├── .github └── workflows │ └── build.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── LICENSE ├── README.org ├── poetry.lock ├── pyproject.toml ├── scripts ├── sync_version.py └── visualize_tissue_masks.py ├── setup.cfg ├── tests ├── __init__.py ├── analyze │ ├── test_differential_expression.py │ ├── test_gene_maps.py │ ├── test_metagenes.py │ └── test_prediction.py ├── conftest.py ├── data │ ├── files │ │ ├── image │ │ │ └── image.jpg │ │ ├── st │ │ │ ├── counts.tsv │ │ │ ├── image.jpg │ │ │ ├── mask.png │ │ │ └── spots.tsv │ │ ├── toydata.h5 │ │ └── visium │ │ │ ├── data.h5 │ │ │ ├── image.jpg │ │ │ ├── mask.png │ │ │ ├── scale_factors.json │ │ │ └── tissue_positions.csv │ ├── test_analysis_exit_status.1.toml │ ├── test_restore_session.1.toml │ ├── test_restore_session.2.toml │ ├── test_stats_writers.1.toml │ └── test_train_exit_status.1.toml ├── model │ └── experiment │ │ └── test_st.py ├── test_functional.py └── test_integration.py └── xfuse ├── __init__.py ├── __main__.py ├── __version__.py ├── _config.py ├── analyze ├── __init__.py ├── analyze.py ├── differential_expression.py ├── gene_maps.py ├── metagenes.py └── prediction.py ├── convert ├── __init__.py ├── image.py ├── st.py ├── utility.py └── visium.py ├── data ├── __init__.py ├── dataset.py ├── slide │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── annotated_image.py │ │ ├── slide_data.py │ │ └── st_slide.py │ ├── iterator │ │ ├── __init__.py │ │ ├── data_iterator.py │ │ ├── full_slide_iterator.py │ │ ├── random_iterator.py │ │ └── slide_iterator.py │ └── slide.py └── utility │ ├── __init__.py │ └── misc.py ├── logging ├── __init__.py ├── formatter.py └── logging.py ├── messengers ├── __init__.py ├── analysis_runner.py ├── checkpointer.py └── stats │ ├── __init__.py │ ├── conditions.py │ ├── elbo.py │ ├── image.py │ ├── latent.py │ ├── metagene_activation.py │ ├── rmse.py │ ├── scale.py │ ├── stats_handler.py │ └── writer │ ├── __init__.py │ ├── file.py │ ├── stats_writer.py │ └── tensorboard.py ├── model ├── __init__.py ├── experiment │ ├── __init__.py │ ├── experiment.py │ ├── image.py │ └── st │ │ ├── __init__.py │ │ ├── metagene_eval.py │ │ ├── metagene_expansion_strategy.py │ │ └── st.py ├── utility │ ├── __init__.py │ └── model_comparison.py └── xfuse.py ├── optim.py ├── run.py ├── session ├── __init__.py ├── io.py ├── items │ ├── __init__.py │ ├── colormap.py │ ├── covariates.py │ ├── dataloader.py │ ├── default_device.py │ ├── eval.py │ ├── genes.py │ ├── learning_rate.py │ ├── log_file.py │ ├── log_level.py │ ├── messengers.py │ ├── model.py │ ├── mpl_backend.py │ ├── optimizer.py │ ├── stats_writers.py │ ├── training_data.py │ └── work_dir.py ├── session.py └── session_item.py ├── train.py └── utility ├── __init__.py ├── core.py ├── file.py ├── mask.py ├── pyro.py ├── state ├── __init__.py ├── getters.py └── state.py ├── tensor.py └── visualization.py /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | lint-and-test: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v1 16 | - name: Set up Python 17 | uses: actions/setup-python@v1 18 | with: 19 | python-version: 3.8 20 | - name: Install dependencies 21 | run: | 22 | pip install --upgrade pip 23 | pip install "poetry >=1.1.4, <2.0.0" 24 | poetry install 25 | - name: Linting 26 | run: | 27 | poetry run pre-commit run --all-files 28 | poetry run mypy xfuse tests scripts 29 | poetry run pylint xfuse tests scripts 30 | - name: Testing 31 | run: | 32 | poetry run pytest --cov=./xfuse --cov-report=xml 33 | - name: Upload coverage report 34 | uses: codecov/codecov-action@v1 35 | with: 36 | token: ${{ secrets.CODECOV_TOKEN }} 37 | file: ./coverage.xml 38 | flags: unittests 39 | fail_ci_if_error: true 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **__pycache__ 2 | **.egg-info 3 | **.eggs 4 | **.mypy_cache 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.8 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v2.4.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: check-yaml 11 | - id: check-added-large-files 12 | 13 | - repo: https://github.com/psf/black 14 | rev: 19.10b0 15 | hooks: 16 | - id: black 17 | additional_dependencies: ['click==8.0.4'] 18 | 19 | - repo: local 20 | hooks: 21 | - id: sync-version 22 | name: sync-version 23 | stages: [commit] 24 | entry: ./scripts/sync_version.py 25 | language: python 26 | files: pyproject.toml 27 | additional_dependencies: 28 | - tomlkit ~= 0.5.8 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.org: -------------------------------------------------------------------------------- 1 | #+TITLE: XFuse: Deep spatial data fusion 2 | 3 | [[https://github.com/ludvb/xfuse/actions?query=workflow%3Abuild+branch%3Amaster][https://github.com/ludvb/xfuse/workflows/build/badge.svg?branch=master]] 4 | 5 | This repository contains code for the paper "Super-resolved spatial transcriptomics by deep data fusion". 6 | 7 | Nature Biotechnology: https://doi.org/10.1038/s41587-021-01075-3 8 | 9 | BioRxiv preprint: https://doi.org/10.1101/2020.02.28.963413 10 | 11 | * Hardware requirements 12 | 13 | XFuse can run on CPU-only hardware, but training new models will take exceedingly long. 14 | We recommend running XFuse on a GPU with at least 8 GB of VRAM. 15 | 16 | * Software requirements 17 | 18 | XFuse has been tested on GNU/Linux but should run on all major operating systems. 19 | XFuse requires Python 3.8. 20 | All other dependencies are pulled in by ~pip~ during the installation. 21 | 22 | * Installing 23 | 24 | To install XFuse to your home directory, run 25 | #+BEGIN_SRC sh 26 | pip install --user git+https://github.com/ludvb/xfuse@master 27 | #+END_SRC 28 | This step should only take a few minutes. 29 | 30 | * Getting started 31 | 32 | This section will guide you through how to start an analysis with XFuse using data on human breast cancer from [fn:1]. 33 | 34 | [fn:1]: https://doi.org/10.1126/science.aaf2403 35 | 36 | ** Data 37 | 38 | The data is available [[https://www.spatialresearch.org/resources-published-datasets/doi-10-1126science-aaf2403/][here]]. 39 | To download all of the required files for the analysis, run 40 | #+BEGIN_SRC sh 41 | # Image data 42 | curl -Lo section1.jpg https://www.spatialresearch.org/wp-content/uploads/2016/07/HE_layer1_BC.jpg 43 | curl -Lo section2.jpg https://www.spatialresearch.org/wp-content/uploads/2016/07/HE_layer2_BC.jpg 44 | curl -Lo section3.jpg https://www.spatialresearch.org/wp-content/uploads/2016/07/HE_layer3_BC.jpg 45 | curl -Lo section4.jpg https://www.spatialresearch.org/wp-content/uploads/2016/07/HE_layer4_BC.jpg 46 | 47 | # Gene expression count data 48 | curl -Lo section1.tsv https://www.spatialresearch.org/wp-content/uploads/2016/07/Layer1_BC_count_matrix-1.tsv 49 | curl -Lo section2.tsv https://www.spatialresearch.org/wp-content/uploads/2016/07/Layer2_BC_count_matrix-1.tsv 50 | curl -Lo section3.tsv https://www.spatialresearch.org/wp-content/uploads/2016/07/Layer3_BC_count_matrix-1.tsv 51 | curl -Lo section4.tsv https://www.spatialresearch.org/wp-content/uploads/2016/07/Layer4_BC_count_matrix-1.tsv 52 | 53 | # Alignment data 54 | curl -Lo section1-alignment.txt https://www.spatialresearch.org/wp-content/uploads/2016/07/Layer1_BC_transformation.txt 55 | curl -Lo section2-alignment.txt https://www.spatialresearch.org/wp-content/uploads/2016/07/Layer2_BC_transformation.txt 56 | curl -Lo section3-alignment.txt https://www.spatialresearch.org/wp-content/uploads/2016/07/Layer3_BC_transformation.txt 57 | curl -Lo section4-alignment.txt https://www.spatialresearch.org/wp-content/uploads/2016/07/Layer4_BC_transformation.txt 58 | #+END_SRC 59 | 60 | ** Preprocessing 61 | 62 | XFuse uses a specialized data format to optimize loading speeds and allow for lazy data loading. 63 | XFuse has inbuilt support for converting data from [[https://support.10xgenomics.com/spatial-gene-expression/software/pipelines/latest/installation][10X Space Ranger]] (~xfuse convert visium~) and the [[https://github.com/SpatialTranscriptomicsResearch/st_pipeline][Spatial Transcriptomics Pipeline]] (~xfuse convert st~) to its own data format. 64 | If your data has been produced by another pipeline, it may need to be wrangled into a supported format before continuing. 65 | Feel free to open an issue on our [[https://github.com/ludvb/xfuse/issues][issue tracker]] if you run into any problems or to request support for a new platform. 66 | 67 | The data from the [[Data]] section was produced by the Spatial Transcriptomics Pipeline, so we can run the following commands to convert it to the right format: 68 | #+BEGIN_SRC sh 69 | xfuse convert st --counts section1.tsv --image section1.jpg --transformation-matrix section1-alignment.txt --scale 0.15 --save-path section1 70 | xfuse convert st --counts section2.tsv --image section2.jpg --transformation-matrix section2-alignment.txt --scale 0.15 --save-path section2 71 | xfuse convert st --counts section3.tsv --image section3.jpg --transformation-matrix section3-alignment.txt --scale 0.15 --save-path section3 72 | xfuse convert st --counts section4.tsv --image section4.jpg --transformation-matrix section4-alignment.txt --scale 0.15 --save-path section4 73 | #+END_SRC 74 | It may be worthwhile to try out different values for the ~--scale~ argument, which downsamples the image data by the given factor. 75 | Essentially, a higher scale increases the resolution of the model but requires considerably more compute power. 76 | 77 | *** Verifying tissue masks 78 | 79 | It is usually a good idea to verify that the computed tissue masks look good. 80 | This can be done using the script ~./scripts/visualize_tissue_masks.py~ included in this repository: 81 | #+BEGIN_SRC sh 82 | curl -LO https://raw.githubusercontent.com/ludvb/xfuse/master/scripts/visualize_tissue_masks.py 83 | python visualize_tissue_masks.py */data.h5 84 | #+END_SRC 85 | The script will show the tissue images with the detected backgrounds blacked out. If tissue detection fails, a custom mask can be passed to ~xfuse convert~ using the ~--mask-file~ argument (see ~xfuse convert visium --help~ for more information). 86 | 87 | ** Configuring and starting the run 88 | 89 | Settings for the run are specified in a configuration file. 90 | Paste the following into a file named ~my-config.toml~: 91 | #+BEGIN_SRC toml 92 | [xfuse] 93 | network_depth = 6 94 | network_width = 16 95 | min_counts = 50 96 | 97 | [expansion_strategy] 98 | type = "DropAndSplit" 99 | [expansion_strategy.DropAndSplit] 100 | max_metagenes = 50 101 | 102 | [optimization] 103 | batch_size = 3 104 | epochs = 100000 105 | learning_rate = 0.0003 106 | patch_size = 768 107 | 108 | [analyses] 109 | [analyses.metagenes] 110 | type = "metagenes" 111 | [analyses.metagenes.options] 112 | method = "pca" 113 | 114 | [analyses.gene_maps] 115 | type = "gene_maps" 116 | [analyses.gene_maps.options] 117 | gene_regex = ".*" 118 | 119 | [slides] 120 | [slides.section1] 121 | data = "section1/data.h5" 122 | [slides.section1.covariates] 123 | section = 1 124 | 125 | [slides.section2] 126 | data = "section2/data.h5" 127 | [slides.section2.covariates] 128 | section = 2 129 | 130 | [slides.section3] 131 | data = "section3/data.h5" 132 | [slides.section3.covariates] 133 | section = 3 134 | 135 | [slides.section4] 136 | data = "section4/data.h5" 137 | [slides.section4.covariates] 138 | section = 4 139 | #+END_SRC 140 | 141 | Here is a non-exhaustive summary of the available configuration options: 142 | - ~xfuse.network_depth~: The number of up- and downsampling steps in the fusion network. If you are running on large images (using a large value for the ~--scale~ argument in ~xfuse convert~), you may need to increase this number. 143 | - ~xfuse.network_width~: The number of channels in the image and expression decoders. You may need to increase this value if you are studying tissues with many different cell types. 144 | - ~xfuse.min_counts~: The minimum number of reads for a gene to be included in the analysis. 145 | - ~expansion_strategy.DropAndSplit.max_metagenes~: The maximum number of metagenes to create during inference. You may need to increase this value if you are studying tissues with many different cell types. 146 | - ~optimization.batch_size~: The mini-batch size. This number should be kept as high as possible to keep gradients stable but can be reduced if you are running XFuse on a GPU with limited memory capacity. 147 | - ~optimization.epochs~: The number of epochs to run. When set to a value below zero, XFuse will use a heuristic stopping criterion. 148 | - ~optimization.patch_size~: The size of training patches. This number should preferably be a multiple of ~2^xfuse.network_depth~ to avoid misalignments during up- and downsampling steps. 149 | - ~slides~: This section defines which slides to include in the experiment. Each slide is associated with a unique subsection. In each subsection, a data path and optional covariates to control for are specified. For example, in the configuration file above, we have given each slide a ~section~ condition with a distinct value to control for sample-wise batch effects. If our dataset contained samples from different patients, we could, for example, also include a ~patient~ condition to control for patient-wise effects. 150 | 151 | We are now ready to start the analysis! 152 | #+BEGIN_SRC sh 153 | xfuse run my-config.toml --save-path my-run 154 | #+END_SRC 155 | 156 | /Tip/: XFuse can generate a template for the configuration file by running 157 | #+BEGIN_SRC sh 158 | xfuse init my-config.toml section1.h5 section2.h5 section3.h5 section4.h5 159 | #+END_SRC 160 | 161 | ** Tracking the training progress 162 | 163 | XFuse continually writes training data to a [[https://github.com/tensorflow/tensorboard][Tensorboard]] log file. 164 | To check how the optimization is progressing, start a Tensorboard web server and direct it to the ~--save-path~ of the run: 165 | #+BEGIN_SRC sh 166 | tensorboard --logdir my-run 167 | #+END_SRC 168 | 169 | ** Stopping and resuming a run 170 | 171 | To stop the run before it has completed, press ~Ctrl+C~. 172 | A snapshot of the model state will be saved to the ~--save-path~. 173 | The snapshot can be restored by running 174 | #+BEGIN_SRC sh 175 | xfuse run my-config.toml --save-path my-run --session my-run/exception.session 176 | #+END_SRC 177 | 178 | ** Finishing the run 179 | 180 | Training the model from scratch will take roughly three days on a normal desktop computer with an Nvidia GeForce 20 series graphics card. 181 | After training, XFuse runs the analyses specified in the configuration file. 182 | Results will be saved to a directory named ~analyses~ in the ~--save-path~. 183 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry>=0.12"] 3 | build-backend = "poetry.masonry.api" 4 | 5 | [tool.black] 6 | line-length = 79 7 | 8 | [tool.poetry] 9 | name = "xfuse" 10 | version = "0.2.1" 11 | description = "Deep spatial data fusion" 12 | authors = ["Ludvig Bergenstråhle "] 13 | 14 | [tool.poetry.dependencies] 15 | click = "^7.1.2" 16 | h5py = "^3.0.0" 17 | imageio = "^2.9.0" 18 | matplotlib = "^3.3.2" 19 | numpy = "^1.19.4" 20 | opencv-python = "^4.4.0" 21 | pandas = "^1.1.4" 22 | Pillow = "^9.0.1" 23 | pyro-ppl = ">=1.5.0,<1.6.0" 24 | python = "^3.8" 25 | scikit-learn = "^0.24.2" 26 | scipy = "^1.5.4" 27 | tensorboard = "^2.5.0" 28 | tifffile = "^2020.10.1" 29 | tomlkit = "^0.7.0" 30 | torch = "^1.8.1" 31 | torchvision = "^0.9.1" 32 | tqdm = "^4.51.0" 33 | tabulate = "^0.8.7" 34 | 35 | [tool.poetry.dev-dependencies] 36 | mypy = "^0.812" 37 | pre-commit = "^2.8.2" 38 | pylint = "^2.8.2" 39 | pytest = "^6.1.2" 40 | pytest-console-scripts = "^1.0.0" 41 | pytest-cov = "^2.10.1" 42 | pytest-datadir = "^1.3.1" 43 | pytest-mock = "^3.3.1" 44 | 45 | [tool.poetry.scripts] 46 | xfuse = "xfuse.__main__:cli" 47 | -------------------------------------------------------------------------------- /scripts/sync_version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pylint: disable=invalid-name 4 | 5 | import os 6 | import sys 7 | 8 | import tomlkit 9 | 10 | 11 | project_file = sys.argv[1] 12 | 13 | with open(project_file, "r") as fp: 14 | project_config = tomlkit.loads(fp.read()) 15 | 16 | version = project_config["tool"]["poetry"]["version"] # type: ignore 17 | 18 | with open( 19 | os.path.join(os.path.dirname(project_file), "xfuse", "__version__.py"), "w" 20 | ) as fp: 21 | fp.write(f'__version__ = "{version}"\n') 22 | -------------------------------------------------------------------------------- /scripts/visualize_tissue_masks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | 5 | import h5py 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("data_files", nargs="+") 11 | options = parser.parse_args() 12 | 13 | n = len(options.data_files) 14 | c = int(np.ceil(n ** 0.5)) 15 | r = int(np.ceil(n / c)) 16 | fig, axs = plt.subplots(r, c) 17 | if n == 1: 18 | axs = np.array([axs]) 19 | for filename, ax in zip(options.data_files, axs.flatten()): 20 | with h5py.File(filename, "r") as data: 21 | img = (data["image"][()] + 1) / 2 22 | mask = data["label"][()] == 1 23 | img[mask] = 0.25 * img[mask] 24 | ax.imshow(img) 25 | plt.tight_layout() 26 | plt.show() 27 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | addopts = 3 | --doctest-modules 4 | --ignore=./scripts 5 | 6 | [mypy] 7 | ignore_missing_imports = True 8 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludvb/xfuse/c420abb013c02f44120205ac184c393c14dcd14d/tests/__init__.py -------------------------------------------------------------------------------- /tests/analyze/test_differential_expression.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from xfuse.analyze.differential_expression import ( 4 | _run_differential_expression_analysis, 5 | ) 6 | from xfuse.session import Session 7 | from xfuse.session.items.work_dir import WorkDir 8 | 9 | 10 | def test_run_differential_expression_analysis( 11 | pretrained_toy_model, toydata, tmp_path 12 | ): 13 | with Session( 14 | model=pretrained_toy_model, 15 | genes=toydata.dataset.genes, 16 | dataloader=toydata, 17 | work_dir=WorkDir(tmp_path), 18 | eval=True, 19 | ): 20 | _run_differential_expression_analysis( 21 | "annotation2", comparisons=[("true", "false")] 22 | ) 23 | 24 | assert os.path.exists(tmp_path / "true-vs-false.csv.gz") 25 | assert os.path.exists(tmp_path / "true-vs-false_top_differential.pdf") 26 | -------------------------------------------------------------------------------- /tests/analyze/test_gene_maps.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from xfuse.analyze.gene_maps import _run_gene_maps_analysis 6 | from xfuse.session import Session 7 | from xfuse.session.items.work_dir import WorkDir 8 | 9 | 10 | def test_run_gene_maps_analysis_image_writer( 11 | pretrained_toy_model, toydata, tmp_path 12 | ): 13 | with Session( 14 | model=pretrained_toy_model, 15 | genes=toydata.dataset.genes, 16 | dataloader=toydata, 17 | work_dir=WorkDir(tmp_path), 18 | ): 19 | _run_gene_maps_analysis(writer="image") 20 | 21 | for section in toydata.dataset.data.design.index: 22 | for gene in toydata.dataset.genes: 23 | assert os.path.exists(tmp_path / section / f"{gene}_mean.jpg") 24 | assert os.path.exists(tmp_path / section / f"{gene}_stdv.jpg") 25 | 26 | 27 | def test_run_gene_maps_analysis_tensor_writer( 28 | pretrained_toy_model, toydata, tmp_path 29 | ): 30 | with Session( 31 | model=pretrained_toy_model, 32 | genes=toydata.dataset.genes, 33 | dataloader=toydata, 34 | work_dir=WorkDir(tmp_path), 35 | ): 36 | _run_gene_maps_analysis(writer="tensor") 37 | 38 | for section in toydata.dataset.data.design.index: 39 | for gene in toydata.dataset.genes: 40 | assert os.path.exists(tmp_path / section / f"{gene}.pt") 41 | -------------------------------------------------------------------------------- /tests/analyze/test_metagenes.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from xfuse.analyze.metagenes import compute_metagene_summary 4 | from xfuse.session import Session 5 | from xfuse.session.items.work_dir import WorkDir 6 | 7 | 8 | def test_metagenes(pretrained_toy_model, toydata, tmp_path): 9 | with Session( 10 | model=pretrained_toy_model, 11 | genes=toydata.dataset.genes, 12 | dataloader=toydata, 13 | work_dir=WorkDir(tmp_path), 14 | eval=True, 15 | ): 16 | compute_metagene_summary() 17 | 18 | for section in toydata.dataset.data.slides: 19 | assert os.path.exists(tmp_path / section / f"summary.png") 20 | -------------------------------------------------------------------------------- /tests/analyze/test_prediction.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from xfuse.analyze.prediction import _run_prediction_analysis 6 | from xfuse.session import Session 7 | from xfuse.session.items.work_dir import WorkDir 8 | 9 | 10 | def test_run_prediction_analysis(pretrained_toy_model, toydata, tmp_path): 11 | with Session( 12 | model=pretrained_toy_model, 13 | genes=toydata.dataset.genes, 14 | dataloader=toydata, 15 | work_dir=WorkDir(tmp_path), 16 | eval=True, 17 | ): 18 | _run_prediction_analysis("annotation1") 19 | 20 | for name, slide in toydata.dataset.data.slides.items(): 21 | name = os.path.basename(name) 22 | output_file = tmp_path / "data.csv.gz" 23 | assert os.path.exists(output_file) 24 | 25 | output_data = pd.read_csv(output_file) 26 | output_data_labels = list(np.unique(output_data.annotation1)) 27 | _, annotation_labels = slide.data.annotation("annotation1") 28 | annotation_labels = sorted(annotation_labels.keys()) 29 | assert output_data_labels == annotation_labels 30 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | r"""Config file for tests""" 2 | 3 | import itertools as it 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pyro 8 | import pyro.distributions as distr 9 | import pytest 10 | import torch 11 | from scipy.ndimage import label as make_label 12 | from xfuse.convert.utility import write_data 13 | from xfuse.data import Data, Dataset 14 | from xfuse.data.slide import STSlide, FullSlideIterator, Slide 15 | from xfuse.data.utility.misc import make_dataloader 16 | from xfuse.model import XFuse 17 | from xfuse.model.experiment.st import ST, MetageneDefault 18 | from xfuse.session import Session, get 19 | from xfuse.train import train 20 | from xfuse.utility.state import reset_state 21 | 22 | 23 | def pytest_configure(config): 24 | # pylint: disable=missing-function-docstring 25 | config.addinivalue_line( 26 | "markers", "fix_rng: resets the RNG to a fixed value" 27 | ) 28 | config.addinivalue_line("markers", "slow: marks test as slow to run") 29 | 30 | 31 | def pytest_runtest_setup(item): 32 | # pylint: disable=missing-function-docstring 33 | pyro.clear_param_store() 34 | reset_state() 35 | if item.get_closest_marker("fix_rng") is not None: 36 | torch.manual_seed(0) 37 | 38 | 39 | def pytest_addoption(parser): 40 | # pylint: disable=missing-function-docstring 41 | parser.addoption( 42 | "--quick", action="store_true", default=False, help="skip slow tests" 43 | ) 44 | 45 | 46 | def pytest_collection_modifyitems(config, items): 47 | # pylint: disable=missing-function-docstring 48 | if config.getoption("--quick"): 49 | for item in filter(lambda x: "slow" in x.keywords, items): 50 | item.add_marker(pytest.mark.skip(reason="skipping slow test")) 51 | 52 | 53 | @pytest.fixture 54 | @pytest.mark.fix_rng 55 | def toydata(tmp_path): 56 | r"""Produces toy dataset""" 57 | # pylint: disable=too-many-locals 58 | 59 | num_genes = 10 60 | num_metagenes = 3 61 | probs = 0.1 62 | H, W = [100] * 2 63 | spot_size = 10 64 | 65 | gridy, gridx = np.meshgrid( 66 | np.linspace(0.0, H - 1, H), np.linspace(0.0, W - 1, W) 67 | ) 68 | yoffset, xoffset = ( 69 | distr.Normal(0.0, 0.2).sample([2, num_metagenes]).cpu().numpy() 70 | ) 71 | activity = ( 72 | np.cos(gridy[..., None] / 100 - 0.5 + yoffset[None, None]) ** 2 73 | * np.cos(gridx[..., None] / 100 - 0.5 + xoffset[None, None]) ** 2 74 | ) 75 | activity = torch.as_tensor(activity, dtype=torch.float32) 76 | 77 | metagene_profiles = ( 78 | distr.Normal(0.0, 1.0) 79 | .expand([num_genes, num_metagenes]) 80 | .sample() 81 | .exp() 82 | ) 83 | 84 | label = np.zeros(activity.shape[:2]).astype(np.uint8) 85 | counts = [torch.zeros(num_genes)] 86 | for i, (y, x) in enumerate( 87 | it.product( 88 | (np.linspace(0.0, 1, H // spot_size)[1:-1] * H).astype(int), 89 | (np.linspace(0.0, 1, W // spot_size)[1:-1] * W).astype(int), 90 | ), 91 | 1, 92 | ): 93 | spot_activity = torch.zeros(num_metagenes) 94 | 95 | for dy, dx in [ 96 | (dx, dy) 97 | for dx, dy in ( 98 | (dy - spot_size // 2, dx - spot_size // 2) 99 | for dy in range(spot_size) 100 | for dx in range(spot_size) 101 | ) 102 | if dy ** 2 + dx ** 2 < spot_size ** 2 / 4 103 | ]: 104 | label[y + dy, x + dx] = i 105 | spot_activity += activity[y + dy, x + dx] 106 | rate = spot_activity @ metagene_profiles.t() 107 | counts.append(distr.NegativeBinomial(rate, probs).sample()) 108 | 109 | image = 255 * ( 110 | (activity - activity.min()) / (activity.max() - activity.min()) 111 | ) 112 | image = image.round().byte().cpu().numpy() 113 | counts = torch.stack(counts) 114 | counts = pd.DataFrame( 115 | counts.cpu().numpy(), 116 | index=pd.Index(list(range(counts.shape[0]))), 117 | columns=[f"g{i + 1}" for i in range(counts.shape[1])], 118 | ) 119 | 120 | annotation1 = np.arange(100) // 10 % 2 == 1 121 | annotation1 = annotation1[:, None] & annotation1[None] 122 | annotation1, _ = make_label(annotation1) 123 | annotation2 = 1 + (annotation1 == 0).astype(np.uint8) 124 | 125 | filepath = tmp_path / "data.h5" 126 | write_data( 127 | counts, 128 | image, 129 | label, 130 | type_label="ST", 131 | annotation={ 132 | "annotation1": ( 133 | annotation1, 134 | {x: str(x) for x in np.unique(annotation1) if x != 0}, 135 | ), 136 | "annotation2": (annotation2, {1: "false", 2: "true"}), 137 | }, 138 | auto_rotate=True, 139 | path=str(filepath), 140 | ) 141 | 142 | design = pd.DataFrame({"ID": 1}, index=["toydata"]).astype("category") 143 | slide = Slide(data=STSlide(str(filepath)), iterator=FullSlideIterator) 144 | data = Data(slides={"toydata": slide}, design=design) 145 | dataset = Dataset(data) 146 | dataloader = make_dataloader(dataset) 147 | 148 | return dataloader 149 | 150 | 151 | @pytest.fixture 152 | def pretrained_toy_model(toydata): 153 | r"""Pretrained toy model""" 154 | # pylint: disable=redefined-outer-name 155 | st_experiment = ST( 156 | depth=2, 157 | num_channels=4, 158 | metagenes=[MetageneDefault(0.0, None) for _ in range(1)], 159 | ) 160 | xfuse = XFuse(experiments=[st_experiment]) 161 | with Session( 162 | model=xfuse, 163 | optimizer=pyro.optim.Adam({"lr": 0.001}), 164 | dataloader=toydata, 165 | genes=toydata.dataset.genes, 166 | covariates={ 167 | covariate: values.cat.categories.values.tolist() 168 | for covariate, values in toydata.dataset.data.design.iteritems() 169 | }, 170 | ): 171 | train(100 + get("training_data").epoch) 172 | return xfuse 173 | -------------------------------------------------------------------------------- /tests/data/files/image/image.jpg: -------------------------------------------------------------------------------- 1 | ../st/image.jpg -------------------------------------------------------------------------------- /tests/data/files/st/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludvb/xfuse/c420abb013c02f44120205ac184c393c14dcd14d/tests/data/files/st/image.jpg -------------------------------------------------------------------------------- /tests/data/files/st/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludvb/xfuse/c420abb013c02f44120205ac184c393c14dcd14d/tests/data/files/st/mask.png -------------------------------------------------------------------------------- /tests/data/files/st/spots.tsv: -------------------------------------------------------------------------------- 1 | x y new_x new_y pixel_x pixel_y 2 | 17 17 16.92 17.00 462.1 465.8 3 | 18 17 18.02 17.03 494.0 466.9 4 | 20 17 20.08 17.06 553.8 467.6 5 | 19 17 18.98 17.07 522.0 467.8 6 | 22 17 21.94 16.97 607.8 464.9 7 | 21 17 20.90 17.02 577.7 466.3 8 | 24 17 24.12 17.06 671.2 467.5 9 | 23 17 23.12 17.08 642.2 468.4 10 | 25 17 25.07 17.00 698.7 465.8 11 | 26 17 26.04 16.99 726.8 465.6 12 | 27 17 26.99 17.00 754.4 465.8 13 | 28 17 27.95 16.97 782.3 465.0 14 | 10 17 9.87 17.08 257.4 468.3 15 | 9 17 9.00 17.03 232.4 466.7 16 | 12 17 12.07 17.04 321.5 467.0 17 | 11 17 10.91 16.97 287.6 465.1 18 | 13 17 13.04 17.15 349.7 470.2 19 | 14 17 14.01 17.06 377.8 467.7 20 | 15 17 15.01 17.05 406.8 467.4 21 | 16 17 16.00 17.05 435.6 467.3 22 | 15 20 15.08 20.00 408.6 553.1 23 | 15 19 15.06 19.02 408.3 524.6 24 | 15 18 15.02 18.02 407.0 495.7 25 | 16 18 15.93 18.00 433.5 495.0 26 | 16 20 15.89 20.05 432.4 554.6 27 | 16 19 15.94 19.02 433.6 524.7 28 | 13 18 12.97 18.01 347.6 495.4 29 | 14 18 14.07 18.03 379.4 495.9 30 | 13 20 13.06 20.04 350.1 554.5 31 | 14 19 14.01 19.05 377.8 525.7 32 | 14 20 14.00 20.02 377.5 553.9 33 | 13 19 13.06 19.07 350.2 526.2 34 | 14 21 13.97 21.05 376.5 583.8 35 | 14 22 14.07 21.94 379.4 609.8 36 | 13 22 13.06 22.01 350.1 611.9 37 | 13 21 13.01 20.92 348.8 580.0 38 | 16 21 15.89 21.02 432.3 582.8 39 | 15 21 15.08 21.07 408.9 584.4 40 | 15 22 15.05 22.00 407.9 611.4 41 | 16 22 16.00 22.10 435.4 614.4 42 | 12 19 12.08 19.06 321.7 525.9 43 | 12 18 12.12 18.06 322.9 496.7 44 | 12 20 12.09 20.03 322.0 554.2 45 | 11 19 10.90 18.99 287.3 523.9 46 | 11 18 10.88 17.97 286.7 494.1 47 | 11 20 10.87 20.03 286.5 554.2 48 | 10 20 9.96 20.04 260.1 554.5 49 | 10 19 9.92 18.98 258.9 523.6 50 | 9 19 9.02 19.04 232.9 525.4 51 | 9 20 9.02 20.05 232.8 554.6 52 | 10 18 9.90 18.00 258.4 494.9 53 | 9 18 9.05 18.02 233.7 495.5 54 | 9 21 9.07 20.98 234.2 581.7 55 | 10 21 9.91 20.97 258.6 581.5 56 | 10 22 9.97 21.95 260.3 610.0 57 | 9 22 9.06 22.01 234.0 611.9 58 | 12 22 12.13 21.97 323.2 610.5 59 | 12 21 12.07 21.04 321.5 583.7 60 | 11 21 10.87 20.97 286.5 581.5 61 | 11 22 10.90 21.94 287.5 609.9 62 | 11 24 10.90 23.91 287.4 667.1 63 | 12 24 12.15 24.01 323.8 670.0 64 | 12 23 12.12 22.99 322.7 640.2 65 | 11 23 10.86 22.90 286.2 637.6 66 | 9 24 9.03 23.95 233.1 668.2 67 | 9 23 9.06 22.94 233.9 638.8 68 | 10 24 9.97 23.95 260.3 668.3 69 | 10 23 10.00 22.94 261.2 638.7 70 | 9 25 9.07 24.94 234.4 696.9 71 | 10 25 9.95 24.91 259.9 696.2 72 | 12 25 12.12 25.00 322.9 698.9 73 | 11 25 10.86 24.95 286.2 697.3 74 | 16 23 15.95 22.96 434.0 639.3 75 | 16 24 15.98 23.96 434.9 668.6 76 | 15 23 15.01 22.97 406.7 639.6 77 | 15 24 15.07 23.97 408.6 668.9 78 | 13 24 13.05 23.95 349.8 668.3 79 | 13 23 13.07 22.92 350.3 638.1 80 | 14 24 14.05 23.96 378.9 668.7 81 | 14 23 14.05 22.95 378.8 639.1 82 | 13 25 13.11 24.92 351.6 696.5 83 | 14 25 14.08 24.90 379.9 696.1 84 | 15 25 15.03 24.89 407.4 695.5 85 | 16 25 15.95 24.96 434.1 697.6 86 | 29 18 29.06 18.05 814.6 496.3 87 | 29 19 29.08 18.97 815.1 523.3 88 | 29 20 29.04 20.09 814.1 555.8 89 | 29 21 28.97 20.97 812.0 581.5 90 | 29 22 28.99 21.97 812.6 610.6 91 | 27 20 27.01 20.05 755.1 554.8 92 | 27 19 26.98 18.98 754.2 523.6 93 | 27 18 26.99 17.97 754.6 494.0 94 | 28 18 27.98 17.99 783.2 494.7 95 | 28 19 28.01 18.99 784.2 523.7 96 | 28 20 28.00 20.02 783.7 553.7 97 | 25 19 25.06 19.02 698.5 524.8 98 | 25 20 25.06 20.01 698.6 553.4 99 | 26 20 26.03 20.02 726.6 553.9 100 | 26 18 26.06 17.98 727.6 494.4 101 | 25 18 25.04 18.02 698.0 495.6 102 | 26 19 26.07 18.99 727.8 523.9 103 | 26 22 26.11 21.98 729.0 610.9 104 | 25 22 25.06 21.99 698.4 611.3 105 | 26 21 26.08 20.96 728.0 581.3 106 | 25 21 25.04 21.01 697.9 582.6 107 | 28 21 27.94 20.97 782.0 581.5 108 | 27 21 27.04 20.95 755.9 580.9 109 | 27 22 26.99 21.96 754.6 610.2 110 | 28 22 27.97 21.95 783.0 609.9 111 | 28 23 27.98 22.90 783.2 637.6 112 | 28 24 28.00 23.90 783.8 666.8 113 | 27 23 27.01 22.91 755.2 637.9 114 | 27 24 26.99 23.90 754.6 666.7 115 | 26 24 26.03 23.94 726.6 667.8 116 | 26 23 26.07 22.88 727.9 637.2 117 | 25 23 25.11 22.90 699.9 637.6 118 | 25 24 25.12 23.92 700.2 667.4 119 | 25 25 25.01 24.81 697.1 693.2 120 | 26 25 26.16 24.84 730.4 694.3 121 | 28 25 27.98 24.88 783.2 695.4 122 | 27 25 27.03 24.86 755.7 694.9 123 | 29 24 28.98 23.89 812.3 666.5 124 | 29 23 28.99 22.86 812.7 636.5 125 | 29 25 29.01 24.91 813.1 696.1 126 | 23 20 23.16 20.03 643.3 554.0 127 | 23 19 23.10 19.04 641.6 525.2 128 | 24 19 24.12 19.03 671.3 525.0 129 | 24 18 24.12 18.02 671.3 495.7 130 | 24 20 24.10 20.05 670.5 554.7 131 | 23 18 23.11 18.02 642.0 495.6 132 | 22 18 21.92 17.95 607.3 493.6 133 | 21 20 20.98 20.04 579.9 554.3 134 | 21 18 20.97 17.97 579.8 494.2 135 | 22 19 21.93 18.93 607.7 522.2 136 | 22 20 21.96 19.96 608.4 552.2 137 | 21 19 20.96 18.97 579.5 523.2 138 | 21 21 20.93 20.97 578.6 581.5 139 | 22 22 21.93 21.93 607.7 609.5 140 | 21 22 21.01 21.94 580.8 609.8 141 | 22 21 21.98 20.92 609.0 580.2 142 | 24 22 24.10 22.01 670.6 611.8 143 | 23 21 23.12 21.02 642.0 582.9 144 | 23 22 23.16 22.01 643.2 611.8 145 | 24 21 24.13 21.01 671.5 582.6 146 | 19 18 19.04 17.98 523.6 494.3 147 | 20 20 20.05 20.01 553.1 553.4 148 | 19 20 18.98 20.05 521.9 554.6 149 | 19 19 18.96 18.98 521.3 523.4 150 | 20 18 19.99 18.04 551.4 496.0 151 | 20 19 20.06 19.00 553.4 524.1 152 | 17 18 16.99 17.98 464.1 494.4 153 | 17 20 17.03 19.98 465.3 552.8 154 | 18 18 17.96 18.01 492.2 495.2 155 | 17 19 16.99 19.00 464.1 524.0 156 | 18 19 17.98 19.00 492.8 524.1 157 | 18 20 18.00 20.09 493.5 555.9 158 | 18 22 18.07 22.02 495.5 612.0 159 | 18 21 17.99 21.04 493.2 583.7 160 | 17 21 16.98 20.99 464.0 582.0 161 | 17 22 16.96 22.10 463.5 614.4 162 | 19 22 19.03 21.98 523.5 611.0 163 | 19 21 18.98 21.06 522.0 584.2 164 | 20 21 20.07 21.00 553.7 582.3 165 | 20 22 20.05 21.98 553.1 610.9 166 | 20 23 20.08 22.84 553.9 635.9 167 | 20 24 20.11 23.89 554.8 666.5 168 | 19 24 19.04 23.98 523.7 669.2 169 | 19 23 18.98 22.88 522.1 637.1 170 | 17 23 16.93 22.93 462.5 638.6 171 | 18 23 18.16 22.88 498.0 637.1 172 | 17 24 16.87 23.91 460.7 667.1 173 | 18 24 18.13 23.93 497.2 667.6 174 | 17 25 16.99 24.90 464.2 696.1 175 | 18 25 18.03 24.97 494.5 697.9 176 | 19 25 19.00 24.95 522.5 697.5 177 | 20 25 20.09 24.90 554.3 696.0 178 | 23 23 23.18 23.00 643.8 640.6 179 | 24 23 24.11 22.90 671.0 637.8 180 | 24 24 24.12 23.95 671.1 668.4 181 | 23 24 23.18 23.98 644.0 669.0 182 | 21 24 20.92 23.87 578.4 666.0 183 | 22 24 21.94 23.89 607.9 666.5 184 | 22 23 21.95 22.84 608.2 635.9 185 | 21 23 21.01 22.80 580.9 634.6 186 | 21 25 21.02 24.87 581.1 695.0 187 | 22 25 21.96 24.88 608.4 695.2 188 | 23 25 23.23 24.94 645.3 697.0 189 | 24 25 24.21 24.84 673.9 694.2 190 | 18 16 17.99 15.99 493.4 436.5 191 | 17 16 16.93 15.98 462.4 436.1 192 | 19 16 18.90 15.99 519.8 436.4 193 | 20 16 20.08 15.95 554.1 435.3 194 | 22 16 21.90 15.95 606.8 435.3 195 | 21 16 20.93 15.96 578.7 435.5 196 | 23 16 23.13 16.04 642.5 437.9 197 | 24 16 24.06 16.02 669.4 437.4 198 | 26 16 26.04 15.98 727.0 436.3 199 | 25 16 25.08 16.01 698.9 437.0 200 | 27 16 26.99 15.97 754.6 436.0 201 | 10 16 9.88 15.99 257.9 436.4 202 | 12 16 12.06 16.07 321.2 438.7 203 | 11 16 10.88 15.97 286.7 435.9 204 | 13 16 13.09 16.03 351.1 437.7 205 | 14 16 14.01 16.04 377.6 437.9 206 | 16 16 15.96 16.00 434.2 436.8 207 | 15 16 15.00 16.01 406.5 437.2 208 | 17 15 17.00 15.00 464.6 407.7 209 | 18 15 17.97 15.03 492.6 408.4 210 | 12 15 12.04 15.05 320.6 409.2 211 | 11 15 10.88 14.96 286.7 406.6 212 | 15 15 15.03 15.00 407.2 407.8 213 | 16 15 15.91 15.04 432.8 408.7 214 | 13 15 13.08 15.04 350.8 408.8 215 | 14 15 13.97 15.03 376.6 408.6 216 | 23 15 23.15 15.03 643.1 408.6 217 | 24 15 24.09 15.02 670.4 408.2 218 | 22 15 21.91 14.98 607.0 407.0 219 | 25 15 25.08 15.02 699.0 408.2 220 | 18 14 17.98 13.98 493.0 377.9 221 | 17 14 16.91 13.96 461.8 377.5 222 | 15 14 15.04 14.00 407.7 378.6 223 | 16 14 15.96 13.99 434.3 378.2 224 | 14 14 14.02 14.09 377.8 381.2 225 | 10 26 9.94 25.92 259.5 725.7 226 | 11 26 10.91 25.93 287.6 725.8 227 | 12 26 12.21 26.03 325.5 728.8 228 | 14 26 13.94 25.99 375.7 727.6 229 | 13 26 13.05 26.00 349.9 727.8 230 | 15 26 15.06 25.94 408.3 726.2 231 | 16 26 15.95 25.94 434.1 726.1 232 | 25 26 25.12 25.95 700.2 726.5 233 | 26 26 26.10 25.91 728.6 725.4 234 | 27 26 26.98 25.92 754.4 725.8 235 | 28 26 27.96 25.91 782.8 725.3 236 | 29 26 28.95 25.91 811.4 725.2 237 | 17 26 17.00 25.92 464.4 725.6 238 | 18 26 17.98 25.94 493.0 726.0 239 | 19 26 19.04 25.90 523.7 725.0 240 | 20 26 20.05 25.92 552.9 725.7 241 | 21 26 21.00 25.89 580.6 724.8 242 | 22 26 21.94 25.90 608.1 724.9 243 | 24 26 24.04 25.95 668.8 726.6 244 | 23 26 23.24 26.01 645.8 728.4 245 | 23 28 23.22 27.93 645.0 784.0 246 | 24 28 24.17 27.92 672.7 783.8 247 | 24 27 24.13 26.96 671.6 756.0 248 | 23 27 23.16 27.00 643.3 757.1 249 | 22 28 21.97 27.88 608.8 782.8 250 | 22 27 21.99 26.87 609.4 753.4 251 | 27 27 27.02 26.92 755.3 754.6 252 | 27 28 27.01 27.88 755.2 782.8 253 | 25 27 25.10 26.95 699.8 755.6 254 | 26 28 26.10 27.83 728.5 781.2 255 | 25 28 25.07 27.86 698.7 782.0 256 | 26 27 26.04 26.93 726.9 755.1 257 | 16 27 15.94 26.94 433.8 755.4 258 | 16 28 16.00 27.94 435.5 784.5 259 | 15 27 15.10 26.95 409.4 755.6 260 | 15 28 15.04 27.89 407.4 783.0 261 | 13 27 13.13 26.97 352.2 756.1 262 | 14 28 14.07 27.90 379.4 783.2 263 | 14 27 13.98 26.98 376.8 756.6 264 | 13 28 13.09 27.96 350.9 785.0 265 | 11 27 10.89 26.91 287.2 754.5 266 | 12 28 12.14 28.00 323.4 786.3 267 | 12 27 12.07 27.00 321.3 757.0 268 | 17 27 16.98 26.94 464.0 755.2 269 | 18 28 18.04 27.97 494.6 785.4 270 | 18 27 18.00 26.96 493.6 756.0 271 | 17 28 17.00 27.90 464.4 783.3 272 | 14 29 14.06 29.00 379.1 815.2 273 | 16 29 15.96 28.91 434.2 812.7 274 | 15 29 15.04 28.94 407.5 813.6 275 | 17 29 16.97 28.88 463.6 811.9 276 | 10 27 9.92 26.91 259.0 754.4 277 | 8 23 7.96 22.99 202.1 640.4 278 | 21 27 20.97 26.91 579.8 754.4 279 | 19 27 18.97 26.95 521.7 755.7 280 | 28 27 27.93 26.89 781.8 753.9 281 | 24 29 24.11 28.96 671.0 814.0 282 | 25 29 25.13 28.93 700.6 813.4 283 | 30 19 29.96 18.97 840.8 523.2 284 | -------------------------------------------------------------------------------- /tests/data/files/toydata.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludvb/xfuse/c420abb013c02f44120205ac184c393c14dcd14d/tests/data/files/toydata.h5 -------------------------------------------------------------------------------- /tests/data/files/visium/data.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludvb/xfuse/c420abb013c02f44120205ac184c393c14dcd14d/tests/data/files/visium/data.h5 -------------------------------------------------------------------------------- /tests/data/files/visium/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludvb/xfuse/c420abb013c02f44120205ac184c393c14dcd14d/tests/data/files/visium/image.jpg -------------------------------------------------------------------------------- /tests/data/files/visium/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludvb/xfuse/c420abb013c02f44120205ac184c393c14dcd14d/tests/data/files/visium/mask.png -------------------------------------------------------------------------------- /tests/data/files/visium/scale_factors.json: -------------------------------------------------------------------------------- 1 | {"spot_diameter_fullres": 5.0, "tissue_hires_scalef": 1.0} 2 | -------------------------------------------------------------------------------- /tests/data/files/visium/tissue_positions.csv: -------------------------------------------------------------------------------- 1 | AAA,1,0,0,10,10 2 | AAT,1,0,1,10,21 3 | AAG,1,0,2,10,33 4 | AAC,1,0,3,10,44 5 | ATA,1,0,4,10,56 6 | ATT,1,0,5,10,67 7 | ATG,1,0,6,10,79 8 | ATC,1,0,7,10,90 9 | AGA,1,1,0,21,10 10 | AGT,1,1,1,21,21 11 | AGG,1,1,2,21,33 12 | AGC,1,1,3,21,44 13 | ACA,1,1,4,21,56 14 | ACT,1,1,5,21,67 15 | ACG,1,1,6,21,79 16 | ACC,1,1,7,21,90 17 | TAA,1,2,0,33,10 18 | TAT,1,2,1,33,21 19 | TAG,1,2,2,33,33 20 | TAC,1,2,3,33,44 21 | TTA,1,2,4,33,56 22 | TTT,1,2,5,33,67 23 | TTG,1,2,6,33,79 24 | TTC,1,2,7,33,90 25 | TGA,1,3,0,44,10 26 | TGT,1,3,1,44,21 27 | TGG,1,3,2,44,33 28 | TGC,1,3,3,44,44 29 | TCA,1,3,4,44,56 30 | TCT,1,3,5,44,67 31 | TCG,1,3,6,44,79 32 | TCC,1,3,7,44,90 33 | GAA,1,4,0,56,10 34 | GAT,1,4,1,56,21 35 | GAG,1,4,2,56,33 36 | GAC,1,4,3,56,44 37 | GTA,1,4,4,56,56 38 | GTT,1,4,5,56,67 39 | GTG,1,4,6,56,79 40 | GTC,1,4,7,56,90 41 | GGA,1,5,0,67,10 42 | GGT,1,5,1,67,21 43 | GGG,1,5,2,67,33 44 | GGC,1,5,3,67,44 45 | GCA,1,5,4,67,56 46 | GCT,1,5,5,67,67 47 | GCG,1,5,6,67,79 48 | GCC,1,5,7,67,90 49 | CAA,1,6,0,79,10 50 | CAT,1,6,1,79,21 51 | CAG,1,6,2,79,33 52 | CAC,1,6,3,79,44 53 | CTA,1,6,4,79,56 54 | CTT,1,6,5,79,67 55 | CTG,1,6,6,79,79 56 | CTC,1,6,7,79,90 57 | CGA,1,7,0,90,10 58 | CGT,1,7,1,90,21 59 | CGG,1,7,2,90,33 60 | CGC,1,7,3,90,44 61 | CCA,1,7,4,90,56 62 | CCT,1,7,5,90,67 63 | CCG,1,7,6,90,79 64 | CCC,1,7,7,90,90 65 | -------------------------------------------------------------------------------- /tests/data/test_analysis_exit_status.1.toml: -------------------------------------------------------------------------------- 1 | [xfuse] 2 | network_depth = 2 3 | 4 | [optimization] 5 | epochs = 10 6 | batch_size = 3 7 | patch_size = 32 8 | 9 | [analyses] 10 | [analyses.metagenes1] 11 | type = "metagenes" 12 | [analyses.metagenes1.options] 13 | 14 | [analyses.prediction1] 15 | type = "prediction" 16 | [analyses.prediction1.options] 17 | annotation_layer = "annotation1" 18 | num_samples = 10 19 | 20 | [analyses.differential_expression1] 21 | type = "differential_expression" 22 | [analyses.differential_expression1.options] 23 | annotation_layer = "annotation2" 24 | comparisons = [["true", "false"]] 25 | num_samples = 10 26 | genes_per_batch = 3 27 | 28 | [analyses.gene_maps1] 29 | type = "gene_maps" 30 | [analyses.gene_maps1.options] 31 | 32 | [slides] 33 | [slides.toydata1] 34 | data = "./files/toydata.h5" 35 | [slides.toydata1.covariates] 36 | slide = "1" 37 | 38 | [slides.toydata2] 39 | data = "./files/toydata.h5" 40 | [slides.toydata2.covariates] 41 | slide = "2" 42 | -------------------------------------------------------------------------------- /tests/data/test_restore_session.1.toml: -------------------------------------------------------------------------------- 1 | [xfuse] 2 | network_depth = 2 3 | 4 | [optimization] 5 | epochs = 1 6 | 7 | [slides] 8 | [slides.toydata] 9 | data = "./files/toydata.h5" 10 | -------------------------------------------------------------------------------- /tests/data/test_restore_session.2.toml: -------------------------------------------------------------------------------- 1 | [xfuse] 2 | network_depth = 2 3 | 4 | [optimization] 5 | epochs = 1 6 | 7 | [slides] 8 | [slides.toydata] 9 | data = "./files/toydata.h5" 10 | [slides.toydata.covariates] 11 | condition = "A" 12 | -------------------------------------------------------------------------------- /tests/data/test_stats_writers.1.toml: -------------------------------------------------------------------------------- 1 | [xfuse] 2 | network_depth = 2 3 | 4 | [optimization] 5 | epochs = 1 6 | batch_size = 2 7 | patch_size = 32 8 | 9 | [slides] 10 | [slides.toydata1] 11 | data = "./files/toydata.h5" 12 | [slides.toydata1.covariates] 13 | section = 1 14 | condition = "A" 15 | 16 | [slides.toydata2] 17 | data = "./files/toydata.h5" 18 | [slides.toydata2.covariates] 19 | section = 2 20 | -------------------------------------------------------------------------------- /tests/data/test_train_exit_status.1.toml: -------------------------------------------------------------------------------- 1 | [xfuse] 2 | network_depth = 2 3 | 4 | [optimization] 5 | epochs = 10 6 | batch_size = 2 7 | patch_size = 32 8 | 9 | [slides] 10 | [slides.toydata] 11 | data = "./files/toydata.h5" 12 | -------------------------------------------------------------------------------- /tests/model/experiment/test_st.py: -------------------------------------------------------------------------------- 1 | import pyro 2 | from xfuse.model.experiment.st.st import _encode_metagene_name 3 | from xfuse.session import Session, get 4 | from xfuse.utility.tensor import to_device 5 | 6 | 7 | def test_split_metagene(pretrained_toy_model, toydata): 8 | r"""Test that metagenes are split correctly""" 9 | st_experiment = pretrained_toy_model.get_experiment("ST") 10 | metagene = next(iter(st_experiment.metagenes.keys())) 11 | metagene_new = st_experiment.split_metagene(metagene) 12 | 13 | with Session( 14 | model=pretrained_toy_model, 15 | dataloader=toydata, 16 | genes=toydata.dataset.genes, 17 | covariates={ 18 | covariate: values.cat.categories.values.tolist() 19 | for covariate, values in toydata.dataset.data.design.iteritems() 20 | }, 21 | ): 22 | x = to_device(next(iter(toydata))) 23 | with pyro.poutine.trace() as guide_tr: 24 | get("model").guide(x) 25 | with pyro.poutine.trace() as model_tr: 26 | with pyro.poutine.replay(trace=guide_tr.trace): 27 | get("model").model(x) 28 | 29 | rim_mean = model_tr.trace.nodes["rim"]["fn"].mean 30 | assert (rim_mean[0, 0] == rim_mean[-1][0, -1]).all() 31 | 32 | rate_mg = guide_tr.trace.nodes[_encode_metagene_name(metagene)]["fn"].mean 33 | rate_mg_new = guide_tr.trace.nodes[_encode_metagene_name(metagene_new)][ 34 | "fn" 35 | ].mean 36 | assert (rate_mg == rate_mg_new).all() 37 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | r"""Integration tests""" 2 | 3 | import pyro.optim 4 | 5 | import pytest 6 | 7 | from xfuse.messengers.stats import RMSE 8 | from xfuse.model import XFuse 9 | from xfuse.model.experiment.st import ST, MetageneDefault 10 | from xfuse.model.experiment.st.metagene_eval import purge_metagenes 11 | from xfuse.model.experiment.st.metagene_expansion_strategy import ( 12 | Extra, 13 | DropAndSplit, 14 | ) 15 | from xfuse.session import Session, get 16 | from xfuse.train import train 17 | 18 | 19 | @pytest.mark.fix_rng 20 | @pytest.mark.slow 21 | def test_toydata(mocker, toydata): 22 | r"""Integration test on toy dataset""" 23 | st_experiment = ST( 24 | depth=2, 25 | num_channels=4, 26 | metagenes=[MetageneDefault(0.0, None) for _ in range(3)], 27 | ) 28 | xfuse = XFuse(experiments=[st_experiment]) 29 | rmse = RMSE() 30 | mock_log_scalar = mocker.patch("xfuse.messengers.stats.rmse.log_scalar") 31 | with Session( 32 | model=xfuse, 33 | optimizer=pyro.optim.Adam({"lr": 0.0001}), 34 | dataloader=toydata, 35 | genes=toydata.dataset.genes, 36 | covariates={ 37 | covariate: values.cat.categories.values.tolist() 38 | for covariate, values in toydata.dataset.data.design.iteritems() 39 | }, 40 | messengers=[rmse], 41 | ): 42 | train(100 + get("training_data").epoch) 43 | rmses = [x[1][1] for x in mock_log_scalar.mock_calls] 44 | assert rmses[-1] < 6.0 45 | 46 | 47 | @pytest.mark.fix_rng 48 | @pytest.mark.parametrize( 49 | "expansion_strategies,compute_expected_metagenes", 50 | [ 51 | ((Extra(5),), lambda n: (n + 5, n)), 52 | ((DropAndSplit(),) * 2, lambda n: (2 * n, n)), 53 | ], 54 | ) 55 | def test_metagene_expansion( 56 | # pylint: disable=redefined-outer-name 57 | toydata, 58 | pretrained_toy_model, 59 | expansion_strategies, 60 | compute_expected_metagenes, 61 | ): 62 | r"""Test metagene expansion dynamics""" 63 | st_experiment = pretrained_toy_model.get_experiment("ST") 64 | num_start_metagenes = len(st_experiment.metagenes) 65 | 66 | for expansion_strategy, expected_metagenes in zip( 67 | expansion_strategies, compute_expected_metagenes(num_start_metagenes) 68 | ): 69 | with Session( 70 | dataloader=toydata, 71 | genes=toydata.dataset.genes, 72 | metagene_expansion_strategy=expansion_strategy, 73 | model=pretrained_toy_model, 74 | ): 75 | purge_metagenes(num_samples=10) 76 | assert len(st_experiment.metagenes) == expected_metagenes 77 | -------------------------------------------------------------------------------- /xfuse/__init__.py: -------------------------------------------------------------------------------- 1 | r"""XFuse""" 2 | 3 | from . import session 4 | from .__version__ import __version__ 5 | 6 | __all__ = ["__version__"] 7 | -------------------------------------------------------------------------------- /xfuse/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.1" 2 | -------------------------------------------------------------------------------- /xfuse/analyze/__init__.py: -------------------------------------------------------------------------------- 1 | from .differential_expression import * 2 | from .gene_maps import * 3 | from .prediction import * 4 | from .metagenes import * 5 | from .analyze import _ANALYSES as analyses, Analysis 6 | -------------------------------------------------------------------------------- /xfuse/analyze/analyze.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, NamedTuple 2 | 3 | from ..logging import DEBUG, log 4 | 5 | 6 | class Analysis(NamedTuple): 7 | r"""Data type for analyses""" 8 | description: str 9 | function: Callable[..., None] 10 | 11 | 12 | _ANALYSES: Dict[str, Analysis] = {} 13 | 14 | 15 | def _register_analysis(name, analysis: Analysis): 16 | if name not in _ANALYSES: 17 | log(DEBUG, 'Registering analysis "%s"', name) 18 | _ANALYSES[name] = analysis 19 | else: 20 | raise RuntimeError(f'Analysis "{name}" has already been registered!') 21 | -------------------------------------------------------------------------------- /xfuse/analyze/differential_expression.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import List, Optional, Tuple 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from ..data import Data, Dataset 9 | from ..data.utility.misc import make_dataloader 10 | from ..data.slide import AnnotatedImage, FullSlideIterator, Slide 11 | from ..session import Session, require 12 | from .analyze import Analysis, _register_analysis 13 | from .prediction import predict_df 14 | 15 | 16 | def _run_differential_expression_analysis( 17 | annotation_layer: Optional[str] = None, 18 | comparisons: List[Tuple[str, str]] = None, 19 | normalize_covariates: List[str] = None, 20 | num_samples: int = 50, 21 | genes_per_batch: int = 100000, 22 | ) -> None: 23 | """Runs differential expression analysis""" 24 | 25 | dataloader = require("dataloader") 26 | 27 | if annotation_layer is None: 28 | warnings.warn( 29 | "No annotation layer specified." 30 | " Skipping differential gene expression analysis." 31 | ) 32 | return 33 | 34 | if comparisons is None: 35 | warnings.warn( 36 | "No comparisons specified." 37 | " Skipping differential gene expression analysis." 38 | ) 39 | return 40 | 41 | if normalize_covariates is None: 42 | normalize_covariates = [] 43 | 44 | slides = { 45 | slide_name: Slide( 46 | data=AnnotatedImage.from_st_slide( 47 | slide.data, annotation_name=annotation_layer 48 | ), 49 | iterator=FullSlideIterator, 50 | ) 51 | for slide_name, slide in dataloader.dataset.data.slides.items() 52 | } 53 | dataloader = make_dataloader( 54 | Dataset(Data(slides, design=dataloader.dataset.data.design)), 55 | batch_size=1, 56 | shuffle=False, 57 | ) 58 | 59 | with Session(dataloader=dataloader, messengers=[]): 60 | samples = predict_df( 61 | num_samples=num_samples, 62 | genes_per_batch=genes_per_batch, 63 | normalize_covariates=normalize_covariates, 64 | ) 65 | 66 | samples = samples.groupby([annotation_layer, "sample", "gene"]).agg(sum) 67 | samples = samples.assign( 68 | count=samples.groupby([annotation_layer, "sample"]).transform( 69 | lambda x: np.log2(x / x.sum()) 70 | ) 71 | ) 72 | samples = samples.reset_index().pivot( 73 | ["sample", "gene"], columns=annotation_layer 74 | ) 75 | samples.columns = samples.columns.map(lambda x: x[1]) 76 | 77 | def _save_comparison(a, b): 78 | lfc = ( 79 | samples[[a, b]] 80 | .assign(lfc=lambda x: x[a] - x[b])["lfc"] 81 | .reset_index() 82 | .pivot("sample", "gene") 83 | ) 84 | lfc.columns = lfc.columns.map(lambda x: x[1]) 85 | 86 | lfc.to_csv(f"{a}-vs-{b}.csv.gz") 87 | 88 | sorted_values = lfc.mean(0).sort_values() 89 | log2_fold_top = lfc[ 90 | pd.concat([sorted_values[:10], sorted_values[-10:]]).index 91 | ] 92 | log2_fold_top.boxplot(vert=False) 93 | plt.title(f"{a} vs. {b}") 94 | plt.xlabel("log2 fold") 95 | plt.savefig(f"{a}-vs-{b}_top_differential.pdf") 96 | plt.close() 97 | 98 | for a, b in comparisons: 99 | _save_comparison(a, b) 100 | 101 | 102 | _register_analysis( 103 | name="differential_expression", 104 | analysis=Analysis( 105 | description="Performs differential gene expression analysis", 106 | function=_run_differential_expression_analysis, 107 | ), 108 | ) 109 | -------------------------------------------------------------------------------- /xfuse/analyze/gene_maps.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | import re 3 | import warnings 4 | from typing import Any, Callable, Dict, Iterable, Tuple, cast 5 | 6 | import numpy as np 7 | import torch 8 | from imageio import imwrite 9 | from scipy.ndimage.morphology import binary_fill_holes 10 | 11 | from .analyze import Analysis, _register_analysis 12 | from .prediction import predict 13 | from ..data import Data, Dataset 14 | from ..data.slide import AnnotatedImage, FullSlideIterator, Slide 15 | from ..data.utility.misc import make_dataloader 16 | from ..logging import Progressbar 17 | from ..session import Session, require 18 | from ..utility.core import chunks_of, resize 19 | from ..utility.file import chdir 20 | from ..utility.visualization import ( 21 | balance_colors, 22 | greyscale2colormap, 23 | mask_background, 24 | ) 25 | from ..utility.tensor import to_device 26 | 27 | 28 | def generate_gene_maps( 29 | num_samples: int = 10, 30 | genes_per_batch: int = 10, 31 | predict_mean: bool = True, 32 | normalize: bool = False, 33 | scale: float = 1.0, 34 | ) -> Iterable[Tuple[str, str, np.ndarray]]: 35 | """Generates gene maps on the active dataset""" 36 | 37 | genes = require("genes") 38 | dataloader = require("dataloader") 39 | 40 | if scale <= 0 or scale > 1.0: 41 | raise ValueError("Argument `scale` must be in (0, 1]") 42 | 43 | def _compute_annotation(shape): 44 | scaled_shape = [x * scale for x in shape] 45 | ys, xs = [ 46 | np.floor(np.linspace(0, scaled_x, x, endpoint=False)).astype(int) 47 | for scaled_x, x in zip(scaled_shape, shape) 48 | ] 49 | annotation = 1 + torch.as_tensor((1 + xs.max()) * ys[:, None] + xs) 50 | label_names = { 51 | 1 + (1 + xs.max()) * y + x: (y, x) 52 | for y in np.unique(ys) 53 | for x in np.unique(xs) 54 | } 55 | return annotation, label_names 56 | 57 | dataloader = make_dataloader( 58 | Dataset( 59 | Data( 60 | slides={ 61 | k: Slide( 62 | data=AnnotatedImage( 63 | torch.as_tensor(v.data.image[()]), 64 | annotation=annotation, 65 | name="coordinates", 66 | label_names=label_names, 67 | ), 68 | iterator=FullSlideIterator, 69 | ) 70 | for k, v in dataloader.dataset.data.slides.items() 71 | for annotation, label_names in [ 72 | _compute_annotation(v.data.label.shape) 73 | ] 74 | }, 75 | design=dataloader.dataset.data.design, 76 | ) 77 | ), 78 | batch_size=1, 79 | shuffle=False, 80 | ) 81 | 82 | def _process_batch(samples): 83 | rows, cols = [x + 1 for x in samples[0]["rownames"][-1]] 84 | data = torch.stack([x["data"] for x in samples]) 85 | data = data.reshape(data.shape[0], rows, cols, data.shape[-1]) 86 | for gene, gene_data in zip( 87 | samples[0]["colnames"], data.permute(3, 0, 1, 2) 88 | ): 89 | yield samples[0]["section"], gene, cast( 90 | np.ndarray, gene_data.cpu().numpy() 91 | ) 92 | 93 | with Progressbar( 94 | chunks_of(genes, genes_per_batch), 95 | total=int(np.ceil(len(genes) / genes_per_batch)), 96 | leave=False, 97 | ) as progress: 98 | for genes_batch in progress: 99 | with Session(dataloader=dataloader, genes=genes_batch): 100 | for _, samples in it.groupby( 101 | sorted( 102 | [ 103 | to_device(x, device=torch.device("cpu")) 104 | for x in predict( 105 | num_samples=num_samples, 106 | genes_per_batch=len(genes_batch), 107 | predict_mean=predict_mean, 108 | normalize_scale=normalize, 109 | normalize_size=True, 110 | ) 111 | ], 112 | key=lambda x: x["section"], 113 | ), 114 | key=lambda x: x["section"], 115 | ): 116 | yield from _process_batch(list(samples)) 117 | 118 | 119 | def _run_gene_maps_analysis( 120 | gene_regex: str = r".*", 121 | num_samples: int = 10, 122 | genes_per_batch: int = 10, 123 | predict_mean: bool = True, 124 | normalize: bool = False, 125 | mask_tissue: bool = True, 126 | scale: float = 1.0, 127 | writer: str = "image", 128 | writer_args: Dict[str, Any] = None, 129 | ) -> None: 130 | r"""Gene maps analysis function""" 131 | 132 | genes = require("genes") 133 | slides = require("dataloader").dataset.data.slides 134 | 135 | if writer_args is None: 136 | writer_args = {} 137 | 138 | def _save_image(gene, samples, tissue_mask, fileformat="jpg"): 139 | def _prepare(x): 140 | x = balance_colors(x, q=0, q_high=0.999) 141 | x = greyscale2colormap(x) 142 | if tissue_mask is not None: 143 | x = mask_background(x, tissue_mask) 144 | return x 145 | 146 | imwrite(f"{gene}_mean.{fileformat}", _prepare(samples.mean(0))) 147 | imwrite(f"{gene}_stdv.{fileformat}", _prepare(samples.std(0))) 148 | 149 | lfc = np.log2(samples.transpose(1, 2, 0)) - np.log2( 150 | samples.mean((1, 2)) 151 | ) 152 | imwrite( 153 | f"{gene}_invcv+.{fileformat}", 154 | _prepare(lfc.mean(-1).clip(0) / lfc.std(-1)), 155 | ) 156 | 157 | def _save_tensor(gene, samples, tissue_mask): 158 | if tissue_mask is not None: 159 | samples[:, ~tissue_mask] = 0.0 160 | torch.save(samples, f"{gene}.pt") 161 | 162 | writers: Dict[str, Callable[..., None]] = { 163 | "image": _save_image, 164 | "tensor": _save_tensor, 165 | } 166 | try: 167 | write = writers[writer] 168 | except KeyError as exc: 169 | raise ValueError( 170 | 'Invalid data format "{}" (choose between: {})'.format( 171 | writer, ", ".join(f'"{x}"' for x in writers) 172 | ) 173 | ) from exc 174 | 175 | tissue_masks = {} 176 | if mask_tissue: 177 | for slide_name in slides: 178 | try: 179 | zero_count_idxs = np.where( 180 | np.array(slides[slide_name].data.counts.todense()).sum(1) 181 | == 0.0 182 | )[0] 183 | tissue_masks[slide_name] = binary_fill_holes( 184 | np.isin( 185 | slides[slide_name].data.label, 186 | 1 + zero_count_idxs, 187 | invert=True, 188 | ) 189 | ) 190 | except AttributeError: 191 | warnings.warn(f'Failed to mask "{slide_name}"') 192 | 193 | with Session( 194 | genes=[ 195 | x for x in genes if re.match(gene_regex, x, flags=re.IGNORECASE) 196 | ] 197 | ): 198 | for slide_name, gene, samples in generate_gene_maps( 199 | num_samples=num_samples, 200 | genes_per_batch=genes_per_batch, 201 | predict_mean=predict_mean, 202 | normalize=normalize, 203 | scale=scale, 204 | ): 205 | try: 206 | tissue_mask = tissue_masks[slide_name] 207 | tissue_mask = resize(tissue_mask, samples.shape[1:]) 208 | except KeyError: 209 | tissue_mask = None 210 | with chdir(slide_name): 211 | write(gene, samples, tissue_mask, **writer_args) 212 | 213 | 214 | _register_analysis( 215 | name="gene_maps", 216 | analysis=Analysis( 217 | description=( 218 | "Constructs a map of imputed expression for each gene in the" 219 | " dataset." 220 | ), 221 | function=_run_gene_maps_analysis, 222 | ), 223 | ) 224 | -------------------------------------------------------------------------------- /xfuse/analyze/metagenes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import cast 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import pyro 8 | from imageio import imwrite 9 | 10 | from ..model.experiment.st.st import ST, _encode_metagene_name 11 | from ..session import Session, require 12 | from ..utility.visualization import visualize_metagenes 13 | from .analyze import Analysis, _register_analysis 14 | 15 | __all__ = [ 16 | "compute_metagene_profiles", 17 | "compute_metagene_summary", 18 | "visualize_metagene_profile", 19 | ] 20 | 21 | 22 | def compute_metagene_profiles(): 23 | r"""Computes metagene profiles""" 24 | model = require("model") 25 | genes = require("genes") 26 | 27 | def _metagene_profile_st(): 28 | model = require("model") 29 | experiment = cast(ST, model.get_experiment("ST")) 30 | with pyro.poutine.block(): 31 | with pyro.poutine.trace() as trace: 32 | # pylint: disable=protected-access 33 | experiment._sample_metagenes() 34 | return [ 35 | (n, trace.trace.nodes[_encode_metagene_name(n)]["fn"]) 36 | for n in experiment.metagenes 37 | ] 38 | 39 | _metagene_profile_fn = {"ST": _metagene_profile_st} 40 | 41 | for experiment in model.experiments.keys(): 42 | try: 43 | fn = _metagene_profile_fn[experiment] 44 | except KeyError: 45 | warnings.warn( 46 | f'Metagene profiles for experiment of type "{experiment}"' 47 | " not implemented" 48 | ) 49 | continue 50 | 51 | names, profiles = zip(*fn()) 52 | dataframe = ( 53 | pd.concat( 54 | [ 55 | pd.DataFrame( 56 | [ 57 | x.mean.detach().cpu().numpy(), 58 | x.stddev.detach().cpu().numpy(), 59 | ], 60 | columns=genes, 61 | index=pd.Index(["mean", "stddev"], name="log2fold"), 62 | ) 63 | for x in profiles 64 | ], 65 | keys=pd.Index(names, name="metagene"), 66 | ) 67 | .reset_index() 68 | .melt( 69 | ["metagene", "log2fold"], var_name="gene", value_name="value", 70 | ) 71 | .pivot_table( 72 | index=["metagene", "gene"], columns="log2fold", values="value", 73 | ) 74 | ) 75 | yield experiment, dataframe 76 | 77 | 78 | def visualize_metagene_profile( 79 | profile, num_high=20, num_low=20, sort_by="mean", ax=None, 80 | ): 81 | r"""Creates metagene profile visualization""" 82 | num_low = max(min(len(profile) - num_high, num_low), 0) 83 | x = profile.sort_values(sort_by) 84 | x = pd.concat([x.iloc[:num_low], x.iloc[-num_high:]]) 85 | (ax if ax else plt).errorbar( 86 | x["mean"], x.index, xerr=x["stddev"], fmt="none", c="black" 87 | ) 88 | (ax if ax else plt).vlines( 89 | 0.0, 90 | ymin=x.index[0], 91 | ymax=x.index[-1], 92 | colors="red", 93 | linestyles="--", 94 | lw=1, 95 | ) 96 | 97 | 98 | def compute_metagene_summary(method: str = "pca") -> None: 99 | r"""Computes metagene summary""" 100 | # pylint: disable=too-many-locals 101 | with Session(messengers=[]): 102 | for (slide_name, summarization, metagenes) in visualize_metagenes( 103 | method 104 | ): 105 | os.makedirs(slide_name, exist_ok=True) 106 | imwrite( 107 | os.path.join(slide_name, "summary.png"), summarization, 108 | ) 109 | for name, metagene in metagenes: 110 | imwrite( 111 | os.path.join(slide_name, f"metagene-{name}.png"), metagene, 112 | ) 113 | 114 | for experiment, metagene_profiles in compute_metagene_profiles(): 115 | metagene_profiles.to_csv(f"{experiment}-metagene-log2fold.csv.gz") 116 | metagene_profiles["invcv"] = ( 117 | metagene_profiles["mean"] / metagene_profiles["stddev"] 118 | ) 119 | for metagene, profile in metagene_profiles.groupby(level=0): 120 | plt.figure(figsize=(4, 10)) 121 | visualize_metagene_profile( 122 | profile.loc[metagene], 123 | num_high=30, 124 | num_low=15, 125 | sort_by="invcv", 126 | ) 127 | plt.title(f"{metagene=} ({experiment})") 128 | plt.tight_layout(pad=0.0) 129 | plt.savefig( 130 | f"{experiment}-metagene-{metagene}-invcvsort.png", dpi=600, 131 | ) 132 | plt.close() 133 | 134 | plt.figure(figsize=(4, 10)) 135 | visualize_metagene_profile( 136 | profile.loc[metagene], 137 | num_high=30, 138 | num_low=15, 139 | sort_by="mean", 140 | ) 141 | plt.title(f"{metagene=} ({experiment})") 142 | plt.tight_layout(pad=0.0) 143 | plt.savefig( 144 | f"{experiment}-metagene-{metagene}-meansort.png", dpi=600, 145 | ) 146 | plt.close() 147 | 148 | 149 | _register_analysis( 150 | name="metagenes", 151 | analysis=Analysis( 152 | description="Creates summary data of the metagenes", 153 | function=compute_metagene_summary, 154 | ), 155 | ) 156 | -------------------------------------------------------------------------------- /xfuse/analyze/prediction.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | from copy import copy 3 | from typing import Any, Dict, Iterable, List, Optional 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pyro 8 | import torch 9 | 10 | from ..data import Data, Dataset 11 | from ..data.slide import AnnotatedImage, FullSlideIterator, Slide 12 | from ..data.utility.misc import make_dataloader 13 | from ..logging import Progressbar 14 | from ..session import Session, require 15 | from ..utility.core import chunks_of 16 | from ..utility.tensor import to_device 17 | from .analyze import Analysis, _register_analysis 18 | 19 | 20 | def _run_model(data, normalize_covariates, normalize_size, predict_mean): 21 | genes = require("genes") 22 | model = require("model") 23 | 24 | data_type, *__this_should_be_empty = list(data.keys()) 25 | 26 | if data_type != "AnnotatedImage": 27 | raise ValueError( 28 | f'Invalid data type "{data_type}".' 29 | " Can only predict on AnnotatedImage." 30 | ) 31 | assert __this_should_be_empty == [] 32 | 33 | data = to_device(data[data_type]) 34 | 35 | # pylint: disable=fixme 36 | # FIXME: Compatibility hack. 37 | # Let's get back to this when reworking the model code. 38 | st_data = { 39 | "slide": data["slide"], 40 | "covariates": [ 41 | { 42 | covariate: condition 43 | for covariate, condition in covariates.items() 44 | if covariate not in normalize_covariates 45 | } 46 | for covariates in data["covariates"] 47 | ], 48 | "data": [ 49 | to_device( 50 | torch.zeros( 51 | int(label.max().item()), len(genes), dtype=torch.float32, 52 | ) 53 | ) 54 | for label in data["label"] 55 | ], 56 | "label": data["label"], 57 | "image": data["image"], 58 | } 59 | 60 | with Session(eval=True): 61 | with pyro.poutine.trace() as guide_trace: 62 | model.guide({"ST": st_data}) 63 | 64 | with Session(eval=True): 65 | with pyro.poutine.replay(trace=guide_trace.trace): 66 | with pyro.poutine.trace() as model_trace: 67 | model({"ST": st_data}) 68 | 69 | for i, (annotation_name, label, label_names, section_name) in enumerate( 70 | zip(data["name"], data["label"], data["label_names"], data["slide"],) 71 | ): 72 | try: 73 | idxs = ( 74 | model_trace.trace.nodes[f"ST/idx-{i}"]["value"] 75 | .long() 76 | .cpu() 77 | .numpy() 78 | ) 79 | emission_distr = model_trace.trace.nodes[f"ST/xsg-{i}"]["fn"] 80 | except KeyError: 81 | continue 82 | 83 | sample = ( 84 | emission_distr.mean if predict_mean else emission_distr.sample() 85 | ) 86 | 87 | if normalize_size: 88 | _, sizes = np.unique( 89 | np.searchsorted(idxs, label.cpu().numpy()), return_counts=True, 90 | ) 91 | sample = sample / torch.as_tensor(sizes).to(sample).unsqueeze(1) 92 | 93 | label_names = label_names[idxs] 94 | 95 | yield { 96 | "data": sample, 97 | "rownames": label_names, 98 | "colnames": genes, 99 | "section": section_name, 100 | "annotation": annotation_name, 101 | } 102 | 103 | 104 | def predict( 105 | num_samples: int = 1, 106 | genes_per_batch: int = 10, 107 | predict_mean: bool = True, 108 | normalize_scale: bool = False, 109 | normalize_size: bool = False, 110 | normalize_covariates: Optional[List[str]] = None, 111 | ) -> Iterable[Dict[str, Any]]: 112 | """Predicts gene expression""" 113 | 114 | if normalize_covariates is None: 115 | normalize_covariates = [] 116 | 117 | dataloader = require("dataloader") 118 | genes = require("genes") 119 | model = require("model") 120 | 121 | def _sample(): 122 | conditional_model = copy(model) 123 | if normalize_scale: 124 | conditional_model.model = pyro.poutine.condition( 125 | conditional_model.model, 126 | # pylint: disable=not-callable 127 | {"scale": torch.tensor(1.0)}, 128 | ) 129 | 130 | iterator = it.product(dataloader, chunks_of(genes, genes_per_batch)) 131 | 132 | with pyro.poutine.trace() as global_trace: 133 | with pyro.poutine.block( 134 | expose_fn=lambda msg: ( 135 | "is_guide" in msg 136 | and msg["is_guide"] 137 | and "is_global" in msg["infer"] 138 | and msg["infer"]["is_global"] 139 | ) 140 | ): 141 | try: 142 | data, batch_genes = next(iterator) 143 | except StopIteration: 144 | return 145 | with Session(genes=batch_genes, model=conditional_model): 146 | yield from _run_model( 147 | data=data, 148 | predict_mean=predict_mean, 149 | normalize_size=normalize_size, 150 | normalize_covariates=normalize_covariates, 151 | ) 152 | 153 | conditional_model.guide = pyro.poutine.condition( 154 | conditional_model.guide, 155 | { 156 | variable: properties["value"] 157 | for variable, properties in global_trace.trace.nodes.items() 158 | }, 159 | ) 160 | conditional_model.model = pyro.poutine.condition( 161 | conditional_model.model, 162 | { 163 | variable: properties["value"] 164 | for variable, properties in global_trace.trace.nodes.items() 165 | }, 166 | ) 167 | 168 | for data, batch_genes in iterator: 169 | with Session(genes=batch_genes, model=conditional_model): 170 | yield from _run_model( 171 | data=data, 172 | predict_mean=predict_mean, 173 | normalize_size=normalize_size, 174 | normalize_covariates=normalize_covariates, 175 | ) 176 | 177 | with Progressbar( 178 | range(1, num_samples + 1), desc="Sampling", leave=False 179 | ) as iterator: 180 | for sample_num in iterator: 181 | for sample in _sample(): 182 | yield {**sample, "sample": sample_num} 183 | 184 | 185 | def predict_df(**kwargs) -> pd.DataFrame: 186 | """ 187 | Similar to :func:`predict` but, instead of streaming result :class:`Dict`s, 188 | return all results in a tidy :class:`~pd.DataFrame`. 189 | """ 190 | return pd.concat( 191 | [ 192 | pd.DataFrame(x["data"].cpu().numpy(), columns=x["colnames"]) 193 | .assign( 194 | **{ 195 | x["annotation"]: x["rownames"], 196 | "section": x["section"], 197 | "sample": x["sample"], 198 | } 199 | ) 200 | .melt( 201 | [x["annotation"], "section", "sample"], 202 | var_name="gene", 203 | value_name="count", 204 | ) 205 | for x in predict(**kwargs) 206 | ], 207 | axis=0, 208 | ) 209 | 210 | 211 | def _run_prediction_analysis( 212 | annotation_layer: str = "", 213 | num_samples: int = 1, 214 | genes_per_batch: int = 10, 215 | predict_mean: bool = True, 216 | normalize_scale: bool = False, 217 | normalize_size: bool = False, 218 | normalize_covariates: Optional[List[str]] = None, 219 | ) -> None: 220 | """Runs prediction analysis""" 221 | 222 | if normalize_covariates is None: 223 | normalize_covariates = [] 224 | 225 | dataloader = require("dataloader") 226 | dataloader = make_dataloader( 227 | Dataset( 228 | Data( 229 | slides={ 230 | k: Slide( 231 | data=AnnotatedImage.from_st_slide( 232 | v.data, annotation_name=annotation_layer 233 | ), 234 | iterator=FullSlideIterator, 235 | ) 236 | for k, v in dataloader.dataset.data.slides.items() 237 | }, 238 | design=dataloader.dataset.data.design, 239 | ) 240 | ), 241 | batch_size=1, 242 | shuffle=False, 243 | ) 244 | 245 | with Session(dataloader=dataloader, messengers=[]): 246 | samples = predict_df( 247 | num_samples=num_samples, 248 | genes_per_batch=genes_per_batch, 249 | predict_mean=predict_mean, 250 | normalize_scale=normalize_scale, 251 | normalize_size=normalize_size, 252 | normalize_covariates=normalize_covariates, 253 | ) 254 | 255 | samples = samples.pivot( 256 | index=[annotation_layer, "section", "sample"], columns=["gene"] 257 | ) 258 | samples = samples.reset_index() 259 | samples.columns = samples.columns.map( 260 | lambda x: x[1] if x[1] != "" else x[0] 261 | ) 262 | samples.to_csv("data.csv.gz") 263 | 264 | 265 | _register_analysis( 266 | name="prediction", 267 | analysis=Analysis( 268 | description="Predicts expression data", 269 | function=_run_prediction_analysis, 270 | ), 271 | ) 272 | -------------------------------------------------------------------------------- /xfuse/convert/__init__.py: -------------------------------------------------------------------------------- 1 | from . import image 2 | from . import st 3 | from . import visium 4 | -------------------------------------------------------------------------------- /xfuse/convert/image.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from PIL import Image 6 | 7 | from ..utility.core import rescale 8 | from ..utility.mask import compute_tissue_mask 9 | from .utility import find_margin, write_data 10 | 11 | 12 | def run( 13 | image: np.ndarray, 14 | output_file: str, 15 | annotation: Optional[Dict[str, np.ndarray]] = None, 16 | scale_factor: Optional[float] = None, 17 | mask: bool = True, 18 | custom_mask: Optional[np.ndarray] = None, 19 | rotate: bool = False, 20 | ) -> None: 21 | r""" 22 | Converts image data into the data format used by xfuse. 23 | """ 24 | if annotation is None: 25 | annotation = {} 26 | 27 | if scale_factor is not None: 28 | image = rescale(image, scale_factor, Image.BOX) 29 | annotation = { 30 | k: rescale(v, scale_factor, Image.NEAREST) 31 | for k, v in annotation.items() 32 | } 33 | if custom_mask is not None: 34 | custom_mask = rescale(custom_mask, scale_factor, Image.NEAREST) 35 | 36 | col_mask, row_mask = find_margin(image) 37 | image = image[row_mask][:, col_mask] 38 | if custom_mask is not None: 39 | custom_mask = custom_mask[row_mask][:, col_mask] 40 | 41 | if scale_factor is not None: 42 | # The outermost pixels may belong in part to the margin if we 43 | # downscaled the image. Therefore, remove one extra row/column. 44 | image = image[1:-1, 1:-1] 45 | if custom_mask is not None: 46 | custom_mask = custom_mask[1:-1, 1:-1] 47 | 48 | if mask: 49 | tissue_mask = compute_tissue_mask(image, initial_mask=custom_mask) 50 | label = np.array(tissue_mask == 0, dtype=np.int16) 51 | else: 52 | label = np.zeros(image.shape[:2], dtype=np.int16) 53 | 54 | counts = pd.DataFrame( 55 | index=pd.Series(np.unique(label[label != 0]), name="n") 56 | ) 57 | 58 | write_data( 59 | counts, 60 | image, 61 | label, 62 | type_label="ST", 63 | annotation={ 64 | k: (v, {x: str(x) for x in np.unique(v)}) 65 | for k, v in annotation.items() 66 | }, 67 | auto_rotate=rotate, 68 | path=output_file, 69 | ) 70 | -------------------------------------------------------------------------------- /xfuse/convert/st.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from PIL import Image 7 | 8 | from ..utility.core import rescale 9 | from .utility import ( 10 | Spot, 11 | find_margin, 12 | labels_from_spots, 13 | mask_tissue, 14 | write_data, 15 | ) 16 | 17 | 18 | def run( 19 | counts: pd.DataFrame, 20 | image: np.ndarray, 21 | output_file: str, 22 | spots: Optional[pd.DataFrame] = None, 23 | transformation: Optional[np.ndarray] = None, 24 | annotation: Optional[Dict[str, np.ndarray]] = None, 25 | scale_factor: Optional[float] = None, 26 | mask: bool = True, 27 | custom_mask: Optional[np.ndarray] = None, 28 | rotate: bool = False, 29 | ) -> None: 30 | r""" 31 | Converts data from the Spatial Transcriptomics pipeline into the data 32 | format used by xfuse. 33 | """ 34 | if annotation is None: 35 | annotation = {} 36 | 37 | if scale_factor is not None: 38 | image = rescale(image, scale_factor, Image.BOX) 39 | annotation = { 40 | k: rescale(v, scale_factor, Image.NEAREST) 41 | for k, v in annotation.items() 42 | } 43 | if spots is not None: 44 | spots[["pixel_x", "pixel_y"]] *= scale_factor 45 | if transformation is not None: 46 | scale_matrix = np.array( 47 | [[scale_factor, 0, 0], [0, scale_factor, 0], [0, 0, 1]] 48 | ) 49 | transformation = transformation @ scale_matrix 50 | if custom_mask is not None: 51 | custom_mask = rescale(custom_mask, scale_factor, Image.NEAREST) 52 | 53 | if spots is not None: 54 | spots.index = spots[["x", "y"]].apply( 55 | lambda x: "x".join(map(str, x)), 1 56 | ) 57 | spot_names = np.intersect1d(spots.index, counts.index) 58 | spots = spots.loc[spot_names] 59 | counts = counts.loc[spot_names] 60 | xmax, xmin = [f(spots.x) for f in (np.max, np.min)] 61 | pxmax, pxmin = [ 62 | np.mean(spots.pixel_x[spots.x == x]) for x in (xmax, xmin) 63 | ] 64 | radius = (pxmax - pxmin) / (xmax - xmin) / 4 65 | spots = list( 66 | spots[["pixel_x", "pixel_y"]].apply( 67 | lambda x: Spot(*x, radius), # type: ignore 68 | 1, 69 | ) 70 | ) 71 | else: 72 | warnings.warn( 73 | "Converting data from the Spatial Transcriptomics pipeline" 74 | " without a spot detector file has been deprecated and will be" 75 | " removed in a future version.", 76 | DeprecationWarning, 77 | ) 78 | coordinates = np.array( 79 | [ 80 | [float(x), float(y)] 81 | for x, y in (x.split("x") for x in counts.index) 82 | ] 83 | ) 84 | if transformation is not None: 85 | coordinates = np.concatenate( 86 | [coordinates, np.ones((len(coordinates), 1))], axis=-1 87 | ) 88 | coordinates = coordinates @ transformation 89 | coordinates = coordinates[:, :2] 90 | else: 91 | coordinates[:, 0] = (coordinates[:, 0] - 1) / 32 * image.shape[1] 92 | coordinates[:, 1] = (coordinates[:, 1] - 1) / 34 * image.shape[0] 93 | radius = np.sqrt(np.product(image.shape[:2]) / 32 / 34) / 4 94 | spots = [Spot(x=x, y=y, r=radius) for x, y in coordinates] 95 | 96 | counts.index = pd.Index([*range(1, counts.shape[0] + 1)], name="n") 97 | 98 | label = np.zeros(image.shape[:2]).astype(np.int16) 99 | labels_from_spots(label, spots) 100 | 101 | col_mask, row_mask = find_margin(image) 102 | image = image[row_mask][:, col_mask] 103 | label = label[row_mask][:, col_mask] 104 | if custom_mask is not None: 105 | custom_mask = custom_mask[row_mask][:, col_mask] 106 | 107 | if scale_factor is not None: 108 | # The outermost pixels may belong in part to the margin if we 109 | # downscaled the image. Therefore, remove one extra row/column. 110 | image = image[1:-1, 1:-1] 111 | label = label[1:-1, 1:-1] 112 | if custom_mask is not None: 113 | custom_mask = custom_mask[1:-1, 1:-1] 114 | 115 | if mask: 116 | counts, label = mask_tissue( 117 | image, counts, label, initial_mask=custom_mask 118 | ) 119 | 120 | write_data( 121 | counts, 122 | image, 123 | label, 124 | type_label="ST", 125 | annotation={ 126 | k: (v, {x: str(x) for x in np.unique(v)}) 127 | for k, v in annotation.items() 128 | }, 129 | auto_rotate=rotate, 130 | path=output_file, 131 | ) 132 | -------------------------------------------------------------------------------- /xfuse/convert/visium.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import cv2 as cv 4 | import h5py 5 | import numpy as np 6 | import pandas as pd 7 | from PIL import Image 8 | from scipy.ndimage.morphology import distance_transform_edt 9 | from scipy.sparse import csr_matrix 10 | 11 | from ..utility.core import rescale 12 | from .utility import ( 13 | Spot, 14 | find_margin, 15 | labels_from_spots, 16 | mask_tissue, 17 | write_data, 18 | ) 19 | 20 | 21 | def run( 22 | image: np.ndarray, 23 | bc_matrix: h5py.File, 24 | tissue_positions: pd.DataFrame, 25 | spot_radius: float, 26 | output_file: str, 27 | annotation: Optional[Dict[str, np.ndarray]] = None, 28 | scale_factor: Optional[float] = None, 29 | mask: bool = True, 30 | custom_mask: Optional[np.ndarray] = None, 31 | rotate: bool = False, 32 | ) -> None: 33 | r""" 34 | Converts data from the 10X SpaceRanger pipeline for visium arrays into 35 | the data format used by xfuse. 36 | """ 37 | if annotation is None: 38 | annotation = {} 39 | 40 | counts = csr_matrix( 41 | ( 42 | bc_matrix["matrix"]["data"], 43 | bc_matrix["matrix"]["indices"], 44 | bc_matrix["matrix"]["indptr"], 45 | ), 46 | shape=( 47 | bc_matrix["matrix"]["barcodes"].shape[0], 48 | bc_matrix["matrix"]["features"]["name"].shape[0], 49 | ), 50 | ) 51 | counts = pd.DataFrame.sparse.from_spmatrix( 52 | counts.astype(float), 53 | columns=bc_matrix["matrix"]["features"]["name"][()].astype(str), 54 | index=pd.Index([*range(1, counts.shape[0] + 1)], name="n"), 55 | ) 56 | 57 | if scale_factor is not None: 58 | tissue_positions[["x", "y"]] *= scale_factor 59 | spot_radius *= scale_factor 60 | image = rescale(image, scale_factor, Image.BOX) 61 | annotation = { 62 | k: rescale(v, scale_factor, Image.NEAREST) 63 | for k, v in annotation.items() 64 | } 65 | if custom_mask is not None: 66 | custom_mask = rescale(custom_mask, scale_factor, Image.NEAREST) 67 | 68 | spots = list( 69 | tissue_positions[["x", "y"]] 70 | .loc[bc_matrix["matrix"]["barcodes"][()].astype(str)] 71 | .apply(lambda x: Spot(x=x["x"], y=x["y"], r=spot_radius), 1) 72 | ) 73 | 74 | label = np.zeros(image.shape[:2]).astype(np.int16) 75 | labels_from_spots(label, spots) 76 | 77 | col_mask, row_mask = find_margin(image) 78 | image = image[row_mask][:, col_mask] 79 | label = label[row_mask][:, col_mask] 80 | if custom_mask is not None: 81 | custom_mask = custom_mask[row_mask][:, col_mask] 82 | 83 | if scale_factor is not None: 84 | # The outermost pixels may belong in part to the margin if we 85 | # downscaled the image. Therefore, remove one extra row/column. 86 | image = image[1:-1, 1:-1] 87 | label = label[1:-1, 1:-1] 88 | if custom_mask is not None: 89 | custom_mask = custom_mask[1:-1, 1:-1] 90 | 91 | if mask: 92 | if custom_mask is not None: 93 | initial_mask = custom_mask 94 | else: 95 | (in_tissue_idxs,) = np.where( 96 | tissue_positions["in_tissue"] 97 | .loc[bc_matrix["matrix"]["barcodes"][()].astype(str)] 98 | .values 99 | ) 100 | in_tissue_idxs = in_tissue_idxs + 1 101 | in_tissue = np.where(np.isin(label, in_tissue_idxs), True, False) 102 | idx1, idx2 = distance_transform_edt( 103 | label == 0, return_indices=True, return_distances=False 104 | ) 105 | initial_mask = np.where( 106 | label != 0, 107 | np.where(in_tissue, cv.GC_FGD, cv.GC_BGD), 108 | np.where(in_tissue[idx1, idx2], cv.GC_PR_FGD, cv.GC_PR_BGD), 109 | ) 110 | initial_mask = initial_mask.astype(np.uint8) 111 | counts, label = mask_tissue( 112 | image, counts, label, initial_mask=initial_mask 113 | ) 114 | 115 | write_data( 116 | counts, 117 | image, 118 | label, 119 | type_label="ST", 120 | annotation={ 121 | k: (v, {x: str(x) for x in np.unique(v)}) 122 | for k, v in annotation.items() 123 | }, 124 | auto_rotate=rotate, 125 | path=output_file, 126 | ) 127 | -------------------------------------------------------------------------------- /xfuse/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import slide, utility 2 | from .dataset import Data, Dataset 3 | 4 | __all__ = ["Data", "Dataset", "slide", "utility"] 5 | -------------------------------------------------------------------------------- /xfuse/data/dataset.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from typing import Dict, NamedTuple, Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | 8 | from .slide import Slide 9 | from ..session import get 10 | 11 | __all__ = ["Data", "Dataset"] 12 | 13 | 14 | class Data(NamedTuple): 15 | r"""Dataset consisting of multiple instances of :class:`Slide`""" 16 | 17 | slides: Dict[str, Slide] 18 | design: pd.DataFrame 19 | 20 | 21 | class Dataset(torch.utils.data.Dataset): 22 | r""" 23 | :class:`~torch.utils.data.Dataset` yielding data from a :class:`Data` 24 | instance 25 | """ 26 | 27 | def __init__(self, data: Data): 28 | self._data = data 29 | 30 | if get("genes"): 31 | self.genes = get("genes") 32 | else: 33 | self.genes = list( 34 | sorted( 35 | reduce( 36 | set.union, # type: ignore 37 | ( 38 | set(slide.data.genes) 39 | for slide in self.data.slides.values() 40 | ), 41 | ) 42 | ) 43 | ) 44 | 45 | self._data_iterators = { 46 | name: slide.iterator(slide.data) 47 | for name, slide in self.data.slides.items() 48 | } 49 | self.observations = pd.DataFrame( 50 | dict( 51 | data_type=np.repeat( 52 | [x.data.data_type for x in self._data.slides.values()], 53 | [len(x) for x in self._data_iterators.values()], 54 | ), 55 | slide=np.repeat( 56 | list(self._data_iterators.keys()), 57 | [len(x) for x in self._data_iterators.values()], 58 | ), 59 | idx=np.concatenate( 60 | [range(len(x)) for x in self._data_iterators.values()] 61 | ), 62 | ) 63 | ) 64 | 65 | def size( 66 | self, 67 | data_type: Optional[str] = None, 68 | slide: Optional[str] = None, 69 | covariate: Optional[str] = None, 70 | condition: Optional[str] = None, 71 | ) -> int: 72 | """Returns the size of the dataset for a given `data_type`""" 73 | observations = self.observations 74 | if data_type is not None: 75 | observations = observations[observations.data_type == data_type] 76 | if slide is not None: 77 | observations = observations[observations.slide == slide] 78 | if covariate is not None: 79 | observations = observations.merge( 80 | self.data.design[covariate].rename("condition"), 81 | left_on="slide", 82 | right_index=True, 83 | ) 84 | if condition is not None: 85 | observations = observations[ 86 | observations["condition"] == condition 87 | ] 88 | return len(observations) 89 | 90 | @property 91 | def genes(self): 92 | r"""The genes present in the dataset""" 93 | return self.__genes 94 | 95 | @genes.setter 96 | def genes(self, genes): 97 | self.__genes = genes 98 | for slide in self.data.slides.values(): 99 | slide.data.genes = genes 100 | 101 | @property 102 | def data(self): 103 | r"""The underlying :class:`Data`""" 104 | return self._data 105 | 106 | def __len__(self): 107 | return len(self.observations) 108 | 109 | def __getitem__(self, idx): 110 | slide = self.observations["slide"].iloc[idx] 111 | return dict( 112 | data_type=self._data.slides[slide].data.data_type, 113 | slide=slide, 114 | covariates=dict(self.data.design.loc[slide].iteritems()), 115 | **self._data_iterators[slide].__getitem__( 116 | self.observations["idx"].iloc[idx] 117 | ), 118 | ) 119 | 120 | def __iter__(self): 121 | for idx in range(len(self)): 122 | yield self[idx] 123 | -------------------------------------------------------------------------------- /xfuse/data/slide/__init__.py: -------------------------------------------------------------------------------- 1 | from .slide import Slide 2 | from .data import AnnotatedImage, SlideData, STSlide 3 | from .iterator import ( 4 | SlideIterator, 5 | DataIterator, 6 | FullSlideIterator, 7 | RandomIterator, 8 | ) 9 | from . import data, iterator 10 | -------------------------------------------------------------------------------- /xfuse/data/slide/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .slide_data import SlideData 2 | from .annotated_image import AnnotatedImage 3 | from .st_slide import STSlide 4 | -------------------------------------------------------------------------------- /xfuse/data/slide/data/annotated_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Dict, List, Optional, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .slide_data import SlideData 8 | from .st_slide import STSlide 9 | 10 | 11 | class AnnotatedImage(SlideData): 12 | """ 13 | Data class for annotated images that lack associated expression data. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | image: torch.Tensor, 19 | annotation: torch.Tensor, 20 | name: str = "Annotation", 21 | label_names: Optional[Dict[int, str]] = None, 22 | ): 23 | if label_names is None: 24 | label_names = {x: str(x) for x in torch.unique(annotation)} 25 | self.image = image 26 | self.label = annotation 27 | self.name = name 28 | self.set_label_names(label_names) 29 | 30 | @property 31 | def label_names(self) -> np.ndarray: 32 | """Names corresponding to the integer labels in :attr:`label`""" 33 | return self._label_names_array 34 | 35 | def set_label_names(self, x: Union[np.ndarray, Dict[int, str]]) -> None: 36 | """Sets the label names""" 37 | if isinstance(x, np.ndarray): 38 | label_names = x 39 | elif isinstance(x, dict): 40 | label_names = np.full(max(x.keys()) + 1, "", dtype="object") 41 | label_names[list(x.keys())] = list(x.values()) 42 | else: 43 | raise ValueError(f"Unsupported {type(x)=}") 44 | self._label_names_array = label_names 45 | 46 | @property 47 | def data_type(self) -> str: 48 | return "AnnotatedImage" 49 | 50 | @property 51 | def genes(self) -> List[str]: 52 | return [] 53 | 54 | @genes.setter 55 | def genes(self, genes: List[str]) -> AnnotatedImage: 56 | # pylint: disable=unused-argument 57 | return self 58 | 59 | @classmethod 60 | def from_st_slide( 61 | cls, st_slide: STSlide, annotation_name: Optional[str] = None 62 | ) -> AnnotatedImage: 63 | """Creates an :class:`AnnotatedImage` from an :class:`STSlide`""" 64 | if annotation_name is None: 65 | annotation = st_slide.label[()] 66 | annotation_name = "n" 67 | label_names = None 68 | else: 69 | annotation, label_names = st_slide.annotation(annotation_name) 70 | image = st_slide.image[()] 71 | return cls( 72 | torch.as_tensor(image.astype(np.float32)), 73 | torch.as_tensor(annotation.astype(np.int64)), 74 | name=annotation_name, 75 | label_names=label_names, 76 | ) 77 | -------------------------------------------------------------------------------- /xfuse/data/slide/data/slide_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from abc import ABCMeta, abstractproperty 3 | from typing import List 4 | 5 | 6 | class SlideData(metaclass=ABCMeta): 7 | r"""Abstract class for different kinds of slide data""" 8 | 9 | @abstractproperty 10 | def data_type(self) -> str: 11 | r"""The type tag of this slide""" 12 | 13 | @abstractproperty 14 | def genes(self) -> List[str]: 15 | r"""Genes returned from this dataset""" 16 | 17 | @genes.setter 18 | def genes(self, genes: List[str]) -> SlideData: 19 | r"""Setter for which genes to return from this dataset""" 20 | -------------------------------------------------------------------------------- /xfuse/data/slide/data/st_slide.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Dict, List, Tuple 3 | 4 | import numpy as np 5 | import scipy.sparse 6 | import torch 7 | 8 | import h5py 9 | 10 | from ....logging import DEBUG, log 11 | from .slide_data import SlideData 12 | 13 | 14 | class STSlide(SlideData): 15 | r""":class:`SlideData` for Spatial Transcriptomics slides""" 16 | 17 | def __init__( 18 | self, 19 | datafile: str, 20 | cache_data: bool = True, 21 | min_counts: float = 0, 22 | always_filter: List[int] = None, 23 | always_keep: List[int] = None, 24 | ): 25 | self._datafile = datafile 26 | with h5py.File(datafile, "r") as data: 27 | self.H, self.W, _ = data["image"].shape 28 | self.genes = list(data["counts"]["columns"][()].astype(str)) 29 | self.__always_filter = always_filter or [] 30 | self.__always_keep = always_keep or [] 31 | self.cache_data = cache_data 32 | self._counts = None 33 | self._label = None 34 | self._image = None 35 | self.min_counts = min_counts 36 | 37 | @property 38 | def data_type(self) -> str: 39 | return "ST" 40 | 41 | @property 42 | def min_counts(self) -> float: 43 | r""" 44 | The minimum number of reads for an ST spot to be included in this 45 | dataset. This attribute can be used to filter out low quality spots. 46 | """ 47 | return self.__min_counts 48 | 49 | @min_counts.setter 50 | def min_counts(self, n: float): 51 | self.__min_counts = n 52 | self.__label_mask = np.unique( 53 | self.__always_filter 54 | + [ 55 | x 56 | for x in ( 57 | (np.array(self.counts.sum(1)).flatten() < n).nonzero()[0] 58 | + 1 59 | ) 60 | if x not in self.__always_keep 61 | ] 62 | ) 63 | if self.__label_mask.shape[0] > 0: 64 | log( 65 | DEBUG, 66 | "The following labels will be masked out in %s: %s", 67 | self._datafile, 68 | ", ".join(map(str, self.__label_mask)), 69 | ) 70 | 71 | @property 72 | def genes(self): 73 | return list(self.__gene_list.copy()) 74 | 75 | @genes.setter 76 | def genes(self, genes: List[str]) -> STSlide: 77 | self.__gene_list = np.array(genes) 78 | with h5py.File(self._datafile, "r") as data: 79 | idxs = { 80 | gene: i 81 | for i, gene in enumerate( 82 | data["counts"]["columns"][()].astype(str) 83 | ) 84 | } 85 | self.__gene_idxs = np.array( 86 | [idxs[gene] if gene in idxs else -1 for gene in genes] 87 | ) 88 | self._counts = None 89 | return self 90 | 91 | def __construct_count_matrix(self): 92 | with h5py.File(self._datafile, "r") as data: 93 | counts = scipy.sparse.csr_matrix( 94 | ( 95 | data["counts"]["data"], 96 | data["counts"]["indices"], 97 | data["counts"]["indptr"], 98 | ), 99 | shape=( 100 | len(data["counts"]["index"]), 101 | len(data["counts"]["columns"]), 102 | ), 103 | ) 104 | counts = scipy.sparse.hstack( 105 | [counts, np.zeros((counts.shape[0], 1))], format="csr" 106 | ) 107 | counts = counts[:, self.__gene_idxs] 108 | return counts 109 | 110 | @property 111 | def counts(self): 112 | r"""Getter for the count data""" 113 | if self._counts is not None: 114 | return self._counts 115 | counts = self.__construct_count_matrix() 116 | if self.cache_data: 117 | self._counts = counts 118 | return counts 119 | 120 | @property 121 | def image(self): 122 | r"""Getter for the slide image""" 123 | if self._image is not None: 124 | return self._image 125 | data = h5py.File(self._datafile, "r") 126 | image = data["image"] 127 | if self.cache_data: 128 | self._image = image[()] 129 | return image 130 | 131 | @property 132 | def label(self): 133 | r"""Getter for the label image of the slide""" 134 | if self._label is not None: 135 | return self._label 136 | data = h5py.File(self._datafile, "r") 137 | label = data["label"] 138 | if self.cache_data: 139 | self._label = label[()] 140 | return label 141 | 142 | def annotation(self, name) -> Tuple[np.ndarray, Dict[int, str]]: 143 | r"""Getter for annotation layers""" 144 | with h5py.File(self._datafile, "r") as data: 145 | if name not in data["annotation"]: 146 | raise RuntimeError(f'Annotation layer "{name}" is missing') 147 | return ( 148 | data["annotation"][name]["label"][()], 149 | dict( 150 | zip( 151 | data["annotation"][name]["names"]["keys"][()], 152 | data["annotation"][name]["names"]["values"][()].astype( 153 | str 154 | ), 155 | ) 156 | ), 157 | ) 158 | 159 | def prepare_data(self, image, label): 160 | r"""Prepare data from image and label patches""" 161 | 162 | label[np.isin(label, self.__label_mask)] = 0 163 | labels = np.sort(np.unique(label[label != 0])) 164 | data = self.counts[(labels - 1).tolist()] 165 | label = np.searchsorted([0, *labels], label) 166 | 167 | return dict( 168 | image=torch.as_tensor(image).float(), 169 | label=torch.as_tensor(label).long(), 170 | data=torch.as_tensor(data.todense()).float(), 171 | ) 172 | -------------------------------------------------------------------------------- /xfuse/data/slide/iterator/__init__.py: -------------------------------------------------------------------------------- 1 | from .slide_iterator import SlideIterator 2 | from .data_iterator import DataIterator 3 | from .full_slide_iterator import FullSlideIterator 4 | from .random_iterator import RandomIterator 5 | -------------------------------------------------------------------------------- /xfuse/data/slide/iterator/data_iterator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import SlideIterator 4 | from ..data import STSlide 5 | 6 | 7 | class DataIterator(SlideIterator): 8 | r""" 9 | A :class:`SlideIterator` that yields only the count data from the slide 10 | """ 11 | 12 | def __init__(self, slide: STSlide): 13 | self._slide = slide 14 | 15 | def __len__(self): 16 | return self._slide.counts.shape[0] 17 | 18 | def __getitem__(self, idx): 19 | return dict(data=torch.as_tensor(self._slide.counts[idx].todense())) 20 | -------------------------------------------------------------------------------- /xfuse/data/slide/iterator/full_slide_iterator.py: -------------------------------------------------------------------------------- 1 | from ..data import AnnotatedImage, SlideData, STSlide 2 | from . import SlideIterator 3 | 4 | 5 | class FullSlideIterator(SlideIterator): 6 | r"""A :class:`SlideIterator` that yields the full (uncropped) sample""" 7 | 8 | def __init__(self, slide: SlideData): 9 | self._slide = slide 10 | 11 | def __len__(self): 12 | return 1 13 | 14 | def __getitem__(self, idx): 15 | if isinstance(self._slide, STSlide): 16 | image = self._slide.image[()].transpose(2, 0, 1) 17 | label = self._slide.label[()] 18 | return self._slide.prepare_data(image, label) 19 | if isinstance(self._slide, AnnotatedImage): 20 | return { 21 | "image": self._slide.image.permute(2, 0, 1), 22 | "label": self._slide.label, 23 | "name": self._slide.name, 24 | "label_names": self._slide.label_names, 25 | } 26 | raise NotImplementedError() 27 | -------------------------------------------------------------------------------- /xfuse/data/slide/iterator/random_iterator.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from torchvision import transforms 6 | from torchvision.transforms.functional import ( 7 | _get_inverse_affine_matrix, 8 | to_pil_image, 9 | ) 10 | 11 | from ....utility import center_crop 12 | from ..data import STSlide 13 | from ..iterator import SlideIterator 14 | 15 | 16 | class RandomIterator(SlideIterator): 17 | r""" 18 | A :class:`SlideIterator` that yields randomly cropped patches of the sample 19 | """ 20 | 21 | def __init__( 22 | self, 23 | slide: STSlide, 24 | patch_size: Optional[Tuple[float, float]] = None, 25 | max_rotation_jitter: float = 180.0, 26 | max_scale_jitter: float = 0.05, 27 | max_shear_jitter: float = 10.0, 28 | ): 29 | self._slide = slide 30 | self.image_augmentation = transforms.Compose( 31 | [ 32 | transforms.ColorJitter( 33 | brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05 34 | ) 35 | ] 36 | ) 37 | self._max_rotation_jitter = max_rotation_jitter 38 | self._max_scale_jitter = max_scale_jitter 39 | self._max_shear_jitter = max_shear_jitter 40 | if patch_size is None: 41 | patch_size = self._slide.W, self._slide.H 42 | self._patch_w, self._patch_h = patch_size 43 | 44 | @staticmethod 45 | def _compute_extended_patch_size( 46 | w: float, h: float, rotation: float, scale: float, shear: List[float] 47 | ) -> Tuple[float, float]: 48 | transform = np.concatenate( 49 | [ 50 | np.array( 51 | _get_inverse_affine_matrix( 52 | center=(0.5, 0.5), 53 | angle=rotation, 54 | translate=(0, 0), 55 | scale=scale, 56 | shear=shear, 57 | ) 58 | ).reshape(2, -1), 59 | np.array([[0.0, 0.0, 1.0]]), 60 | ] 61 | ) 62 | corners = np.array([[0, 0, 1], [0, h, 1], [w, 0, 1], [w, h, 1]]) 63 | inv_corners = transform @ np.transpose(corners) 64 | xmax, ymax = inv_corners[:2].max(1) 65 | xmin, ymin = inv_corners[:2].min(1) 66 | return xmax - xmin, ymax - ymin 67 | 68 | def __len__(self): 69 | return int( 70 | np.ceil( 71 | self._slide.W / self._patch_w * self._slide.H / self._patch_h 72 | ) 73 | ) 74 | 75 | def __getitem__(self, idx): 76 | # pylint: disable=too-many-locals 77 | 78 | # Sample a random transformation 79 | rotation = np.random.uniform( 80 | -self._max_rotation_jitter, self._max_rotation_jitter 81 | ) 82 | scale = np.exp( 83 | np.random.uniform(-self._max_scale_jitter, self._max_scale_jitter) 84 | ) 85 | shear = np.random.uniform( 86 | -self._max_shear_jitter, self._max_shear_jitter, size=2 87 | ) 88 | 89 | # Compute the "extended" patch size. This is the size of the patch that 90 | # we will first transform and then center crop to the final size. 91 | extpatch_w, extpatch_h = self._compute_extended_patch_size( 92 | w=self._patch_w, 93 | h=self._patch_h, 94 | rotation=rotation, 95 | scale=scale, 96 | shear=shear, 97 | ) 98 | 99 | # The slide may not be large enough for the extended patch size. In 100 | # this case, we will downscale the target patch size until the extended 101 | # patch size fits. 102 | adjmul = min( 103 | 1.0, self._slide.W / extpatch_w, self._slide.H / extpatch_h 104 | ) 105 | extpatch_w = min(int(np.ceil(extpatch_w * adjmul)), self._slide.W) 106 | extpatch_h = min(int(np.ceil(extpatch_h * adjmul)), self._slide.H) 107 | patch_w = int(self._patch_w * adjmul) 108 | patch_h = int(self._patch_h * adjmul) 109 | 110 | # Extract the extended patch by sampling uniformly from the size of the 111 | # slide 112 | x, y = [ 113 | np.random.randint(a - b + 1) 114 | for a, b in zip( 115 | (self._slide.W, self._slide.H), (extpatch_w, extpatch_h) 116 | ) 117 | ] 118 | image = self._slide.image[y : y + extpatch_h, x : x + extpatch_w] 119 | image = (255 * (image + 1) / 2).astype(np.uint8) 120 | image = to_pil_image(image) 121 | label = to_pil_image( 122 | self._slide.label[y : y + extpatch_h, x : x + extpatch_w] 123 | ) 124 | 125 | # Apply augmentations 126 | output_size = (max(extpatch_w, patch_w), max(extpatch_h, patch_h)) 127 | transformation = _get_inverse_affine_matrix( 128 | center=(image.size[0] * 0.5, image.size[1] * 0.5), 129 | angle=rotation, 130 | translate=[(a - b) / 2 for a, b in zip(output_size, image.size)], 131 | scale=scale, 132 | shear=shear, 133 | ) 134 | image = self.image_augmentation(image) 135 | image = np.array( 136 | image.transform( 137 | output_size, 138 | Image.AFFINE, 139 | transformation, 140 | resample=Image.BILINEAR, 141 | ) 142 | ) 143 | image = center_crop(image, (patch_h, patch_w)) 144 | label = np.array( 145 | label.transform( 146 | output_size, 147 | Image.AFFINE, 148 | transformation, 149 | resample=Image.NEAREST, 150 | ) 151 | ) 152 | label = center_crop(label, (patch_h, patch_w)) 153 | if np.random.rand() < 0.5: 154 | image = np.flip(image, 0).copy() 155 | label = np.flip(label, 0).copy() 156 | 157 | # Convert image to the correct data format (float32 in [-1, 1] and in 158 | # CHW order) 159 | image = 2 * image.astype(np.float32) / 255 - 1 160 | image = image.transpose(2, 0, 1) 161 | 162 | return self._slide.prepare_data(image, label) 163 | -------------------------------------------------------------------------------- /xfuse/data/slide/iterator/slide_iterator.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class SlideIterator(metaclass=ABCMeta): 5 | r"""Slide iterator""" 6 | 7 | @abstractmethod 8 | def __len__(self): 9 | pass 10 | 11 | @abstractmethod 12 | def __getitem__(self, idx): 13 | pass 14 | 15 | def __iter__(self): 16 | for idx in range(len(self)): 17 | yield self.__getitem__(idx) 18 | -------------------------------------------------------------------------------- /xfuse/data/slide/slide.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, NamedTuple 2 | 3 | from .data import SlideData 4 | from .iterator import SlideIterator 5 | 6 | 7 | class Slide(NamedTuple): 8 | r"""Data structure for tissue slide""" 9 | data: SlideData 10 | iterator: Callable[[SlideData], SlideIterator] 11 | -------------------------------------------------------------------------------- /xfuse/data/utility/__init__.py: -------------------------------------------------------------------------------- 1 | from . import misc 2 | -------------------------------------------------------------------------------- /xfuse/data/utility/misc.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | from typing import Any, Dict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.dataloader import default_collate # type: ignore 7 | 8 | from ...session import get 9 | from ...utility import center_crop 10 | from ..dataset import Dataset 11 | 12 | __all__ = ["make_dataloader", "estimate_spot_size"] 13 | 14 | 15 | class _RepeatSampler: 16 | """Sampler that repeats forever.""" 17 | 18 | def __init__(self, sampler): 19 | self.sampler = sampler 20 | 21 | def __iter__(self): 22 | while True: 23 | yield from iter(self.sampler) 24 | 25 | 26 | class DataLoader(torch.utils.data.DataLoader): 27 | r""" 28 | DataLoader that avoids spawning new workers in each epoch. 29 | See https://github.com/pytorch/pytorch/issues/15849 30 | """ 31 | 32 | def __init__(self, *args, **kwargs): 33 | super().__init__(*args, **kwargs) 34 | object.__setattr__( 35 | self, "batch_sampler", _RepeatSampler(self.batch_sampler) 36 | ) 37 | self.reset_workers() 38 | 39 | def reset_workers(self): 40 | r""" 41 | Reloads worker processes on the next call to `self.__iter__`. This 42 | should be called if the dataset in the main process has been changed. 43 | """ 44 | self.__iterator = super().__iter__() 45 | return self 46 | 47 | def __len__(self): 48 | return len(self.batch_sampler.sampler) 49 | 50 | def __iter__(self): 51 | for _ in range(len(self)): 52 | # pylint: disable=stop-iteration-return 53 | yield next(self.__iterator) 54 | 55 | 56 | def estimate_spot_size(dataset: Dataset) -> Dict[str, float]: 57 | r"""Computes the mean spot size in the :class:`Dataset`""" 58 | 59 | def _compute_size(x): 60 | if x["data_type"] == "ST": 61 | zero_count_idxs = 1 + torch.where(x["data"].sum(1) == 0)[0] 62 | partial_idxs = np.unique( 63 | torch.cat( 64 | [ 65 | x["label"][0], 66 | x["label"][-1], 67 | x["label"][:, 0], 68 | x["label"][:, -1], 69 | ] 70 | ) 71 | .cpu() 72 | .numpy() 73 | ) 74 | partial_idxs = np.setdiff1d( 75 | partial_idxs, zero_count_idxs.cpu().numpy() 76 | ) 77 | mask = np.invert( 78 | np.isin(x["label"].cpu().numpy(), [0, *partial_idxs]) 79 | ) 80 | _, sizes = np.unique( 81 | x["label"].cpu().numpy()[mask].flatten(), return_counts=True, 82 | ) 83 | return sizes 84 | raise NotImplementedError() 85 | 86 | return { 87 | k: np.concatenate([v[1] for v in vs]).mean() 88 | for k, vs in it.groupby( 89 | [(x["data_type"], _compute_size(x)) for x in dataset], 90 | key=lambda x: x[0], 91 | ) 92 | } 93 | 94 | 95 | def make_dataloader(dataset: Dataset, **kwargs: Any) -> DataLoader: 96 | r"""Creates a :class:`~torch.utils.data.DataLoader` for `dataset`""" 97 | 98 | def _collate(xs): 99 | def _remove_key(v): 100 | v.pop("data_type") 101 | return v 102 | 103 | def _sort_key(x): 104 | return x["data_type"] 105 | 106 | def _collate(ys): 107 | collated_data = {} 108 | 109 | # we can't collate the count data as a tensor since its dimension 110 | # will differ between samples. therefore, we return it as a list 111 | # instead. 112 | try: 113 | collated_data.update({"data": [y.pop("data") for y in ys]}) 114 | except KeyError: 115 | pass 116 | 117 | # Collate any other non-tensor as list 118 | collated_data.update( 119 | { 120 | k: [y.pop(k) for y in ys] 121 | for k in set( 122 | k 123 | for y in ys 124 | for k, v in y.items() 125 | if not torch.is_tensor(v) 126 | ) 127 | } 128 | ) 129 | 130 | # Crop image sizes to the minimum size over the batch 131 | min_size = {} 132 | for y in ys: 133 | for k, v in y.items(): 134 | if k in min_size: 135 | min_size[k] = torch.min( 136 | min_size[k], torch.as_tensor(v.shape) 137 | ) 138 | else: 139 | min_size[k] = torch.as_tensor(v.shape) 140 | for y in ys: 141 | for k, v in min_size.items(): 142 | y[k] = center_crop(y[k], v.numpy().tolist()) 143 | collated_data.update(default_collate(ys)) 144 | 145 | return collated_data 146 | 147 | return { 148 | k: _collate([_remove_key(v) for v in vs]) 149 | for k, vs in it.groupby(sorted(xs, key=_sort_key), key=_sort_key) 150 | } 151 | 152 | def _worker_init(n): 153 | np.random.seed(np.random.get_state()[1][0] + get("training_data").step) 154 | np.random.seed(np.random.randint(np.iinfo(np.int32).max) + n) 155 | 156 | return DataLoader( 157 | dataset, collate_fn=_collate, worker_init_fn=_worker_init, **kwargs 158 | ) 159 | -------------------------------------------------------------------------------- /xfuse/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from .logging import ( 2 | Progressbar, 3 | DEBUG, 4 | ERROR, 5 | INFO, 6 | WARNING, 7 | LOGGER, 8 | log, 9 | set_level, 10 | ) 11 | -------------------------------------------------------------------------------- /xfuse/logging/formatter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from . import DEBUG, INFO, WARNING, ERROR 4 | 5 | 6 | LEVEL_NAMES = { 7 | DEBUG: "DEBUG", 8 | INFO: "INFO", 9 | WARNING: "WARNING", 10 | ERROR: "ERROR", 11 | } 12 | LEVEL_NAMES_FANCY = { 13 | DEBUG: "🐛", 14 | INFO: "ℹ", 15 | WARNING: "⚠ WARNING", 16 | ERROR: "🚨 ERROR", 17 | } 18 | 19 | 20 | class Formatter(logging.Formatter): 21 | r"""Custom log message formatter""" 22 | 23 | def __init__(self, *args, fancy_formatting=False, **kwargs): 24 | self.fancy = fancy_formatting 25 | super().__init__(*args, **kwargs) 26 | 27 | def format(self, record): 28 | if self.fancy: 29 | if record.levelno >= ERROR: 30 | style = "\033[1m\033[91m" 31 | elif record.levelno >= WARNING: 32 | style = "\033[1m\033[93m" 33 | elif record.levelno >= INFO: 34 | style = "\033[1m" 35 | else: 36 | style = "" 37 | reset_style = "\033[0m" 38 | else: 39 | style = "" 40 | reset_style = "" 41 | 42 | try: 43 | levelname = (LEVEL_NAMES_FANCY if self.fancy else LEVEL_NAMES)[ 44 | record.levelno 45 | ] 46 | except KeyError: 47 | levelname = str(record.levelno) 48 | 49 | if record.levelno == DEBUG: 50 | where = f"({record.filename}:{record.lineno})" 51 | else: 52 | where = None 53 | 54 | return " ".join( 55 | x 56 | for x in [ 57 | f"[{self.formatTime(record)}]", 58 | "".join([style, levelname, reset_style]), 59 | where, 60 | ":", 61 | record.getMessage(), 62 | ] 63 | if x 64 | ) 65 | -------------------------------------------------------------------------------- /xfuse/logging/logging.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | from functools import wraps 4 | from logging import ( # pylint: disable=unused-import 5 | DEBUG, 6 | ERROR, 7 | INFO, 8 | WARNING, 9 | ) 10 | from typing import List 11 | 12 | from tqdm import tqdm 13 | 14 | from ..utility import temp_attr 15 | 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | _PROGRESSBARS: List[tqdm] = [] 20 | 21 | 22 | def _refresh_progressbars(): 23 | # pylint: disable=protected-access 24 | for i, pbar in enumerate(reversed(_PROGRESSBARS)): 25 | if pbar._tqdm_instance.pos != i: 26 | pbar._tqdm_instance.clear() 27 | pbar._tqdm_instance.pos = i 28 | pbar._tqdm_instance.refresh() 29 | 30 | 31 | @wraps(LOGGER.log) 32 | def log(*args, **kwargs): 33 | # pylint: disable=missing-function-docstring 34 | # pylint: disable=protected-access 35 | for pbar in _PROGRESSBARS: 36 | pbar._tqdm_instance.clear() 37 | msg_frame = inspect.currentframe().f_back 38 | with temp_attr( 39 | LOGGER, 40 | "findCaller", 41 | lambda self, stack_info=None: ( 42 | msg_frame.f_code.co_filename, 43 | msg_frame.f_lineno, 44 | msg_frame.f_code.co_name, 45 | None, 46 | ), 47 | ): 48 | LOGGER.log(*args, **kwargs) 49 | for pbar in _PROGRESSBARS: 50 | pbar._tqdm_instance.refresh() 51 | 52 | 53 | def set_level(level: int): 54 | r"""Set logging level""" 55 | LOGGER.setLevel(level) 56 | 57 | 58 | class Progressbar: 59 | r""" 60 | Context manager for creating progress bars compatible with the logging 61 | environment 62 | """ 63 | 64 | def __init__(self, iterable, /, *, position=-1, **kwargs): 65 | self._iterable = iterable 66 | self._position = position 67 | self._kwargs = kwargs 68 | self._tqdm_instance = None 69 | 70 | def __enter__(self): 71 | # pylint: disable=no-member,attribute-defined-outside-init 72 | # ^ disable false positive linting errors 73 | self._tqdm_instance = tqdm(self._iterable, **self._kwargs) 74 | _PROGRESSBARS.insert(self._position % (len(_PROGRESSBARS) + 1), self) 75 | _refresh_progressbars() 76 | return self._tqdm_instance 77 | 78 | def __exit__(self, err_type, err, tb): 79 | _PROGRESSBARS.remove(self) 80 | self._tqdm_instance.close() 81 | _refresh_progressbars() 82 | -------------------------------------------------------------------------------- /xfuse/messengers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import stats 2 | from .analysis_runner import * 3 | from .checkpointer import * 4 | -------------------------------------------------------------------------------- /xfuse/messengers/analysis_runner.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Dict, Tuple 3 | 4 | from pyro.poutine.messenger import Messenger 5 | 6 | from ..analyze import analyses as _analyses 7 | from ..logging import INFO, log 8 | from ..session import Session, get 9 | from ..utility.file import chdir 10 | 11 | 12 | class AnalysisRunner(Messenger): 13 | r"""Saves the currently running session to disk at a fixed interval""" 14 | 15 | def __init__( 16 | self, 17 | analyses: Dict[str, Tuple[str, Dict[str, Any]]], 18 | period: int = 10000, 19 | ): 20 | super().__init__() 21 | self._analyses = analyses 22 | self._period = period 23 | 24 | def _pyro_post_step(self, _msg): 25 | if (step := get("training_data").step) % self._period == 0: 26 | for name, (analysis_type, options) in self._analyses.items(): 27 | if analysis_type in _analyses: 28 | log(INFO, 'Running analysis "%s"', name) 29 | with Session(messengers=[]): 30 | with chdir(f"/analyses/step-{step:06d}/{name}"): 31 | _analyses[analysis_type].function(**options) 32 | else: 33 | warnings.warn(f'Unknown analysis "{analysis_type}"') 34 | -------------------------------------------------------------------------------- /xfuse/messengers/checkpointer.py: -------------------------------------------------------------------------------- 1 | from pyro.poutine.messenger import Messenger 2 | 3 | from ..session.io import save_session 4 | from ..utility.file import chdir 5 | 6 | 7 | class Checkpointer(Messenger): 8 | r"""Saves the currently running session to disk at a fixed interval""" 9 | 10 | def __init__(self, period: int = 1): 11 | super().__init__() 12 | self._period = period 13 | 14 | def _pyro_post_epoch(self, msg): 15 | epoch = msg["kwargs"]["epoch"] 16 | if epoch % self._period == 0: 17 | with chdir("/checkpoints"): 18 | save_session(f"epoch-{epoch:08d}") 19 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .conditions import * 4 | from .elbo import * 5 | from .metagene_activation import * 6 | from .image import * 7 | from .latent import * 8 | from .rmse import * 9 | from .scale import * 10 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/conditions.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | 5 | from .stats_handler import StatsHandler, log_scalars 6 | from ...session import get 7 | 8 | 9 | class Conditions(StatsHandler): 10 | r"""Root-mean-square error stats tracker""" 11 | 12 | def _select_msg(self, type, **_): 13 | # pylint: disable=arguments-differ 14 | # pylint: disable=redefined-builtin 15 | return type == "step" 16 | 17 | def _handle(self, value, **_): 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=no-member 20 | for slide, covariate, logits in [ 21 | (*match.groups(), x["value"]) 22 | for x in value.nodes.values() 23 | for match in [re.match("^logits-(.*)-(.*)$", x["site"])] 24 | if match 25 | ]: 26 | log_scalars( 27 | "/".join(["conditions", slide, covariate]), 28 | { 29 | condition: prob.item() 30 | for prob, condition in zip( 31 | torch.softmax(logits, 0), get("covariates")[covariate] 32 | ) 33 | }, 34 | ) 35 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/elbo.py: -------------------------------------------------------------------------------- 1 | from .stats_handler import StatsHandler, log_scalar 2 | from ...session import get 3 | 4 | 5 | class ELBO(StatsHandler): 6 | r"""ELBO stats tracker""" 7 | 8 | def _select_msg(self, type, **_): 9 | # pylint: disable=arguments-differ 10 | # pylint: disable=redefined-builtin 11 | return type == "step" 12 | 13 | def _handle(self, value, **_): 14 | # pylint: disable=arguments-differ 15 | # pylint: disable=no-member 16 | training_data = get("training_data") 17 | model_log_prob = value.log_prob_sum( 18 | site_filter=lambda _, x: x["is_guide"] 19 | ) 20 | guide_log_prob = value.log_prob_sum( 21 | site_filter=lambda _, x: not x["is_guide"] 22 | ) 23 | elbo = (guide_log_prob - model_log_prob).item() 24 | try: 25 | training_data.elbo_short = training_data.elbo_short + 1e-3 * ( 26 | elbo - training_data.elbo_short 27 | ) 28 | training_data.elbo_long = training_data.elbo_long + 1e-4 * ( 29 | elbo - training_data.elbo_long 30 | ) 31 | except TypeError: 32 | training_data.elbo_short = elbo 33 | training_data.elbo_long = elbo 34 | log_scalar("loss/elbo", elbo) 35 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/image.py: -------------------------------------------------------------------------------- 1 | from .stats_handler import StatsHandler, log_images 2 | 3 | 4 | class Image(StatsHandler): 5 | r"""Image stats tracker""" 6 | 7 | def _select_msg(self, type, name, is_observed, **_): 8 | # pylint: disable=arguments-differ 9 | # pylint: disable=redefined-builtin 10 | return type == "sample" and is_observed and name[-5:] == "image" 11 | 12 | def _handle(self, fn, value, **_): 13 | # pylint: disable=arguments-differ 14 | # pylint: disable=no-member 15 | log_images("image/ground_truth", (1 + value.permute(0, 2, 3, 1)) / 2) 16 | log_images("image/mean", (1 + fn.mean.permute(0, 2, 3, 1)) / 2) 17 | log_images("image/sample", (1 + fn.sample().permute(0, 2, 3, 1)) / 2) 18 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/latent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .stats_handler import StatsHandler, log_images 4 | from ...utility.visualization import reduce_last_dimension 5 | 6 | 7 | class Latent(StatsHandler): 8 | r"""Latent state stats tracker""" 9 | 10 | def _select_msg(self, type, name, **msg): 11 | # pylint: disable=arguments-differ 12 | # pylint: disable=redefined-builtin 13 | return type == "sample" and name[:2] == "z-" and not msg["is_guide"] 14 | 15 | def _handle(self, value, name, **_): 16 | # pylint: disable=arguments-differ 17 | # pylint: disable=no-member 18 | try: 19 | log_images( 20 | f"z/{name[2:]}", 21 | torch.as_tensor( 22 | reduce_last_dimension(value.permute(0, 2, 3, 1)) 23 | ), 24 | ) 25 | except ValueError: 26 | pass 27 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/metagene_activation.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import matplotlib.pyplot as plt 4 | import pyro.poutine 5 | import torch 6 | 7 | from ...analyze.metagenes import ( 8 | compute_metagene_profiles, 9 | visualize_metagene_profile, 10 | ) 11 | from ...model.experiment.st import ST 12 | from ...session import get 13 | from ...utility.visualization import reduce_last_dimension, visualize_metagenes 14 | from .stats_handler import ( 15 | StatsHandler, 16 | log_figure, 17 | log_histogram, 18 | log_image, 19 | log_images, 20 | log_scalar, 21 | ) 22 | 23 | __all__ = [ 24 | "MetageneHistogram", 25 | "MetageneMean", 26 | "MetageneSummary", 27 | "MetageneFullSummary", 28 | ] 29 | 30 | 31 | class Metagene(StatsHandler): 32 | r"""Abstract class for metagene trackers""" 33 | 34 | def _select_msg(self, type, name, **_): 35 | # pylint: disable=arguments-differ 36 | # pylint: disable=redefined-builtin 37 | return type == "sample" and name[-3:] == "rim" 38 | 39 | @abstractmethod 40 | def _handle_metagene(self, name, metagene): 41 | pass 42 | 43 | def _handle(self, fn, **_): 44 | # pylint: disable=arguments-differ 45 | if (model := get("model")) is not None: 46 | experiment: ST = model.get_experiment("ST") 47 | for name, metagene in zip( 48 | experiment.metagenes, fn.mean.permute(1, 0, 2, 3) 49 | ): 50 | self._handle_metagene(name, metagene) 51 | 52 | 53 | class MetageneHistogram(Metagene): 54 | r"""Summarizes the spatial activation of each metagene in a histogram""" 55 | 56 | def _handle_metagene(self, name, metagene): 57 | # pylint: disable=no-member 58 | log_histogram( 59 | f"metagene-histogram/metagene-{name}", metagene.flatten() 60 | ) 61 | 62 | 63 | class MetageneMean(Metagene): 64 | r"""Summarizes the mean spatial activation of each metagene""" 65 | 66 | def _handle_metagene(self, name, metagene): 67 | # pylint: disable=no-member 68 | log_scalar(f"metagene-mean/metagene-{name}", metagene.mean()) 69 | 70 | 71 | class MetageneSummary(StatsHandler): 72 | r"""Plots summarized spatial activations of all metagenes""" 73 | 74 | def _select_msg(self, type, name, value, **msg): 75 | # pylint: disable=arguments-differ 76 | # pylint: disable=redefined-builtin 77 | return ( 78 | type == "sample" 79 | and not msg["is_guide"] 80 | and name[-3:] == "rim" 81 | and value.shape[1] >= 3 82 | ) 83 | 84 | def _handle(self, fn, **_): 85 | # pylint: disable=arguments-differ 86 | # pylint: disable=no-member 87 | try: 88 | log_images( 89 | "metagene-batch-summary", 90 | torch.as_tensor( 91 | reduce_last_dimension(fn.mean.permute(0, 2, 3, 1)) 92 | ), 93 | ) 94 | except ValueError: 95 | pass 96 | 97 | 98 | class MetageneFullSummary(StatsHandler): 99 | r"""Plots summarized spatial activations of all metagenes in each sample""" 100 | 101 | def _select_msg(self, type, **_): 102 | # pylint: disable=arguments-differ 103 | # pylint: disable=redefined-builtin 104 | return type == "step" 105 | 106 | def _handle(self, **msg): 107 | try: 108 | with pyro.poutine.block(): 109 | for ( 110 | slide_name, 111 | summarization, 112 | metagenes, 113 | ) in visualize_metagenes(): 114 | # pylint: disable=no-member 115 | log_image( 116 | f"metagene-summary/{slide_name}", 117 | torch.as_tensor(summarization), 118 | ) 119 | for name, metagene in metagenes: 120 | log_image( 121 | f"metagene-{name}/{slide_name}", 122 | torch.as_tensor(metagene), 123 | ) 124 | except ValueError: 125 | pass 126 | 127 | for experiment, metagene_profiles in compute_metagene_profiles(): 128 | metagene_profiles["invcv"] = ( 129 | metagene_profiles["mean"] / metagene_profiles["stddev"] 130 | ) 131 | for name, profile in metagene_profiles.groupby(level=0): 132 | fig = plt.figure(figsize=(3.5, 3.7)) 133 | visualize_metagene_profile( 134 | profile.loc[name], 135 | num_high=20, 136 | num_low=10, 137 | sort_by="invcv", 138 | ) 139 | plt.tight_layout(pad=0.0) 140 | log_figure( 141 | f"metagene-{name}/profile/{experiment}/invcvsort", fig, 142 | ) 143 | 144 | fig = plt.figure(figsize=(3.5, 3.7)) 145 | visualize_metagene_profile( 146 | profile.loc[name], num_high=20, num_low=10, sort_by="mean", 147 | ) 148 | plt.tight_layout(pad=0.0) 149 | log_figure( 150 | f"metagene-{name}/profile/{experiment}/meansort", fig, 151 | ) 152 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/rmse.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | 5 | from .stats_handler import StatsHandler, log_scalar 6 | from ...session import get 7 | 8 | 9 | class RMSE(StatsHandler): 10 | r"""Root-mean-square error stats tracker""" 11 | 12 | def _select_msg(self, type, **_): 13 | # pylint: disable=arguments-differ 14 | # pylint: disable=redefined-builtin 15 | return type == "step" 16 | 17 | def _handle(self, value, **_): 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=no-member 20 | training_data = get("training_data") 21 | 22 | try: 23 | means, values = zip( 24 | *[ 25 | (x["fn"].mean, x["value"]) 26 | for x in value.nodes.values() 27 | if re.match("ST/xsg-\\d+", x["site"]) 28 | ] 29 | ) 30 | except ValueError: 31 | return 32 | 33 | rmse = ( 34 | ((torch.cat(means) - torch.cat(values)) ** 2) 35 | .mean(1) 36 | .sqrt() 37 | .mean() 38 | .item() 39 | ) 40 | 41 | try: 42 | training_data.rmse = training_data.rmse + 1e-3 * ( 43 | rmse - training_data.rmse 44 | ) 45 | except TypeError: 46 | training_data.rmse = rmse 47 | 48 | log_scalar("accuracy/rmse", rmse) 49 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/scale.py: -------------------------------------------------------------------------------- 1 | from .stats_handler import StatsHandler, log_images 2 | 3 | 4 | class Scale(StatsHandler): 5 | r"""Scaling factor stats tracker""" 6 | 7 | def _select_msg(self, type, name, **_): 8 | # pylint: disable=arguments-differ 9 | # pylint: disable=redefined-builtin 10 | return type == "sample" and name[-5:] == "scale" 11 | 12 | def _handle(self, fn, **_): 13 | # pylint: disable=arguments-differ 14 | # pylint: disable=no-member 15 | scale = fn.mean.permute(0, 2, 3, 1) 16 | scale = scale / scale.max() 17 | log_images("scale", scale) 18 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/stats_handler.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from io import BytesIO 3 | from typing import Callable, List, Optional 4 | 5 | import matplotlib 6 | import torch 7 | from imageio import imread 8 | from pyro.poutine.messenger import Messenger 9 | 10 | from ...logging import DEBUG, log 11 | from ...session import Session, get 12 | from ...utility.file import chdir 13 | from .writer import StatsWriter 14 | 15 | 16 | __all__ = [ 17 | "StatsHandler", 18 | "log_figure", 19 | "log_histogram", 20 | "log_image", 21 | "log_images", 22 | "log_scalar", 23 | ] 24 | 25 | 26 | class StatsHandler(Messenger, metaclass=ABCMeta): 27 | r"""Abstract class for stats trackers""" 28 | 29 | def __init__( 30 | self, predicate: Optional[Callable[..., bool]] = None, 31 | ): 32 | super().__init__() 33 | 34 | if predicate is None: 35 | predicate = lambda **_: not get("eval") 36 | self.predicate = predicate 37 | 38 | def __enter__(self, *args, **kwargs): 39 | # pylint: disable=arguments-differ 40 | log(DEBUG, "Activating stats tracker: %s", type(self).__name__) 41 | super().__enter__(*args, **kwargs) 42 | 43 | def __exit__(self, *args, **kwargs): 44 | # pylint: disable=arguments-differ 45 | log(DEBUG, "Deactivating stats tracker: %s", type(self).__name__) 46 | super().__exit__(*args, **kwargs) 47 | 48 | @abstractmethod 49 | def _handle(self, **msg) -> None: 50 | pass 51 | 52 | @abstractmethod 53 | def _select_msg(self, **msg) -> bool: 54 | pass 55 | 56 | def _postprocess_message(self, msg): 57 | if self._select_msg(**msg) and self.predicate(**msg): 58 | self._handle(**msg) 59 | 60 | 61 | def log_figure(tag: str, figure: matplotlib.figure.Figure, **kwargs,) -> None: 62 | r""" 63 | Converts :class:`~matplotlib.figure.Figure`` to image data and logs it 64 | using :func:`log_image` 65 | """ 66 | if "format" not in kwargs: 67 | kwargs["format"] = "tiff" 68 | bio = BytesIO() 69 | with Session(mpl_backend="Agg"): 70 | figure.savefig(bio, **kwargs) 71 | bio.seek(0) 72 | fig_image = torch.as_tensor(imread(bio)) 73 | log_image(tag, img_tensor=fig_image) 74 | 75 | 76 | def log_histogram(*args, **kwargs) -> None: 77 | r"""Pushes histogram data to the session `stats_writers`""" 78 | stats_writers: List[StatsWriter] = get("stats_writers") 79 | with chdir("/stats"), torch.no_grad(): 80 | for stats_writer in stats_writers: 81 | stats_writer.write_histogram(*args, **kwargs) 82 | 83 | 84 | def log_image(*args, **kwargs) -> None: 85 | r"""Pushes image data to the session `stats_writers`""" 86 | stats_writers: List[StatsWriter] = get("stats_writers") 87 | with chdir("/stats"), torch.no_grad(): 88 | for stats_writer in stats_writers: 89 | stats_writer.write_image(*args, **kwargs) 90 | 91 | 92 | def log_images(*args, **kwargs) -> None: 93 | r"""Pushes image grid to the session `stats_writers`""" 94 | stats_writers: List[StatsWriter] = get("stats_writers") 95 | with chdir("/stats"), torch.no_grad(): 96 | for stats_writer in stats_writers: 97 | stats_writer.write_images(*args, **kwargs) 98 | 99 | 100 | def log_scalar(*args, **kwargs) -> None: 101 | r"""Pushes scalar to the session `stats_writers`""" 102 | stats_writers: List[StatsWriter] = get("stats_writers") 103 | with chdir("/stats"), torch.no_grad(): 104 | for stats_writer in stats_writers: 105 | stats_writer.write_scalar(*args, **kwargs) 106 | 107 | 108 | def log_scalars(*args, **kwargs) -> None: 109 | r"""Pushes scalars to the session `stats_writers`""" 110 | stats_writers: List[StatsWriter] = get("stats_writers") 111 | with chdir("/stats"), torch.no_grad(): 112 | for stats_writer in stats_writers: 113 | stats_writer.write_scalars(*args, **kwargs) 114 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/writer/__init__.py: -------------------------------------------------------------------------------- 1 | from .stats_writer import StatsWriter 2 | from .file import FileWriter 3 | from .tensorboard import TensorboardWriter 4 | 5 | 6 | __all__ = [ 7 | "StatsWriter", 8 | "FileWriter", 9 | "TensorboardWriter", 10 | ] 11 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/writer/file.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | import time 4 | from typing import Dict, Optional 5 | from warnings import warn 6 | 7 | import numpy as np 8 | import torch 9 | from imageio import imwrite 10 | 11 | from . import StatsWriter 12 | from ....session import get 13 | from ....utility.file import first_unique_filename 14 | from ....utility.visualization import _normalize 15 | 16 | 17 | __all__ = ["FileWriter"] 18 | 19 | 20 | class FileWriter(StatsWriter): 21 | r"""Stats writer emitting .jpg and .csv.gz files""" 22 | 23 | def __init__(self): 24 | self._file_cons: Dict[str, gzip.GzipFile] = {} 25 | 26 | def write_histogram( 27 | self, tag: str, values: torch.Tensor, bins: Optional[int] = None 28 | ) -> None: 29 | r"""Logs a histogram""" 30 | warn( 31 | RuntimeWarning( 32 | "Histogram logging is not yet supported for this writer" 33 | ) 34 | ) 35 | 36 | def write_image(self, tag: str, img_tensor: torch.Tensor) -> None: 37 | r"""Logs an image""" 38 | training_data = get("training_data") 39 | *prefix, name = tag.split("/") 40 | filename = first_unique_filename( 41 | os.path.join( 42 | *prefix, 43 | f"{name}-{training_data.epoch}-{training_data.step}.png", 44 | ) 45 | ) 46 | if (dirname := os.path.dirname(filename)) != "": 47 | os.makedirs(dirname, exist_ok=True) 48 | img = img_tensor.detach().cpu().numpy() 49 | img = _normalize(img) 50 | img = (255 * img).astype(np.uint8) 51 | imwrite(os.path.abspath(filename), img) 52 | 53 | def write_images(self, tag: str, img_tensor: torch.Tensor) -> None: 54 | r"""Logs an image grid""" 55 | N, H, W, C = img_tensor.shape # pylint: disable=invalid-name 56 | cols = int(np.ceil(N ** 0.5)) 57 | rows = int(np.ceil(N / cols)) 58 | img_tensor = torch.cat( 59 | [img_tensor, torch.zeros(rows * cols - N, H, W, C).to(img_tensor)], 60 | ) 61 | img_tensor = ( 62 | img_tensor.reshape(rows, cols, H, W, C) 63 | .permute(0, 2, 1, 3, 4) 64 | .reshape(rows * H, cols * W, C) 65 | ) 66 | self.write_image(tag, img_tensor) 67 | 68 | def write_scalar(self, tag: str, scalar_value: float) -> None: 69 | r"""Logs a scalar""" 70 | training_data = get("training_data") 71 | *prefix, name = tag.split("/") 72 | filename = os.path.join(*prefix, f"{name}.csv.gz") 73 | if filename not in self._file_cons: 74 | os.makedirs(os.path.dirname(filename), exist_ok=True) 75 | if not os.path.exists(filename): 76 | with gzip.open(filename, "wb") as fcon: 77 | fcon.write("time,epoch,step,value\n".encode()) 78 | self._file_cons[filename] = gzip.open(filename, "ab") 79 | self._file_cons[filename].write( 80 | str.encode( 81 | "{:f},{:d},{:d},{:f}\n".format( 82 | time.time(), 83 | training_data.epoch, 84 | training_data.step, 85 | scalar_value, 86 | ) 87 | ) 88 | ) 89 | 90 | def write_scalars(self, tag: str, scalar_values: Dict[str, float]) -> None: 91 | r"""Logs a set of associated scalars""" 92 | training_data = get("training_data") 93 | *prefix, name = tag.split("/") 94 | filename = os.path.join(*prefix, f"{name}.csv.gz") 95 | if filename not in self._file_cons: 96 | os.makedirs(os.path.dirname(filename), exist_ok=True) 97 | if not os.path.exists(filename): 98 | with gzip.open(filename, "wb") as fcon: 99 | fcon.write("time,epoch,step,name,value\n".encode()) 100 | self._file_cons[filename] = gzip.open(filename, "ab") 101 | for name, value in scalar_values.items(): 102 | self._file_cons[filename].write( 103 | str.encode( 104 | "{:f},{:d},{:d},{:s},{:f}\n".format( 105 | time.time(), 106 | training_data.epoch, 107 | training_data.step, 108 | name, 109 | value, 110 | ) 111 | ) 112 | ) 113 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/writer/stats_writer.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | 6 | 7 | __all__ = ["StatsWriter"] 8 | 9 | 10 | class StatsWriter(metaclass=ABCMeta): 11 | r"""Abstract class for stats writers""" 12 | 13 | @abstractmethod 14 | def write_histogram( 15 | self, tag: str, values: torch.Tensor, bins: Optional[int] = None 16 | ) -> None: 17 | r"""Writes histogram data""" 18 | 19 | @abstractmethod 20 | def write_image(self, tag: str, img_tensor: torch.Tensor) -> None: 21 | r"""Writes image data""" 22 | 23 | @abstractmethod 24 | def write_images(self, tag: str, img_tensor: torch.Tensor) -> None: 25 | r"""Writes image grid data""" 26 | 27 | @abstractmethod 28 | def write_scalar(self, tag: str, scalar_value: float) -> None: 29 | r"""Writes scalar data""" 30 | 31 | @abstractmethod 32 | def write_scalars(self, tag: str, scalar_values: Dict[str, float]) -> None: 33 | r"""Writes scalar data""" 34 | -------------------------------------------------------------------------------- /xfuse/messengers/stats/writer/tensorboard.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | from torch.utils.tensorboard.writer import SummaryWriter 5 | 6 | from . import StatsWriter 7 | from ....logging import DEBUG, log 8 | from ....session import get 9 | 10 | 11 | __all__ = ["TensorboardWriter"] 12 | 13 | 14 | class TensorboardWriter(StatsWriter): 15 | r"""Tensorboard stats writer""" 16 | 17 | def __init__(self): 18 | self.__summary_writer = None 19 | 20 | @property 21 | def _summary_writer(self) -> SummaryWriter: 22 | log_dir = get("work_dir").full_path 23 | if ( 24 | self.__summary_writer is None 25 | or self.__summary_writer.log_dir != log_dir 26 | ): 27 | log(DEBUG, "Creating new SummaryWriter (log_dir = %s)", log_dir) 28 | self.__summary_writer = SummaryWriter(log_dir=log_dir) 29 | return self.__summary_writer 30 | 31 | def write_histogram( 32 | self, tag: str, values: torch.Tensor, bins: Optional[int] = None 33 | ) -> None: 34 | r"""Logs a histogram""" 35 | self._summary_writer.add_histogram( 36 | tag, values, global_step=get("training_data").step 37 | ) 38 | 39 | def write_image(self, tag: str, img_tensor: torch.Tensor) -> None: 40 | r"""Logs an image""" 41 | self._summary_writer.add_image( 42 | tag, 43 | img_tensor, 44 | global_step=get("training_data").step, 45 | dataformats="HWC", 46 | ) 47 | 48 | def write_images(self, tag: str, img_tensor: torch.Tensor) -> None: 49 | r"""Logs an image grid""" 50 | self._summary_writer.add_images( 51 | tag, 52 | img_tensor, 53 | global_step=get("training_data").step, 54 | dataformats="NHWC", 55 | ) 56 | 57 | def write_scalar(self, tag: str, scalar_value: float) -> None: 58 | r"""Logs a scalar""" 59 | self._summary_writer.add_scalar( 60 | tag, scalar_value, global_step=get("training_data").step 61 | ) 62 | 63 | def write_scalars(self, tag: str, scalar_values: Dict[str, float]) -> None: 64 | r"""Logs a set of associated scalars""" 65 | self._summary_writer.add_scalars( 66 | tag, scalar_values, global_step=get("training_data").step 67 | ) 68 | -------------------------------------------------------------------------------- /xfuse/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .xfuse import * 2 | from . import experiment 3 | from . import utility 4 | -------------------------------------------------------------------------------- /xfuse/model/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | from .experiment import * 2 | -------------------------------------------------------------------------------- /xfuse/model/experiment/experiment.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, abstractproperty 2 | 3 | import torch 4 | 5 | 6 | class Experiment(torch.nn.Module): 7 | r"""Abstract class defining the experiment type""" 8 | 9 | @property 10 | def num_z(self): 11 | r"""Number of independent tissue states""" 12 | return 1 13 | 14 | @abstractproperty 15 | def tag(self): 16 | r"""Experiment name""" 17 | 18 | @abstractmethod 19 | def model(self, x, zs): 20 | r"""Experiment model""" 21 | 22 | @abstractmethod 23 | def guide(self, x): 24 | r"""Experiment guide for :class:`pyro.infer.SVI`""" 25 | 26 | def forward(self, x, zs): 27 | r"""Alias for :func:`model`""" 28 | # pylint: disable=arguments-differ 29 | return self.model(x, zs) 30 | -------------------------------------------------------------------------------- /xfuse/model/experiment/image.py: -------------------------------------------------------------------------------- 1 | import pyro 2 | import torch 3 | from pyro.distributions import Normal # pylint: disable=no-name-in-module 4 | 5 | from ...utility import center_crop 6 | from ...utility.state import get_module 7 | from . import Experiment 8 | 9 | 10 | class Image(Experiment): 11 | r"""Image experiment""" 12 | 13 | def __init__(self, *args, depth=4, num_channels=8, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.depth = depth 16 | self.num_channels = num_channels 17 | 18 | @property 19 | def num_z(self): 20 | return self.depth 21 | 22 | @property 23 | def tag(self): 24 | return "image" 25 | 26 | def _decode(self, zs): 27 | # pylint: disable=no-self-use 28 | def _decode(y, i): 29 | decoder = get_module( 30 | f"img-decoder-{i}", 31 | lambda: torch.nn.Sequential( 32 | torch.nn.Conv2d( 33 | y.shape[1], y.shape[1], kernel_size=3, padding=1 34 | ), 35 | torch.nn.BatchNorm2d(y.shape[1], momentum=0.05), 36 | torch.nn.LeakyReLU(0.2, inplace=True), 37 | torch.nn.Conv2d( 38 | y.shape[1], y.shape[1], kernel_size=3, padding=1 39 | ), 40 | torch.nn.BatchNorm2d(y.shape[1], momentum=0.05), 41 | torch.nn.LeakyReLU(0.2, inplace=True), 42 | ), 43 | checkpoint=True, 44 | ) 45 | return decoder(y) 46 | 47 | def _combine(y, z, i): 48 | combiner = get_module( 49 | f"img-combiner-{i}", 50 | lambda: torch.nn.Sequential( 51 | torch.nn.Conv2d( 52 | y.shape[1] + z.shape[1], 53 | z.shape[1], 54 | kernel_size=3, 55 | padding=1, 56 | ), 57 | torch.nn.BatchNorm2d(z.shape[1], momentum=0.05), 58 | torch.nn.LeakyReLU(0.2, inplace=True), 59 | ), 60 | checkpoint=True, 61 | ) 62 | y = center_crop(y, [None, None, *z.shape[-2:]]) 63 | return combiner(torch.cat([y, z], 1)) 64 | 65 | def _upsample(y, i): 66 | upsampler = get_module( 67 | f"upsampler-{i}", 68 | lambda: torch.nn.Sequential( 69 | torch.nn.Upsample( 70 | scale_factor=2.0, mode="bilinear", align_corners=False 71 | ), 72 | torch.nn.Conv2d( 73 | y.shape[1], y.shape[1] // 2, kernel_size=5, padding=2 74 | ), 75 | torch.nn.BatchNorm2d(y.shape[1] // 2, momentum=0.05), 76 | torch.nn.LeakyReLU(0.2, inplace=True), 77 | ), 78 | checkpoint=True, 79 | ) 80 | return upsampler(y) 81 | 82 | y = _decode(zs[-1], self.depth - 1) 83 | for i, z in zip(reversed(range(self.depth - 1)), zs[::-1][1:]): 84 | y = _decode(_combine(_upsample(y, i), z, i), i) 85 | 86 | return y 87 | 88 | def _encode(self, x): 89 | def _encode(x, i): 90 | encoder = get_module( 91 | f"encoder-{i}", 92 | lambda: torch.nn.Sequential( 93 | torch.nn.Conv2d( 94 | x.shape[1], x.shape[1], kernel_size=3, padding=1 95 | ), 96 | torch.nn.BatchNorm2d(x.shape[1], momentum=0.05), 97 | torch.nn.LeakyReLU(0.2, inplace=True), 98 | torch.nn.Conv2d( 99 | x.shape[1], x.shape[1], kernel_size=3, padding=1 100 | ), 101 | torch.nn.BatchNorm2d(x.shape[1], momentum=0.05), 102 | torch.nn.LeakyReLU(0.2, inplace=True), 103 | ), 104 | checkpoint=True, 105 | ) 106 | return encoder(x) 107 | 108 | def _downsample(x, i): 109 | downsampler = get_module( 110 | f"downsampler-{i}", 111 | lambda: torch.nn.Sequential( 112 | torch.nn.Conv2d( 113 | x.shape[1], 114 | 2 * x.shape[1], 115 | kernel_size=5, 116 | stride=2, 117 | padding=2, 118 | ), 119 | torch.nn.BatchNorm2d(2 * x.shape[1], momentum=0.05), 120 | torch.nn.LeakyReLU(0.2, inplace=True), 121 | ), 122 | checkpoint=True, 123 | ) 124 | return downsampler(x) 125 | 126 | preencoder = get_module( 127 | "preencoder", 128 | lambda: torch.nn.Sequential( 129 | torch.nn.Conv2d( 130 | x.shape[1], self.num_channels, kernel_size=3, padding=1 131 | ), 132 | torch.nn.BatchNorm2d(self.num_channels, momentum=0.05), 133 | torch.nn.LeakyReLU(0.2, inplace=True), 134 | ), 135 | ) 136 | 137 | ys = [_encode(preencoder(x), 0)] 138 | for i in range(1, self.depth): 139 | ys.append(_encode(_downsample(ys[-1], i), i)) 140 | 141 | return ys 142 | 143 | def _sample_image(self, x, decoded): 144 | def _create_mu_decoder(): 145 | decoder = torch.nn.Sequential( 146 | torch.nn.Conv2d( 147 | self.num_channels, self.num_channels, kernel_size=1 148 | ), 149 | torch.nn.BatchNorm2d(self.num_channels, momentum=0.05), 150 | torch.nn.LeakyReLU(0.2, inplace=True), 151 | torch.nn.Conv2d( 152 | self.num_channels, x["image"].shape[1], kernel_size=1 153 | ), 154 | torch.nn.Tanh(), 155 | ) 156 | torch.nn.init.constant_(decoder[-2].weight, 0.0) 157 | mean = x["image"].mean((0, 2, 3)) 158 | decoder[-2].bias.data = ((1 + mean) / (1 - mean)).log() / 2 159 | return decoder 160 | 161 | def _create_sd_decoder(): 162 | decoder = torch.nn.Sequential( 163 | torch.nn.Conv2d( 164 | self.num_channels, self.num_channels, kernel_size=1 165 | ), 166 | torch.nn.BatchNorm2d(self.num_channels, momentum=0.05), 167 | torch.nn.LeakyReLU(0.2, inplace=True), 168 | torch.nn.Conv2d( 169 | self.num_channels, x["image"].shape[1], kernel_size=1 170 | ), 171 | torch.nn.Softplus(), 172 | ) 173 | torch.nn.init.constant_(decoder[-2].weight, 0.0) 174 | std = x["image"].std((0, 2, 3)) 175 | decoder[-2].bias.data = (std.exp() - 1).log() 176 | return decoder 177 | 178 | img_mu = get_module("img_mu", _create_mu_decoder, checkpoint=True) 179 | img_sd = get_module("img_sd", _create_sd_decoder, checkpoint=True) 180 | mu = img_mu(decoded) 181 | sd = img_sd(decoded) 182 | 183 | image_distr = Normal(mu, 1e-8 + sd).to_event(3) 184 | pyro.sample( 185 | "image", 186 | image_distr, 187 | obs=center_crop(x["image"], image_distr.shape()), 188 | ) 189 | 190 | def model(self, x, zs): 191 | decoded = self._decode(zs) 192 | self._sample_image(x, decoded) 193 | 194 | def guide(self, x): 195 | return self._encode(x["image"]) 196 | -------------------------------------------------------------------------------- /xfuse/model/experiment/st/__init__.py: -------------------------------------------------------------------------------- 1 | from .st import * 2 | -------------------------------------------------------------------------------- /xfuse/model/experiment/st/metagene_eval.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Callable, NoReturn, Union, cast 3 | 4 | import numpy as np 5 | import pyro as p 6 | from pyro.poutine.messenger import Messenger 7 | 8 | from ....logging import INFO, WARNING, log 9 | from ....session.session import Session, require 10 | from ....utility.tensor import to_device 11 | from ... import XFuse 12 | from ...utility import compare 13 | from . import ST 14 | from .metagene_expansion_strategy import ExpansionStrategy 15 | 16 | 17 | def purge_metagenes(num_samples: int = 1) -> None: 18 | r""" 19 | Purges superfluous metagenes and adds new ones based on the 20 | `metagene_expansion_strategy` of the current :class:`Session` 21 | """ 22 | 23 | log(INFO, "Evaluating metagenes") 24 | 25 | xfuse: XFuse = require("model") 26 | metagene_expansion_strategy: ExpansionStrategy = require( 27 | "metagene_expansion_strategy" 28 | ) 29 | 30 | def _xfuse_without(n): 31 | reduced_xfuse = deepcopy(xfuse) 32 | reduced_xfuse.get_experiment("ST").remove_metagene( 33 | n, remove_params=False 34 | ) 35 | return reduced_xfuse 36 | 37 | with Session(log_level=WARNING, eval=True): 38 | st_experiment = cast(ST, xfuse.get_experiment("ST")) 39 | metagenes = st_experiment.metagenes 40 | 41 | if len(metagenes) == 1: 42 | contrib = list(metagenes) 43 | noncontrib = [] 44 | else: 45 | reduced_models, ns = zip( 46 | *[(_xfuse_without(n).model, n) for n in metagenes] 47 | ) 48 | 49 | def _eval_on(x): 50 | def _sample_once(): 51 | guide = p.poutine.trace(xfuse.guide).get_trace(x) 52 | full, *reduced = compare( 53 | x, guide, xfuse.model, *reduced_models 54 | ) 55 | return [x - full for x in reduced] 56 | 57 | res = [_sample_once() for _ in range(num_samples)] 58 | return np.mean(res, 0) 59 | 60 | dataloader = require("dataloader") 61 | scores = np.mean([_eval_on(to_device(x)) for x in dataloader], 0) 62 | 63 | noncontrib = [ 64 | n for res, n in reversed(sorted(zip(scores, ns))) if res >= 0 65 | ] 66 | contrib = [n for n in ns if n not in noncontrib] 67 | 68 | log( 69 | INFO, 70 | "Contributing metagenes: %s", 71 | ", ".join(contrib) if contrib != [] else "-", 72 | ) 73 | log( 74 | INFO, 75 | "Non-contributing metagenes: %s", 76 | ", ".join(noncontrib) if noncontrib != [] else "-", 77 | ) 78 | 79 | metagene_expansion_strategy(st_experiment, contrib, noncontrib) 80 | 81 | 82 | class MetagenePurger(Messenger): 83 | r""" 84 | Runs :func:`purge_metagenes` at a fixed interval to purge superfluous 85 | metagenes or add new ones 86 | """ 87 | 88 | def __init__( 89 | self, period: Union[int, Callable[[int], bool]] = 1, **kwargs: Any 90 | ): 91 | super().__init__() 92 | self._predicate = ( 93 | period 94 | if callable(period) 95 | else lambda epoch: epoch % cast(int, period) == 0 96 | ) 97 | self._kwargs = kwargs 98 | 99 | def _handle(self, **_msg) -> NoReturn: 100 | # pylint: disable=no-self-use 101 | raise RuntimeError("Unreachable code path") 102 | 103 | def _select_msg(self, **_msg) -> bool: 104 | # pylint: disable=no-self-use 105 | return False 106 | 107 | def _pyro_post_epoch(self, msg) -> None: 108 | if self._predicate(msg["kwargs"]["epoch"]): 109 | with Session(messengers=[]): 110 | purge_metagenes(**self._kwargs) 111 | -------------------------------------------------------------------------------- /xfuse/model/experiment/st/metagene_expansion_strategy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from functools import reduce 3 | from inspect import isabstract, isclass 4 | from typing import Callable, List, Optional, Set 5 | 6 | import numpy as np 7 | 8 | from ....logging import DEBUG, log 9 | from ....session import SessionItem, register_session_item, get 10 | from . import ST 11 | 12 | 13 | class ExpansionStrategy(ABC): 14 | r"""Abstract base class for metagene expansion strategies""" 15 | 16 | @abstractmethod 17 | def __call__( 18 | self, 19 | experiment: ST, 20 | contributing_metagenes: List[str], 21 | noncontributing_metagenes: List[str], 22 | ) -> None: 23 | pass 24 | 25 | 26 | class Extra(ExpansionStrategy): 27 | r""" 28 | An :class:`ExpansionStrategy` that keeps a number of "extra", 29 | non-contributing metagenes around. The number of extra metagenes can be 30 | fixed or linearly annealed. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | num_metagenes: int = 4, 36 | anneal_to: Optional[int] = 1, 37 | anneal_epochs: Optional[int] = 10000, 38 | ): 39 | self.__num_from = num_metagenes 40 | self.__num_to = anneal_to 41 | self.__anneal_epochs = anneal_epochs 42 | 43 | @property 44 | def num(self): 45 | r"""The current number of metagenes""" 46 | if self.__num_to is None or self.__anneal_epochs is None: 47 | return self.__num_from 48 | training_data = get("training_data") 49 | q = min(1.0, training_data.epoch / self.__anneal_epochs) 50 | return round((1 - q) * self.__num_from + q * self.__num_to) 51 | 52 | def __call__( 53 | self, 54 | experiment: ST, 55 | contributing_metagenes: List[str], 56 | noncontributing_metagenes: List[str], 57 | ) -> None: 58 | n_missing = self.num - len(noncontributing_metagenes) 59 | for _ in range(n_missing): 60 | experiment.add_metagene() 61 | for n, _ in zip(noncontributing_metagenes, range(-n_missing)): 62 | experiment.remove_metagene(n, remove_params=False) 63 | 64 | 65 | class _Node: 66 | @abstractmethod 67 | def get_nodes(self): 68 | r"""Get all child nodes""" 69 | 70 | 71 | class _Split(_Node): 72 | def __init__(self, a: _Node, b: _Node): 73 | self.a = a 74 | self.b = b 75 | 76 | def get_nodes(self): 77 | return [*self.a.get_nodes(), *self.b.get_nodes()] 78 | 79 | 80 | class _Leaf(_Node): 81 | def __init__(self, name: str, contributing: bool = False): 82 | self.name = name 83 | self.contributing = contributing 84 | 85 | def get_nodes(self): 86 | return [self.name] 87 | 88 | 89 | def _map_modify(root: _Node, fn: Callable[[_Leaf], None]): 90 | if isinstance(root, _Leaf): 91 | fn(root) 92 | return 93 | if isinstance(root, _Split): 94 | _map_modify(root.a, fn) 95 | _map_modify(root.b, fn) 96 | return 97 | raise NotImplementedError() 98 | 99 | 100 | def _show(root: _Node) -> str: 101 | if isinstance(root, _Split): 102 | return f"({_show(root.a)}), ({_show(root.b)})" 103 | if isinstance(root, _Leaf): 104 | return f"{root.name}: {root.contributing}" 105 | raise NotImplementedError() 106 | 107 | 108 | class DropAndSplit(ExpansionStrategy): 109 | r""" 110 | An :class:`ExpansionStrategy` that splits contributing metagenes and merges 111 | back previously split, non-contributing metagenes 112 | """ 113 | 114 | def __init__(self, max_metagenes: int = 50): 115 | self._root_nodes: Set[_Node] = set() 116 | self._max_metagenes = max_metagenes 117 | 118 | def __call__( 119 | self, 120 | experiment: ST, 121 | contributing_metagenes: List[str], 122 | noncontributing_metagenes: List[str], 123 | ) -> None: 124 | contrib = set(contributing_metagenes) 125 | noncontrib = set(noncontributing_metagenes) 126 | 127 | def _set_contributing(x: _Leaf): 128 | x.contributing = x.name in contrib 129 | 130 | def _drop_nonexistant_branches(root: _Node) -> Optional[_Node]: 131 | if isinstance(root, _Split): 132 | a = _drop_nonexistant_branches(root.a) 133 | b = _drop_nonexistant_branches(root.b) 134 | if a and b: 135 | return root 136 | if a and not b: 137 | return a 138 | if b and not a: 139 | return b 140 | return None 141 | if isinstance(root, _Leaf): 142 | if root.name in set.union(contrib, noncontrib): 143 | return root 144 | return None 145 | raise NotImplementedError() 146 | 147 | def _drop_noncontributing_branches(root: _Node) -> _Node: 148 | if isinstance(root, _Split): 149 | if isinstance(root.b, _Leaf) and not root.b.contributing: 150 | return _drop_noncontributing_branches(root.a) 151 | if isinstance(root.a, _Leaf) and not root.a.contributing: 152 | return _drop_noncontributing_branches(root.b) 153 | return _Split( 154 | *np.random.permutation( 155 | [ 156 | _drop_noncontributing_branches(root.a), 157 | _drop_noncontributing_branches(root.b), 158 | ] 159 | ) 160 | ) 161 | if isinstance(root, _Leaf): 162 | return root 163 | raise NotImplementedError() 164 | 165 | def _extend_contributing_branches(root: _Node) -> _Node: 166 | if isinstance(root, _Split): 167 | if isinstance(root.a, _Leaf) and isinstance(root.b, _Leaf): 168 | if root.a.contributing and root.b.contributing: 169 | return _Split( 170 | _extend_contributing_branches(root.a), 171 | _extend_contributing_branches(root.b), 172 | ) 173 | return root 174 | return _Split( 175 | *[ 176 | _extend_contributing_branches(node) 177 | if not isinstance(node, _Leaf) or node.contributing 178 | else node 179 | for node in (root.a, root.b) 180 | ] 181 | ) 182 | if isinstance(root, _Leaf): 183 | if root.contributing and ( 184 | self._max_metagenes <= 0 185 | or len(experiment.metagenes) < self._max_metagenes 186 | ): 187 | return _Split( 188 | root, _Leaf(experiment.split_metagene(root.name)) 189 | ) 190 | return root 191 | raise NotImplementedError() 192 | 193 | def _log_trees(title: str): 194 | log(DEBUG, "%s:", title) 195 | for tree in self._root_nodes: 196 | log(DEBUG, " %s", _show(tree)) 197 | 198 | self._root_nodes = { 199 | tree 200 | for tree in map(_drop_nonexistant_branches, self._root_nodes) 201 | if tree is not None 202 | } 203 | 204 | for tree in self._root_nodes: 205 | _map_modify(tree, _set_contributing) 206 | 207 | _log_trees("Trees before retraction") 208 | 209 | self._root_nodes = set( 210 | map(_drop_noncontributing_branches, self._root_nodes) 211 | ) 212 | 213 | # Remove non-contributing trees, keeping at least one 214 | noncontributing = [ 215 | x 216 | for x in self._root_nodes 217 | if isinstance(x, _Leaf) 218 | if not x.contributing 219 | ] 220 | for tree in noncontributing[: len(self._root_nodes) - 1]: 221 | self._root_nodes.discard(tree) 222 | 223 | _log_trees("Trees after retraction / before splitting") 224 | 225 | forest: Set[str] = reduce( 226 | set.union, (x.get_nodes() for x in self._root_nodes), set() 227 | ) 228 | for x in contrib: 229 | if x not in forest: 230 | log(DEBUG, "Adding new root node: %s", x) 231 | self._root_nodes.add(_Leaf(x, True)) 232 | for x in noncontrib: 233 | if x not in forest: 234 | experiment.remove_metagene(x, remove_params=True) 235 | 236 | self._root_nodes = set( 237 | map(_extend_contributing_branches, self._root_nodes) 238 | ) 239 | 240 | _log_trees("Trees after splitting") 241 | 242 | 243 | register_session_item( 244 | "metagene_expansion_strategy", 245 | SessionItem(setter=lambda _: None, default=None, persistent=True), 246 | ) 247 | 248 | 249 | STRATEGIES = { 250 | x.__name__: x 251 | for x in locals().values() 252 | if isclass(x) 253 | if issubclass(x, ExpansionStrategy) 254 | if not isabstract(x) 255 | } 256 | -------------------------------------------------------------------------------- /xfuse/model/utility/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_comparison import * 2 | -------------------------------------------------------------------------------- /xfuse/model/utility/model_comparison.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pyro 4 | 5 | from ...session import Session 6 | 7 | 8 | def compare(data, guide, *models) -> List[float]: 9 | r""" 10 | Returns the ELBO of given models on provided data using the same guide 11 | trace 12 | """ 13 | 14 | def _evaluate(model): 15 | with pyro.poutine.trace() as trace: 16 | with pyro.poutine.replay(trace=guide): 17 | model(data) 18 | return ( 19 | trace.trace.log_prob_sum().item() 20 | - guide.log_prob_sum( 21 | site_filter=lambda name, site: name in trace.trace.nodes 22 | ).item() 23 | ) 24 | 25 | with Session(eval=True): 26 | result = [_evaluate(model) for model in models] 27 | return result 28 | -------------------------------------------------------------------------------- /xfuse/model/xfuse.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import pyro as p 4 | from pyro.distributions import Normal # pylint: disable=no-name-in-module 5 | from pyro.poutine.messenger import Messenger 6 | 7 | import torch 8 | 9 | from .experiment import Experiment 10 | from ..logging import INFO, log 11 | from ..utility.state import get_module 12 | from ..utility.tensor import find_device 13 | 14 | 15 | class ModelWrapper(Messenger): 16 | r""":class:`Messenger` wrapping the model to add context information.""" 17 | 18 | def _process_message(self, msg): 19 | msg["is_guide"] = False 20 | 21 | 22 | class GuideWrapper(Messenger): 23 | r""":class:`Messenger` wrapping the guide to add context information.""" 24 | 25 | def _process_message(self, msg): 26 | msg["is_guide"] = True 27 | 28 | 29 | class XFuse(torch.nn.Module): 30 | r"""XFuse""" 31 | 32 | def __init__(self, experiments: List[Experiment]): 33 | super().__init__() 34 | self.__experiment_store: Dict[str, Experiment] = {} 35 | for experiment in experiments: 36 | self.register_experiment(experiment) 37 | 38 | @property 39 | def experiments(self): 40 | r"""Returns the registered experiments""" 41 | return self.__experiment_store.copy() 42 | 43 | def get_experiment(self, experiment_type: str) -> Experiment: 44 | r"""Get registered :class:`Experiment` by tag""" 45 | 46 | try: 47 | return self.__experiment_store[experiment_type] 48 | except KeyError as exc: 49 | raise RuntimeError( 50 | f"Unknown experiment type: {experiment_type}" 51 | ) from exc 52 | 53 | def register_experiment(self, experiment: Experiment) -> None: 54 | r"""Registered :class:`Experiment`""" 55 | 56 | if experiment.tag in self.__experiment_store: 57 | raise RuntimeError( 58 | f'Model for data type "{experiment.tag}" already registered' 59 | ) 60 | log( 61 | INFO, 62 | 'Registering experiment: %s (data type: "%s")', 63 | type(experiment).__name__, 64 | experiment.tag, 65 | ) 66 | self.add_module(experiment.tag, experiment) 67 | self.__experiment_store[experiment.tag] = experiment 68 | 69 | def forward(self, *input): 70 | r"""Alias for :func:`model`""" 71 | # pylint: disable=redefined-builtin 72 | return self.model(*input) 73 | 74 | @ModelWrapper() 75 | def model(self, xs): 76 | r"""Runs XFuse on the given data""" 77 | 78 | def _go(experiment, x): 79 | zs = [ 80 | p.sample( 81 | f"z-{experiment.tag}-{i}", 82 | ( 83 | # pylint: disable=not-callable 84 | Normal(torch.tensor(0.0, device=find_device(x)), 1.0) 85 | .expand([1, 1, 1, 1]) 86 | .to_event(3) 87 | ), 88 | ) 89 | for i in range(experiment.num_z) 90 | ] 91 | experiment.model(x, zs) 92 | 93 | for experiment, x in xs.items(): 94 | _go(self.get_experiment(experiment), x) 95 | 96 | @GuideWrapper() 97 | def guide(self, xs): 98 | r""" 99 | Runs the :class:`pyro.infer.SVI` `guide` for XFuse on the given data 100 | """ 101 | 102 | def _go(experiment, x): 103 | def _sample(name, y): 104 | z_mu = get_module( 105 | f"{name}-mu", 106 | lambda: torch.nn.Sequential( 107 | torch.nn.Conv2d(y.shape[1], y.shape[1], 1), 108 | torch.nn.BatchNorm2d(y.shape[1], momentum=0.05), 109 | torch.nn.LeakyReLU(0.2, inplace=True), 110 | torch.nn.Conv2d(y.shape[1], y.shape[1], 1), 111 | ), 112 | checkpoint=True, 113 | ) 114 | z_sd = get_module( 115 | f"{name}-sd", 116 | lambda: torch.nn.Sequential( 117 | torch.nn.Conv2d(y.shape[1], y.shape[1], 1), 118 | torch.nn.BatchNorm2d(y.shape[1], momentum=0.05), 119 | torch.nn.LeakyReLU(0.2, inplace=True), 120 | torch.nn.Conv2d(y.shape[1], y.shape[1], 1), 121 | torch.nn.Softplus(), 122 | ), 123 | checkpoint=True, 124 | ) 125 | return p.sample( 126 | name, Normal(z_mu(y), 1e-8 + z_sd(y)).to_event(3) 127 | ) 128 | 129 | for i, y in enumerate(experiment.guide(x)): 130 | _sample(f"z-{experiment.tag}-{i}", y) 131 | 132 | for experiment, x in xs.items(): 133 | _go(self.get_experiment(experiment), x) 134 | -------------------------------------------------------------------------------- /xfuse/optim.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from inspect import signature 3 | 4 | import pyro 5 | 6 | from .logging import DEBUG, log 7 | from .utility.state.getters import get_param_optim_args 8 | from .session import get 9 | 10 | 11 | __all__ = [] 12 | 13 | 14 | def _make_wrapped_constructor(constructor): 15 | @wraps(constructor) 16 | def _constructor(default_optim_args, *args, **kwargs): 17 | def _optim_args(_module_name, param_name): 18 | optim_args = default_optim_args.copy() 19 | try: 20 | param_optim_args = get_param_optim_args(param_name) 21 | except KeyError: 22 | param_optim_args = {} 23 | if "lr" not in optim_args: 24 | optim_args["lr"] = get("learning_rate") 25 | for k, value in param_optim_args.items(): 26 | if k == "lr_multiplier": 27 | log( 28 | DEBUG, 29 | f"Adjusting learning rate to {optim_args['lr']*value=}" 30 | ' for parameter "%s"', 31 | param_name, 32 | ) 33 | optim_args["lr"] *= value 34 | else: 35 | raise RuntimeError(f'Unknown optim arg "{k}"') 36 | return optim_args 37 | 38 | return constructor(_optim_args, *args, **kwargs) 39 | 40 | return _constructor 41 | 42 | 43 | for __name, __constructor in [ 44 | (k, v) 45 | for k, v in pyro.optim.__dict__.items() 46 | if callable(v) and "optim_args" in signature(v).parameters.keys() 47 | ]: 48 | locals()[__name] = _make_wrapped_constructor(__constructor) 49 | __all__.append(__name) 50 | -------------------------------------------------------------------------------- /xfuse/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import warnings 4 | from functools import partial, reduce 5 | from operator import add 6 | from typing import Any, Dict, Optional, Tuple 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from ._config import _ANNOTATED_CONFIG as CONFIG # type: ignore 12 | from .analyze import analyses as _analyses 13 | from .data import Data, Dataset 14 | from .data.slide import RandomIterator, Slide, STSlide 15 | from .data.utility.misc import make_dataloader 16 | from .logging import INFO, log 17 | from .model import XFuse 18 | from .model.experiment.st import ST as STExperiment 19 | from .model.experiment.st.metagene_expansion_strategy import ( 20 | STRATEGIES, 21 | ExpansionStrategy, 22 | ) 23 | from .model.experiment.st.metagene_eval import MetagenePurger, purge_metagenes 24 | from .model.experiment.st.metagene_expansion_strategy import Extra 25 | from .optim import Adam # type: ignore # pylint: disable=no-name-in-module 26 | from .session import Session, get 27 | from .train import test_convergence, train 28 | from .utility.file import chdir 29 | from .session.io import save_session 30 | 31 | 32 | def run( 33 | design: pd.DataFrame, 34 | slide_paths: Dict[str, str], 35 | analyses: Optional[Dict[str, Tuple[str, Dict[str, Any]]]] = None, 36 | expansion_strategy: ExpansionStrategy = STRATEGIES[ 37 | CONFIG["expansion_strategy"].value["type"].value 38 | ](), 39 | purge_interval: int = ( 40 | CONFIG["expansion_strategy"].value["purge_interval"].value 41 | ), 42 | network_depth: int = CONFIG["xfuse"].value["network_depth"].value, 43 | network_width: int = CONFIG["xfuse"].value["network_width"].value, 44 | min_counts: int = CONFIG["xfuse"].value["min_counts"].value, 45 | gene_regex: str = CONFIG["xfuse"].value["min_counts"].value, 46 | patch_size: int = CONFIG["optimization"].value["patch_size"].value, 47 | batch_size: int = CONFIG["optimization"].value["batch_size"].value, 48 | epochs: int = CONFIG["optimization"].value["epochs"].value, 49 | learning_rate: float = CONFIG["optimization"].value["learning_rate"].value, 50 | cache_data: bool = CONFIG["settings"].value["cache_data"].value, 51 | num_data_workers: int = CONFIG["settings"].value["data_workers"].value, 52 | slide_options: Optional[Dict[str, Any]] = None, 53 | ): 54 | r"""Runs an analysis""" 55 | 56 | # pylint: disable=too-many-arguments,too-many-locals 57 | 58 | if analyses is None: 59 | analyses = {} 60 | 61 | if slide_options is None: 62 | slide_options = {} 63 | 64 | if (available_cores := len(os.sched_getaffinity(0))) < num_data_workers: 65 | warnings.warn( 66 | " ".join( 67 | [ 68 | f"Available cores ({available_cores}) is less than the" 69 | f" requested number of workers ({num_data_workers}).", 70 | f" Setting the number of workers to {available_cores}.", 71 | ] 72 | ), 73 | ) 74 | num_data_workers = available_cores 75 | 76 | slides = { 77 | slide: Slide( 78 | data=STSlide( 79 | slide_paths[slide], 80 | cache_data=cache_data, 81 | **(slide_options[slide] if slide_options is not None else {}), 82 | ), 83 | iterator=partial( 84 | RandomIterator, 85 | patch_size=( 86 | None if patch_size < 0 else (patch_size, patch_size) 87 | ), 88 | ), 89 | ) 90 | for slide in design.index 91 | } 92 | dataset = Dataset(data=Data(slides=slides, design=design)) 93 | dataloader = make_dataloader( 94 | dataset, 95 | batch_size=batch_size if batch_size < len(dataset) else len(dataset), 96 | shuffle=True, 97 | num_workers=num_data_workers, 98 | drop_last=True, 99 | ) 100 | 101 | genes = get("genes") 102 | if genes is None: 103 | summed_counts = reduce( 104 | add, 105 | [ 106 | np.array(slide.data.counts.sum(0)).flatten() 107 | for slide in dataset.data.slides.values() 108 | ], 109 | ) 110 | filtered_genes = set( 111 | g for g, x in zip(dataset.genes, summed_counts) if x < min_counts 112 | ) 113 | filtered_genes = filtered_genes | set( 114 | g for g in dataset.genes if not re.match(gene_regex, g) 115 | ) 116 | 117 | if len(filtered_genes) > 0: 118 | log( 119 | INFO, 120 | "The following %d genes have been filtered out: %s", 121 | len(filtered_genes), 122 | ", ".join(sorted(filtered_genes)), 123 | ) 124 | genes = sorted(set(dataset.genes) - filtered_genes) 125 | log( 126 | INFO, 127 | "Using the following set of %d genes: %s", 128 | len(genes), 129 | ", ".join(genes), 130 | ) 131 | 132 | xfuse = get("model") 133 | if xfuse is None: 134 | st_experiment = STExperiment( 135 | depth=network_depth, num_channels=network_width, 136 | ) 137 | xfuse = XFuse(experiments=[st_experiment]).to(get("default_device")) 138 | 139 | optimizer = get("optimizer") 140 | if optimizer is None: 141 | optimizer = Adam({"amsgrad": True}) 142 | 143 | def _panic(_session, _err_type, _err, _tb): 144 | with chdir("/"): 145 | save_session("exception") 146 | 147 | with Session( 148 | model=xfuse, 149 | genes=genes, 150 | learning_rate=learning_rate, 151 | messengers=[ 152 | MetagenePurger( 153 | period=lambda e: ( 154 | purge_interval > 0 155 | and e % purge_interval == 0 156 | and (epochs < 0 or e <= epochs - purge_interval) 157 | ), 158 | num_samples=3, 159 | ), 160 | *get("messengers"), 161 | ], 162 | metagene_expansion_strategy=expansion_strategy, 163 | optimizer=optimizer, 164 | dataloader=dataloader, 165 | panic=_panic, 166 | ): 167 | has_converged = ( 168 | test_convergence() 169 | if epochs < 0 170 | else get("training_data").epoch >= epochs 171 | ) 172 | if not has_converged: 173 | train(epochs) 174 | with Session(model=xfuse, metagene_expansion_strategy=Extra(0)): 175 | try: 176 | purge_metagenes(num_samples=10) 177 | except RuntimeError as exc: 178 | if "Cannot remove last metagene" in str(exc): 179 | warnings.warn("Failed to find metagenes") 180 | else: 181 | raise 182 | save_session("final") 183 | 184 | for name, (analysis_type, options) in analyses.items(): 185 | if analysis_type in _analyses: 186 | log(INFO, 'Running analysis "%s"', name) 187 | with Session( 188 | model=xfuse, dataloader=dataloader, genes=genes, messengers=[], 189 | ): 190 | with chdir(f"analyses/final/{name}"): 191 | _analyses[analysis_type].function(**options) 192 | else: 193 | warnings.warn(f'Unknown analysis "{analysis_type}"') 194 | -------------------------------------------------------------------------------- /xfuse/session/__init__.py: -------------------------------------------------------------------------------- 1 | from .session_item import * 2 | from .session import * 3 | from . import items 4 | -------------------------------------------------------------------------------- /xfuse/session/io.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | 7 | from _io import BufferedReader 8 | 9 | from . import Session, get_session 10 | from .session import _SESSION_STORE 11 | from ..logging import INFO, log 12 | from ..utility.file import first_unique_filename 13 | from ..utility.state.state import get_state_dict, load_state_dict 14 | 15 | __all__ = ["load_session", "save_session"] 16 | 17 | 18 | def save_session(filename_prefix: str) -> None: 19 | r"""Saves the current :class:`Session`""" 20 | 21 | def _can_pickle(name, x): 22 | try: 23 | _ = pickle.dumps(x) 24 | except pickle.PickleError as exc: 25 | warnings.warn( 26 | f'Session item "{name}" cannot be saved.' 27 | f" The error returned was: {str(exc)}", 28 | ) 29 | return False 30 | return True 31 | 32 | session = Session( 33 | **{ 34 | k: v 35 | for k, v in iter(get_session()) 36 | if _SESSION_STORE[k].persistent 37 | if v is not None 38 | if _can_pickle(k, v) 39 | } 40 | ) 41 | 42 | path = first_unique_filename(f"{filename_prefix}.session") 43 | log(INFO, "Saving session to %s", path) 44 | torch.save((session, get_state_dict()), path) 45 | 46 | 47 | def load_session(file: Union[str, BufferedReader]) -> Session: 48 | r"""Loads :class:`Session` from a file""" 49 | log( 50 | INFO, 51 | "Loading session from %s", 52 | file.name if isinstance(file, BufferedReader) else file, 53 | ) 54 | session, state_dict = torch.load(file, map_location="cpu") 55 | with session: 56 | load_state_dict(state_dict) 57 | return session 58 | -------------------------------------------------------------------------------- /xfuse/session/items/__init__.py: -------------------------------------------------------------------------------- 1 | from .colormap import * 2 | from .covariates import * 3 | from .dataloader import * 4 | from .default_device import * 5 | from .eval import * 6 | from .genes import * 7 | from .learning_rate import * 8 | from .log_file import * 9 | from .log_level import * 10 | from .messengers import * 11 | from .model import * 12 | from .mpl_backend import * 13 | from .optimizer import * 14 | from .stats_writers import * 15 | from .training_data import * 16 | from .work_dir import * 17 | -------------------------------------------------------------------------------- /xfuse/session/items/colormap.py: -------------------------------------------------------------------------------- 1 | from matplotlib.cm import inferno # pylint: disable=no-name-in-module 2 | from .. import SessionItem, register_session_item 3 | 4 | 5 | register_session_item( 6 | "colormap", 7 | SessionItem(setter=lambda _: None, default=inferno, persistent=False), 8 | ) 9 | -------------------------------------------------------------------------------- /xfuse/session/items/covariates.py: -------------------------------------------------------------------------------- 1 | from .. import SessionItem, register_session_item 2 | 3 | 4 | register_session_item( 5 | "covariates", 6 | SessionItem(setter=lambda _: None, default={}, persistent=True), 7 | ) 8 | -------------------------------------------------------------------------------- /xfuse/session/items/dataloader.py: -------------------------------------------------------------------------------- 1 | from .. import SessionItem, register_session_item 2 | 3 | register_session_item( 4 | "dataloader", 5 | SessionItem(setter=lambda _: None, default=None, persistent=False), 6 | ) 7 | -------------------------------------------------------------------------------- /xfuse/session/items/default_device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .. import SessionItem, register_session_item 4 | from ...logging import DEBUG, log 5 | from ...utility.tensor import to_device 6 | from ...utility.state.state import StateDict, get_state_dict, load_state_dict 7 | 8 | 9 | __DEFAULT_DEVICE = ( 10 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 11 | ) 12 | __CURRENT_DEVICE = __DEFAULT_DEVICE 13 | 14 | 15 | def _set_default_device(device): 16 | # pylint: disable=global-statement 17 | global __CURRENT_DEVICE 18 | if device != __CURRENT_DEVICE: 19 | log(DEBUG, "Setting default device to %s", str(device)) 20 | state_dict = get_state_dict() 21 | new_state_dict = StateDict( 22 | params=state_dict.params, 23 | modules=to_device(state_dict.modules, device=device), 24 | optimizer=to_device(state_dict.optimizer, device=device), 25 | ) 26 | load_state_dict(new_state_dict) 27 | __CURRENT_DEVICE = device 28 | 29 | 30 | register_session_item( 31 | "default_device", 32 | SessionItem( 33 | setter=_set_default_device, default=__DEFAULT_DEVICE, persistent=False, 34 | ), 35 | ) 36 | -------------------------------------------------------------------------------- /xfuse/session/items/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .. import SessionItem, register_session_item 4 | 5 | 6 | def _set_eval(eval_mode): 7 | torch.set_grad_enabled(not eval_mode) 8 | 9 | 10 | register_session_item( 11 | "eval", SessionItem(setter=_set_eval, default=False, persistent=False) 12 | ) 13 | -------------------------------------------------------------------------------- /xfuse/session/items/genes.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from .. import SessionItem, get, register_session_item 4 | 5 | 6 | def _set_genes(x: Optional[List[str]]) -> None: 7 | dataloader = get("dataloader") 8 | if dataloader and x and dataloader.dataset.genes != x: 9 | dataloader.dataset.genes = x 10 | dataloader.reset_workers() 11 | 12 | 13 | register_session_item( 14 | "genes", SessionItem(setter=_set_genes, default=None, persistent=True) 15 | ) 16 | -------------------------------------------------------------------------------- /xfuse/session/items/learning_rate.py: -------------------------------------------------------------------------------- 1 | from .. import SessionItem, register_session_item 2 | from ...logging import DEBUG, log 3 | from ...utility.state.state import get_state_dict, load_state_dict 4 | 5 | 6 | __DEFAULT_LR = 0.0003 7 | __CURRENT_LR = __DEFAULT_LR 8 | 9 | 10 | def _set_learning_rate(learning_rate): 11 | # pylint: disable=global-statement 12 | global __CURRENT_LR 13 | if learning_rate != __CURRENT_LR: 14 | log(DEBUG, "Setting learning rate to %f", learning_rate) 15 | load_state_dict(get_state_dict()) 16 | __CURRENT_LR = learning_rate 17 | 18 | 19 | register_session_item( 20 | "learning_rate", 21 | SessionItem(setter=_set_learning_rate, default=1e-3, persistent=False), 22 | ) 23 | -------------------------------------------------------------------------------- /xfuse/session/items/log_file.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Union 3 | from _io import TextIOWrapper 4 | 5 | from ...logging import LOGGER 6 | from ...logging.formatter import Formatter 7 | from .. import SessionItem, Unset, register_session_item 8 | 9 | 10 | def _setter(filebuffers: Union[List[TextIOWrapper], Unset]): 11 | warnings_logger = logging.getLogger("py.warnings") 12 | while warnings_logger.handlers != []: 13 | warnings_logger.removeHandler(warnings_logger.handlers[0]) 14 | 15 | while LOGGER.handlers != []: 16 | LOGGER.removeHandler(LOGGER.handlers[0]) 17 | 18 | if isinstance(filebuffers, list): 19 | for filebuffer in filebuffers: 20 | fancy_formatting = filebuffer.isatty() 21 | 22 | handler = logging.StreamHandler(filebuffer) 23 | handler.setFormatter(Formatter(fancy_formatting=fancy_formatting)) 24 | 25 | LOGGER.addHandler(handler) 26 | logging.getLogger("py.warnings").addHandler(handler) 27 | 28 | 29 | register_session_item( 30 | "log_file", SessionItem(setter=_setter, default=[], persistent=False), 31 | ) 32 | -------------------------------------------------------------------------------- /xfuse/session/items/log_level.py: -------------------------------------------------------------------------------- 1 | from .. import SessionItem, register_session_item 2 | from ...logging import INFO, set_level 3 | 4 | 5 | register_session_item( 6 | "log_level", SessionItem(setter=set_level, default=INFO, persistent=False) 7 | ) 8 | -------------------------------------------------------------------------------- /xfuse/session/items/messengers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pyro.poutine.runtime as pyro_runtime 4 | from pyro.poutine.messenger import Messenger 5 | 6 | from .. import SessionItem, register_session_item 7 | 8 | 9 | __all__: List[str] = [] 10 | 11 | __INSTALLED_MESSENGERS: List[Messenger] = [] 12 | 13 | 14 | def _messengers_setter(messengers: List[Messenger]) -> None: 15 | # pylint: disable=protected-access 16 | # ^ HACK: This setter installs and uninstall Messengers manually by 17 | # modifying Pyro's runtime stack, effectively replacing their 18 | # __enter__ and __exit__ methods. 19 | # Messengers are always added to the bottom of the stack to avoid 20 | # interfering with other Messengers. 21 | pyro_runtime._PYRO_STACK[:] = [ 22 | messenger 23 | for messenger in pyro_runtime._PYRO_STACK 24 | if messenger not in __INSTALLED_MESSENGERS 25 | ] 26 | pyro_runtime._PYRO_STACK[:] = [*messengers, *pyro_runtime._PYRO_STACK] 27 | __INSTALLED_MESSENGERS[:] = messengers 28 | 29 | 30 | register_session_item( 31 | "messengers", 32 | SessionItem(setter=_messengers_setter, default=[], persistent=False), 33 | ) 34 | -------------------------------------------------------------------------------- /xfuse/session/items/model.py: -------------------------------------------------------------------------------- 1 | from .. import SessionItem, register_session_item 2 | 3 | register_session_item( 4 | "model", SessionItem(setter=lambda _: None, default=None, persistent=True) 5 | ) 6 | -------------------------------------------------------------------------------- /xfuse/session/items/mpl_backend.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | from .. import SessionItem, register_session_item 3 | 4 | 5 | register_session_item( 6 | "mpl_backend", 7 | SessionItem( 8 | setter=matplotlib.use, 9 | default=matplotlib.get_backend(), 10 | persistent=False, 11 | ), 12 | ) 13 | -------------------------------------------------------------------------------- /xfuse/session/items/optimizer.py: -------------------------------------------------------------------------------- 1 | from .. import SessionItem, register_session_item 2 | 3 | register_session_item( 4 | "optimizer", 5 | SessionItem(setter=lambda _: None, default=None, persistent=False), 6 | ) 7 | -------------------------------------------------------------------------------- /xfuse/session/items/stats_writers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ...messengers.stats.writer import StatsWriter 4 | from .. import SessionItem, register_session_item 5 | 6 | 7 | __all__: List[str] = [] 8 | 9 | 10 | def _stats_writer_setter(stats_writers: List[StatsWriter]) -> None: 11 | # pylint: disable=unused-argument 12 | pass 13 | 14 | 15 | register_session_item( 16 | "stats_writers", 17 | SessionItem(setter=_stats_writer_setter, default=[], persistent=False), 18 | ) 19 | -------------------------------------------------------------------------------- /xfuse/session/items/training_data.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .. import SessionItem, register_session_item 4 | 5 | 6 | class TrainingData: 7 | r"""Data structure for holding training data""" 8 | epoch: int = 0 9 | step: int = 0 10 | elbo_long: Optional[float] = None 11 | elbo_short: Optional[float] = None 12 | rmse: Optional[float] = None 13 | 14 | 15 | register_session_item( 16 | "training_data", 17 | SessionItem( 18 | setter=lambda _: None, default=TrainingData(), persistent=True 19 | ), 20 | ) 21 | -------------------------------------------------------------------------------- /xfuse/session/items/work_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | from .. import SessionItem, register_session_item 5 | from ...logging import DEBUG, log 6 | 7 | 8 | __all__: List[str] = [] 9 | 10 | 11 | class WorkDir: 12 | """Stores the current working directory""" 13 | 14 | def __init__(self, root: Optional[str] = None, subpath: str = os.curdir): 15 | if root is None: 16 | root = os.getcwd() 17 | self.root = root 18 | self.subpath = subpath 19 | 20 | def __eq__(self, other: object) -> bool: 21 | if not isinstance(other, WorkDir): 22 | raise NotImplementedError() 23 | return other.full_path == self.full_path 24 | 25 | @property 26 | def root(self) -> str: 27 | """The root file path""" 28 | return self.__root 29 | 30 | @root.setter 31 | def root(self, root: str): 32 | """Setter for the root file path""" 33 | root = os.path.expanduser(root) 34 | root = os.path.expandvars(root) 35 | root = os.path.abspath(root) 36 | root = os.path.normcase(root) 37 | root = os.path.normpath(root) 38 | self.__root = root 39 | 40 | @property 41 | def subpath(self) -> str: 42 | """The currently active subdirectory""" 43 | return self.__subpath 44 | 45 | @subpath.setter 46 | def subpath(self, subpath: str): 47 | """Setter for the currently active subdirectory""" 48 | subpath = os.path.expandvars(subpath) 49 | subpath = os.path.normcase(subpath) 50 | subpath = os.path.normpath(subpath) 51 | self.__subpath = subpath 52 | 53 | @property 54 | def full_path(self) -> str: 55 | """The full path (:func:`root` + :func:`subpath`)""" 56 | return os.path.join(self.root, self.subpath) 57 | 58 | 59 | __DEFAULT_WORKDIR = WorkDir() 60 | __CUR_WORKDIR = __DEFAULT_WORKDIR 61 | 62 | 63 | def _work_dir_setter(work_dir: WorkDir) -> None: 64 | # pylint: disable=global-statement 65 | global __CUR_WORKDIR 66 | if work_dir != __CUR_WORKDIR: 67 | log(DEBUG, "Changing working directory to: %s", work_dir.full_path) 68 | if not os.path.exists(work_dir.full_path): 69 | os.makedirs(work_dir.full_path, exist_ok=True) 70 | os.chdir(work_dir.full_path) 71 | __CUR_WORKDIR = work_dir 72 | 73 | 74 | register_session_item( 75 | "work_dir", 76 | SessionItem(setter=_work_dir_setter, default=WorkDir(), persistent=False), 77 | ) 78 | -------------------------------------------------------------------------------- /xfuse/session/session.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from traceback import format_exc 3 | from typing import Any, Dict, List 4 | 5 | from ..logging import DEBUG, ERROR, log 6 | from .session_item import SessionItem 7 | 8 | __all__ = [ 9 | "Session", 10 | "Unset", 11 | "get_session", 12 | "get", 13 | "register_session_item", 14 | "require", 15 | ] 16 | 17 | 18 | class Unset: 19 | r"""Marker for unset :class:`Session` items""" 20 | 21 | def __str__(self): 22 | return "UNSET" 23 | 24 | 25 | class Session: 26 | r"""Session context manager""" 27 | 28 | def __init__(self, **kwargs): 29 | for name in _SESSION_STORE: 30 | try: 31 | value = kwargs.pop(name) 32 | except KeyError: 33 | value = Unset() 34 | setattr(self, name, value) 35 | if len(kwargs) != 0: 36 | raise ValueError( 37 | f'invalid session items: {",".join(kwargs.keys())}' 38 | ) 39 | self._level = -1 40 | 41 | def __enter__(self): 42 | _SESSION_STACK.append(self) 43 | for session in _SESSION_STACK: 44 | session._level += 1 45 | _apply_session(get_session()) 46 | 47 | def __exit__(self, err_type, err, tb): 48 | if err_type is not None: 49 | if self._level == 0: 50 | log( 51 | ERROR, 52 | "%s: %s\n%s", 53 | err_type.__name__, 54 | str(err), 55 | format_exc(), 56 | ) 57 | panic_handler = get("panic") 58 | if not isinstance(panic_handler, Unset): 59 | panic_handler(get_session(), err_type, err, tb) 60 | else: 61 | for session in _SESSION_STACK: 62 | session._level -= 1 63 | assert self._level == -1 64 | assert self == _SESSION_STACK.pop() 65 | _apply_session(get_session()) 66 | 67 | def __str__(self): 68 | return ( 69 | "Session {" 70 | + "; ".join(f"{x}={getattr(self, str(x))}" for x in _SESSION_STORE) 71 | + "}" 72 | ) 73 | 74 | def __iter__(self): 75 | for key in _SESSION_STORE: 76 | yield key, getattr(self, key) 77 | 78 | 79 | _SESSION_STACK: List[Session] = [] 80 | _SESSION_STORE: Dict[str, SessionItem] = {} 81 | 82 | 83 | def _apply_session(session: Session): 84 | for name, (setter, default, _persistent) in _SESSION_STORE.items(): 85 | setter(getattr(session, name, default)) 86 | 87 | 88 | def get(name: str) -> Any: 89 | r""" 90 | Gets session item from the current context. Returns its default value if 91 | unset. 92 | """ 93 | try: 94 | return require(name) 95 | except RuntimeError: 96 | return _SESSION_STORE[name].default 97 | 98 | 99 | def require(name: str) -> Any: 100 | r""" 101 | Gets session item from the current context. Raises `RuntimeError` if unset. 102 | """ 103 | if name not in _SESSION_STORE: 104 | raise ValueError(f"{name} is not a session item") 105 | 106 | for obj in reversed(_SESSION_STACK): 107 | try: 108 | val = getattr(obj, name) 109 | if not isinstance(val, Unset): 110 | return val 111 | except AttributeError: 112 | warnings.warn(f'Session object lacks attribute "{name}"') 113 | 114 | raise RuntimeError(f"Session item {name} has not been set!") 115 | 116 | 117 | def get_session(): 118 | r""" 119 | Constructs a new :class:`Sessions` based on the current session context 120 | """ 121 | return Session(**{name: get(name) for name in _SESSION_STORE}) 122 | 123 | 124 | def register_session_item(name: str, x: SessionItem) -> None: 125 | r"""Registers new :class:`SessionItem`""" 126 | log(DEBUG, 'Registering session item "%s"', name) 127 | _SESSION_STORE[name] = x 128 | 129 | 130 | register_session_item( 131 | "panic", SessionItem(lambda _: None, lambda *_: None, persistent=False) 132 | ) 133 | -------------------------------------------------------------------------------- /xfuse/session/session_item.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, NamedTuple 2 | 3 | 4 | class SessionItem(NamedTuple): 5 | r"""Data structure for session items""" 6 | 7 | setter: Callable[[Any], None] 8 | default: Any 9 | persistent: bool = True 10 | -------------------------------------------------------------------------------- /xfuse/train.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | 3 | import pyro 4 | from pyro.poutine.runtime import effectful 5 | 6 | from .logging import DEBUG, INFO, Progressbar, log 7 | from .session import get, require 8 | from .utility.pyro import TraceWithDuplicates 9 | from .utility.tensor import to_device 10 | 11 | 12 | def test_convergence(): 13 | r""" 14 | Tests if the model has converged according to a heuristic stopping 15 | criterion 16 | """ 17 | training_data = get("training_data") 18 | return ( 19 | training_data.epoch > 1000 20 | and training_data.elbo_long > training_data.elbo_short 21 | ) 22 | 23 | 24 | def train(epochs: int = -1): 25 | """Trains the session model""" 26 | optim = require("optimizer") 27 | model = require("model") 28 | dataloader = require("dataloader") 29 | training_data = get("training_data") 30 | 31 | @effectful(type="step") 32 | def _step(*, x): 33 | loss = pyro.infer.Trace_ELBO() 34 | with TraceWithDuplicates() as trace: 35 | pyro.infer.SVI(model.model, model.guide, optim, loss).step(x) 36 | return trace.trace 37 | 38 | @effectful(type="epoch") 39 | def _epoch(*, epoch): 40 | if isinstance(optim, pyro.optim.PyroLRScheduler): 41 | optim.step(epoch=epoch) 42 | with Progressbar( 43 | dataloader, desc=f"Epoch {epoch:05d}", leave=False, 44 | ) as iterator: 45 | for x in iterator: 46 | training_data.step += 1 47 | _step(x=to_device(x)) 48 | 49 | with Progressbar( 50 | ( 51 | it.count(training_data.epoch + 1) 52 | if epochs < 0 53 | else range(training_data.epoch + 1, epochs + 1) 54 | ), 55 | desc="Optimizing model", 56 | unit="epoch", 57 | dynamic_ncols=True, 58 | leave=False, 59 | ) as iterator: 60 | for epoch in iterator: 61 | training_data.epoch = epoch 62 | _epoch(epoch=epoch) 63 | log( 64 | INFO, 65 | " | ".join( 66 | ["Epoch %05d", "Running ELBO %+.4e", "Running RMSE %.3f"] 67 | ), 68 | epoch, 69 | training_data.elbo_long or 0.0, 70 | training_data.rmse or 0.0, 71 | ) 72 | 73 | if epochs < 0 and test_convergence(): 74 | log(DEBUG, "Model has converged, stopping") 75 | break 76 | -------------------------------------------------------------------------------- /xfuse/utility/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /xfuse/utility/core.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | from typing import ( 3 | Any, 4 | ContextManager, 5 | Iterable, 6 | List, 7 | Protocol, 8 | Tuple, 9 | TypeVar, 10 | Sequence, 11 | Union, 12 | ) 13 | 14 | import warnings 15 | 16 | import numpy as np 17 | from PIL import Image 18 | 19 | 20 | __all__ = [ 21 | "center_crop", 22 | "chunks_of", 23 | "rescale", 24 | "resize", 25 | "temp_attr", 26 | ] 27 | 28 | 29 | ArrayType = TypeVar("ArrayType", bound="ArrayLike") 30 | 31 | 32 | class ArrayLike(Protocol): 33 | r""" 34 | A protocol for sliceable objects (e.g., numpy arrays or pytorch tensors) 35 | """ 36 | 37 | @property 38 | def shape(self) -> Tuple[int, ...]: 39 | # pylint: disable=missing-docstring 40 | ... 41 | 42 | def __getitem__( 43 | self: ArrayType, idx: Union[slice, Tuple[slice, ...]] 44 | ) -> ArrayType: 45 | ... 46 | 47 | def __setitem__( 48 | self: ArrayType, idx: Union[slice, Tuple[slice, ...]], value: Any 49 | ) -> None: 50 | ... 51 | 52 | 53 | def center_crop(x: ArrayType, target_shape: Tuple[int, ...]) -> ArrayType: 54 | r"""Crops `x` to the given `target_shape` from the center""" 55 | return x[ 56 | tuple( 57 | slice(round((a - b) / 2), round((a - b) / 2) + b) 58 | if b is not None 59 | else slice(None) 60 | for a, b in zip(x.shape, target_shape) 61 | ) 62 | ] 63 | 64 | 65 | def rescale( 66 | image: np.ndarray, scaling_factor: float, resample: int = Image.NEAREST 67 | ) -> np.ndarray: 68 | r""" 69 | Rescales image by a given `scaling_factor` 70 | 71 | :param image: Image array 72 | :param scaling_factor: Scaling factor 73 | :param resample: Resampling filter 74 | :returns: The rescaled image 75 | """ 76 | image_pil = Image.fromarray(image) 77 | image_pil = image_pil.resize( 78 | [round(x * scaling_factor) for x in image_pil.size], resample=resample, 79 | ) 80 | return np.array(image_pil) 81 | 82 | 83 | def resize( 84 | image: np.ndarray, 85 | target_shape: Sequence[int], 86 | resample: int = Image.NEAREST, 87 | ) -> np.ndarray: 88 | r""" 89 | Resizes image to a given `target_shape` 90 | 91 | :param image: Image array 92 | :param target_shape: Target shape 93 | :param resample: Resampling filter 94 | :returns: The rescaled image 95 | """ 96 | image_pil = Image.fromarray(image) 97 | image_pil = image_pil.resize(target_shape[::-1], resample=resample) 98 | return np.array(image_pil) 99 | 100 | 101 | def temp_attr(obj: object, attr: str, value: Any) -> ContextManager: 102 | r""" 103 | Creates a context manager for setting transient object attributes. 104 | 105 | >>> from types import SimpleNamespace 106 | >>> obj = SimpleNamespace(x=1) 107 | >>> with temp_attr(obj, 'x', 2): 108 | ... print(obj.x) 109 | 2 110 | >>> print(obj.x) 111 | 1 112 | """ 113 | 114 | class _TempAttr: 115 | def __init__(self): 116 | self.__original_value = None 117 | 118 | def __enter__(self): 119 | self.__original_value = getattr(obj, attr) 120 | setattr(obj, attr, value) 121 | 122 | def __exit__(self, *_): 123 | if getattr(obj, attr) == value: 124 | setattr(obj, attr, self.__original_value) 125 | else: 126 | warnings.warn( 127 | f'Attribute "{attr}" changed while in context.' 128 | " The new value will be kept.", 129 | ) 130 | 131 | return _TempAttr() 132 | 133 | 134 | T = TypeVar("T") 135 | 136 | 137 | def chunks_of(xs: Iterable[T], size: int) -> Iterable[List[T]]: 138 | r""" 139 | Yields size `size` chunks of `xs`. 140 | 141 | >>> list(chunks_of([1, 2, 3, 4], 2)) 142 | [[1, 2], [3, 4]] 143 | """ 144 | 145 | class _StopMarker: 146 | pass 147 | 148 | for chunk in it.zip_longest(*[iter(xs)] * size, fillvalue=_StopMarker): 149 | yield list(filter(lambda x: x is not _StopMarker, chunk)) 150 | -------------------------------------------------------------------------------- /xfuse/utility/file.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | import os 3 | 4 | from ..session import Session, get 5 | from ..session.items.work_dir import WorkDir 6 | 7 | 8 | def chdir(dirname: str) -> Session: 9 | r""" 10 | Changes the session working directory to `dirname`. Absolute paths are 11 | rerooted with the session root. 12 | """ 13 | cwd = get("work_dir") 14 | if os.path.isabs(dirname): 15 | _root, *subdirs = os.path.normpath(dirname).split(os.sep) 16 | subpath = os.path.join(*subdirs) 17 | else: 18 | subpath = os.path.join(cwd.subpath, dirname) 19 | return Session(work_dir=WorkDir(root=cwd.root, subpath=subpath)) 20 | 21 | 22 | def first_unique_filename(root_name: str) -> str: 23 | r""" 24 | Returns the first non-existent filename in the sequence "`root_name`", 25 | "`root_name`.1", "`root_name`.2", ... 26 | """ 27 | for path in it.chain( 28 | (root_name,), (f"{root_name}.{i}" for i in it.count(1)) 29 | ): 30 | if not os.path.exists(path): 31 | return path 32 | raise RuntimeError("Unreachable code path") 33 | -------------------------------------------------------------------------------- /xfuse/utility/mask.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | import warnings 3 | from typing import Optional 4 | 5 | import cv2 as cv 6 | import numpy as np 7 | from PIL import Image 8 | from scipy.ndimage import label 9 | from scipy.ndimage.morphology import binary_fill_holes 10 | 11 | from .core import rescale, resize 12 | from ..logging import INFO, log 13 | 14 | 15 | def remove_fg_elements(mask: np.ndarray, size_threshold: float): 16 | r"""Removes small foreground elements""" 17 | labels, _ = label(mask) 18 | labels_unique, label_counts = np.unique(labels, return_counts=True) 19 | small_labels = labels_unique[ 20 | label_counts < size_threshold ** 2 * np.prod(mask.shape) 21 | ] 22 | mask[np.isin(labels, small_labels)] = False 23 | return mask 24 | 25 | 26 | def compute_tissue_mask( 27 | image: np.ndarray, 28 | convergence_threshold: float = 0.0001, 29 | size_threshold: float = 0.01, 30 | initial_mask: Optional[np.ndarray] = None, 31 | ) -> np.ndarray: 32 | r""" 33 | Computes boolean mask indicating likely foreground elements in histology 34 | image. 35 | """ 36 | # pylint: disable=no-member 37 | # ^ pylint fails to identify cv.* members 38 | original_shape = image.shape[:2] 39 | scale_factor = 1000 / max(original_shape) 40 | 41 | image = rescale(image, scale_factor, resample=Image.NEAREST) 42 | 43 | if initial_mask is None: 44 | initial_mask = ( 45 | cv.blur(cv.Canny(cv.blur(image, (5, 5)), 100, 200), (20, 20)) > 0 46 | ) 47 | initial_mask = binary_fill_holes(initial_mask) 48 | initial_mask = remove_fg_elements(initial_mask, 0.1) # type: ignore 49 | 50 | mask = np.where(initial_mask, cv.GC_PR_FGD, cv.GC_PR_BGD) 51 | mask = mask.astype(np.uint8) 52 | else: 53 | mask = initial_mask 54 | mask = rescale(mask, scale_factor, resample=Image.NEAREST) 55 | 56 | bgd_model = np.zeros((1, 65), np.float64) 57 | fgd_model = bgd_model.copy() 58 | 59 | log(INFO, "Computing tissue mask:") 60 | 61 | for i in it.count(1): 62 | old_mask = mask.copy() 63 | try: 64 | cv.grabCut( 65 | image, 66 | mask, 67 | None, 68 | bgd_model, 69 | fgd_model, 70 | 1, 71 | cv.GC_INIT_WITH_MASK, 72 | ) 73 | except cv.error as cv_err: 74 | warnings.warn(f"Failed to mask tissue\n{str(cv_err).strip()}") 75 | mask = np.full_like(mask, cv.GC_PR_FGD) 76 | break 77 | prop_changed = (mask != old_mask).sum() / np.prod(mask.shape) 78 | log(INFO, " Iteration %2d Δ = %.2f%%", i, 100 * prop_changed) 79 | if prop_changed < convergence_threshold: 80 | break 81 | 82 | mask = np.isin(mask, [cv.GC_FGD, cv.GC_PR_FGD]) 83 | mask = cleanup_mask(mask, size_threshold) 84 | 85 | mask = resize(mask, target_shape=original_shape, resample=Image.NEAREST) 86 | 87 | return mask 88 | 89 | 90 | def cleanup_mask(mask: np.ndarray, size_threshold: float): 91 | r"""Removes small background and foreground elements""" 92 | mask = ~remove_fg_elements(~mask, size_threshold) 93 | mask = remove_fg_elements(mask, size_threshold) 94 | return mask 95 | -------------------------------------------------------------------------------- /xfuse/utility/pyro.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | import pyro 4 | 5 | 6 | class TraceWithDuplicates(pyro.poutine.trace_messenger.TraceMessenger): 7 | """ 8 | A version of :class:`pyro.poutine.trace_messenger.TraceMessenger` that 9 | allows tracing with duplicated sample sites. This is necessary for tracing 10 | the guide and model simultaneously. 11 | """ 12 | 13 | def _pyro_post_sample(self, msg): 14 | msg = copy(msg) 15 | msg["site"] = msg.pop("name") 16 | msg["name"] = str(len(self.trace.nodes)) 17 | super()._pyro_post_sample(msg) 18 | 19 | def _pyro_post_param(self, msg): 20 | msg = copy(msg) 21 | msg["site"] = msg.pop("name") 22 | msg["name"] = str(len(self.trace.nodes)) 23 | super()._pyro_post_param(msg) 24 | -------------------------------------------------------------------------------- /xfuse/utility/state/__init__.py: -------------------------------------------------------------------------------- 1 | from .getters import get_module, get_param 2 | from .state import StateDict, get_state_dict, load_state_dict, reset_state 3 | 4 | 5 | __all__ = [ 6 | "get_module", 7 | "get_param", 8 | "StateDict", 9 | "get_state_dict", 10 | "load_state_dict", 11 | "reset_state", 12 | ] 13 | -------------------------------------------------------------------------------- /xfuse/utility/state/getters.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional 2 | 3 | import pyro 4 | import torch 5 | 6 | from .state import __MODULES, __STATE_DICT, Param 7 | from ...session import get 8 | from ...utility.tensor import checkpoint as _checkpoint 9 | 10 | 11 | def get_module( 12 | name: str, 13 | module: Optional[Callable[[], torch.nn.Module]] = None, 14 | checkpoint: bool = False, 15 | ) -> Callable[..., Any]: 16 | r""" 17 | Retrieves :class:`~torch.nn.Module` by name or creates it if it doesn't 18 | exist. 19 | 20 | :param name: Module name 21 | :param module: Module to register if it doesn't already exist. The module 22 | should be "quoted" by encapsulating it in a `Callable` in order to lazify 23 | its creation. 24 | :param checkpoint: Flag indicating whether the module should be 25 | checkpointed 26 | 27 | :returns: The module 28 | :raises RuntimeError: If there is no module named `name` and `module` is 29 | `None`. 30 | """ 31 | try: 32 | module_ = pyro.module(name, __MODULES[name]) 33 | except KeyError as exc: 34 | if module is None: 35 | raise RuntimeError(f'Module "{name}" does not exist') from exc 36 | module_ = pyro.module(name, module(), update_module_params=True) 37 | if name in __STATE_DICT.modules: 38 | module_.load_state_dict(__STATE_DICT.modules[name]) 39 | module_ = module_.to(get("default_device")) 40 | __MODULES[name] = module_ 41 | module_ = module_.train(not get("eval")) 42 | if checkpoint: 43 | return lambda *args, **kwargs: _checkpoint(module_, *args, **kwargs) 44 | return module_ 45 | 46 | 47 | def get_param( 48 | name: str, 49 | default_value: Optional[Callable[[], torch.Tensor]] = None, 50 | lr_multiplier: float = 1.0, 51 | **kwargs: Any, 52 | ) -> torch.Tensor: 53 | r""" 54 | Retrieves learnable :class:`~torch.Tensor` (non-module parameter) by 55 | name or creates it if it doesn't exist. 56 | 57 | :param name: Parameter name 58 | :param default_value: Default value if parameter doesn't exist. The value 59 | should be "quoted" by encapsulating it in a `Callable` in order to lazify 60 | its creation. 61 | :param lr_multiplier: Learning rate multiplier 62 | :param kwargs: Arguments passed to :func:`~pyro.sample`. 63 | 64 | :returns: The parameter 65 | :raises RuntimeError: If there is no parameter named `name` and 66 | `default_value` is `None`. 67 | """ 68 | if name in pyro.get_param_store(): 69 | return pyro.param(name) 70 | try: 71 | value = __STATE_DICT.params[name].data 72 | except KeyError as exc: 73 | if default_value is None: 74 | raise RuntimeError(f'Parameter "{name}" does not exist') from exc 75 | if callable(default_value): 76 | value = default_value() 77 | else: 78 | value = default_value 79 | __STATE_DICT.params[name] = Param( 80 | data=value.detach().cpu(), 81 | optim_args={"lr_multiplier": lr_multiplier}, 82 | ) 83 | return pyro.param(name, value.to(get("default_device")), **kwargs) 84 | 85 | 86 | def get_param_optim_args(name: str) -> Dict[str, Any]: 87 | r""" 88 | :param name: Parameter name 89 | :returns: The optimizer arguments 90 | :raises KeyError: If there is no parameter named `name` 91 | """ 92 | return __STATE_DICT.params[name].optim_args 93 | -------------------------------------------------------------------------------- /xfuse/utility/state/state.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from typing import Any, Dict, NamedTuple 3 | 4 | import pyro 5 | import torch 6 | 7 | from ...session import get 8 | from ...utility.tensor import to_device 9 | 10 | 11 | class Param(NamedTuple): 12 | r"""Data structure for model parameters""" 13 | data: torch.Tensor 14 | optim_args: Dict[str, Any] = {} 15 | 16 | 17 | class StateDict(NamedTuple): 18 | r"""Data structure for the states of modules and non-module parameters""" 19 | modules: Dict[str, Dict[str, torch.Tensor]] # type: ignore 20 | params: Dict[str, Param] 21 | optimizer: Dict[str, Dict[str, torch.Tensor]] 22 | 23 | 24 | __MODULES: Dict[str, torch.nn.Module] = {} 25 | __STATE_DICT: StateDict = StateDict(modules={}, params={}, optimizer={}) 26 | 27 | 28 | def get_state_dict() -> StateDict: 29 | r"""Returns the state dicts of the modules in the module store""" 30 | state_dict = StateDict( 31 | modules=copy(__STATE_DICT.modules), 32 | params=copy(__STATE_DICT.params), 33 | optimizer=copy(__STATE_DICT.optimizer), 34 | ) 35 | state_dict.modules.update( 36 | { 37 | name: to_device(module.state_dict(), torch.device("cpu")) 38 | for name, module in __MODULES.items() 39 | } 40 | ) 41 | param_store = pyro.get_param_store() 42 | state_dict.params.update( 43 | { 44 | name: Param( 45 | data=param_store[name].detach().cpu(), 46 | optim_args=param.optim_args, 47 | ) 48 | for name, param in __STATE_DICT.params.items() 49 | if name in param_store 50 | } 51 | ) 52 | optimizer = get("optimizer") 53 | if optimizer is not None: 54 | state_dict.optimizer.update( 55 | to_device(optimizer.get_state(), torch.device("cpu")) 56 | ) 57 | return state_dict 58 | 59 | 60 | def load_state_dict(state_dict: StateDict) -> None: 61 | r"""Sets the default state dicts for the modules in the module store""" 62 | reset_state() 63 | __STATE_DICT.modules.update(state_dict.modules) 64 | __STATE_DICT.params.update(state_dict.params) 65 | __STATE_DICT.optimizer.update(state_dict.optimizer) 66 | optimizer = get("optimizer") 67 | if optimizer is not None: 68 | optimizer.set_state(__STATE_DICT.optimizer) 69 | 70 | 71 | def reset_state() -> None: 72 | r"""Resets all state modules and parameters""" 73 | __MODULES.clear() 74 | __STATE_DICT.modules.clear() 75 | __STATE_DICT.params.clear() 76 | __STATE_DICT.optimizer.clear() 77 | pyro.clear_param_store() 78 | optimizer = get("optimizer") 79 | if optimizer is not None: 80 | optimizer.optim_objs.clear() 81 | optimizer.grad_clip.clear() 82 | # pylint: disable=protected-access 83 | optimizer._state_waiting_to_be_consumed.clear() 84 | -------------------------------------------------------------------------------- /xfuse/utility/tensor.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Any, 3 | Dict, 4 | List, 5 | Optional, 6 | cast, 7 | overload, 8 | ) 9 | 10 | import numpy as np 11 | import torch 12 | from torch.utils.checkpoint import checkpoint as _checkpoint 13 | 14 | from ..session import get 15 | 16 | 17 | def checkpoint(function, *args, **kwargs): 18 | r""" 19 | Wrapper for :func:`torch.utils.checkpoint.checkpoint` that conditions 20 | checkpointing on the session `eval` state. 21 | """ 22 | if get("eval"): 23 | return function(*args, **kwargs) 24 | return _checkpoint(function, *args, **kwargs) 25 | 26 | 27 | class NoDevice(Exception): 28 | # pylint: disable=missing-class-docstring 29 | pass 30 | 31 | 32 | def find_device(x: Any) -> torch.device: 33 | r""" 34 | Tries to find the :class:`torch.device` associated with the given object 35 | """ 36 | 37 | if isinstance(x, torch.Tensor): 38 | return x.device 39 | 40 | if isinstance(x, list): 41 | for y in x: 42 | try: 43 | return find_device(y) 44 | except NoDevice: 45 | pass 46 | 47 | if isinstance(x, dict): 48 | for y in x.values(): 49 | try: 50 | return find_device(y) 51 | except NoDevice: 52 | pass 53 | 54 | raise NoDevice(f"Failed to find a device associated with {x}") 55 | 56 | 57 | def sparseonehot(labels: torch.Tensor, num_classes: Optional[int] = None): 58 | r"""One-hot encodes a label vectors into a sparse tensor""" 59 | if num_classes is None: 60 | num_classes = cast(int, labels.max().item()) + 1 61 | idx = torch.stack([torch.arange(labels.shape[0]).to(labels), labels]) 62 | return torch.sparse.LongTensor( # type: ignore 63 | idx, 64 | torch.ones(idx.shape[1]).to(idx), 65 | torch.Size([labels.shape[0], num_classes]), 66 | ) 67 | 68 | 69 | def isoftplus(x, /): 70 | r""" 71 | Inverse softplus. 72 | 73 | >>> ((isoftplus(torch.nn.functional.softplus(torch.linspace(-5, 5, 10))) 74 | ... - torch.linspace(-5, 5, 10)) < 1e-5).all() 75 | tensor(True) 76 | """ 77 | return np.log(np.exp(x) - 1) 78 | 79 | 80 | @overload 81 | def to_device( 82 | x: torch.Tensor, device: Optional[torch.device] = None 83 | ) -> torch.Tensor: 84 | # pylint: disable=missing-function-docstring 85 | ... 86 | 87 | 88 | @overload 89 | def to_device( 90 | x: List[Any], device: Optional[torch.device] = None 91 | ) -> List[Any]: 92 | # pylint: disable=missing-function-docstring 93 | ... 94 | 95 | 96 | @overload 97 | def to_device( 98 | x: Dict[Any, Any], device: Optional[torch.device] = None 99 | ) -> Dict[Any, Any]: 100 | # pylint: disable=missing-function-docstring 101 | ... 102 | 103 | 104 | def to_device(x, device=None): 105 | r""" 106 | Converts :class:`torch.Tensor` or a collection of :class:`torch.Tensor` to 107 | the given :class:`torch.device` 108 | """ 109 | if device is None: 110 | device = get("default_device") 111 | if isinstance(x, torch.Tensor): 112 | return x.to(device) 113 | if isinstance(x, list): 114 | return [to_device(y, device) for y in x] 115 | if isinstance(x, dict): 116 | return {k: to_device(v, device) for k, v in x.items()} 117 | return x 118 | --------------------------------------------------------------------------------