├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── Makefile ├── README.md ├── documentation ├── CyCIF Lung Cancer.ipynb ├── IMC Bladder Cancer.ipynb ├── IMC Breast Cancer.ipynb ├── IMC Healthy Lung.ipynb ├── IMC Lung Infection.ipynb ├── Running in R.Rmd ├── UTAG Tutorial.ipynb └── testing_memory.ipynb ├── environment.yml ├── pyproject.toml ├── setup.cfg └── utag ├── __init__.py ├── segmentation.py ├── tests ├── __init__.py └── utag_test.py ├── types.py ├── utils.py └── vizualize.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application UTAG test automation 5 | 6 | on: 7 | push: 8 | branches: [ main, dev ] 9 | pull_request: 10 | branches: [ main, dev ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | name: every OS 18 | strategy: 19 | matrix: 20 | # run 2 copies of the current job in parallel 21 | # and they will load balance all specs 22 | os: ['ubuntu-latest','macos-latest'] #, 'windows-latest'] 23 | python-version: [3.9.19, 3.10.11] 24 | fail-fast: false 25 | 26 | runs-on: ${{ matrix.os }} 27 | 28 | steps: 29 | # Checkout the latest code from the repo 30 | - uses: actions/checkout@v3 31 | - name: Set up Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@v3 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | # Display the Python version being used 36 | - name: Display Python version 37 | run: python -c "import sys; print(sys.version)" 38 | # Install the package using the setup.py 39 | - name: Install dependencies 40 | run: | 41 | python -m pip install --upgrade pip 42 | pip install flake8 pytest 43 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 44 | pip install . 45 | - name: Lint with flake8 46 | run: | 47 | # stop the build if there are Python syntax errors or undefined names 48 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 49 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 50 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 51 | - name: Test with pytest 52 | run: | 53 | pytest utag/tests/utag_test.py 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # project specific 2 | processed/ 3 | results/ 4 | manuscript/ 5 | data/ 6 | lib/ 7 | .ipynb_checkpoints/ 8 | 9 | # ignore test files 10 | .tox 11 | _version.py 12 | pytest.log 13 | .coverage* 14 | 15 | # Build-related stuff 16 | build/ 17 | dist/ 18 | *.egg-info 19 | 20 | 21 | # data files 22 | *.mcd 23 | *.h5 24 | *.h5ad 25 | *.gmt 26 | *.tiff 27 | *.tif 28 | *.mcd 29 | *.npz 30 | # *.csv 31 | *.pq 32 | 33 | # toy/experimental files 34 | *.txt 35 | *.tsv 36 | *.pkl 37 | *.xlsx 38 | *.pickle 39 | *.svg 40 | *.png 41 | *.jpg 42 | *.jpeg 43 | *.pdf 44 | 45 | # ignore mypy 46 | .mypy* 47 | 48 | # ignore eggs 49 | .eggs/ 50 | 51 | # ignore built docs 52 | doc/build/* 53 | 54 | # generic ignore list: 55 | *.lst 56 | 57 | # Compiled source 58 | *.com 59 | *.class 60 | *.dll 61 | *.exe 62 | *.o 63 | *.so 64 | *.pyc 65 | 66 | # Packages 67 | # it's better to unpack these files and commit the raw source 68 | # git has its own built in compression methods 69 | *.7z 70 | *.dmg 71 | *.gz 72 | *.iso 73 | *.jar 74 | *.rar 75 | *.tar 76 | *.zip 77 | 78 | # Logs and databases 79 | *.log 80 | *.sql 81 | *.sqlite 82 | 83 | # OS generated files 84 | .DS_Store 85 | .DS_Store? 86 | ._* 87 | .Spotlight-V100 88 | .Trashes 89 | ehthumbs.db 90 | Thumbs.db 91 | 92 | # Sublime files 93 | *.sublime-* 94 | 95 | # Gedit temporary files 96 | *~ 97 | 98 | # libreoffice lock files: 99 | .~lock* 100 | 101 | # IDE-specific items 102 | .idea/ 103 | 104 | # pytest-related 105 | .cache/ 106 | .coverage* 107 | coverage.xml 108 | 109 | # Reserved files for comparison 110 | *RESERVE* 111 | 112 | # Self Note 113 | README_future_plans.md 114 | 115 | 116 | data/ 117 | figures/ 118 | 119 | # ipynb checkpoints 120 | */.ipynb_checkpoints/ 121 | 122 | # pycache files 123 | gatdu/__pycache__ 124 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 2 | # 3 | # This file specifies the steps to run and their order and allows running them. 4 | # Type `make` for instructions. Type make to execute a command. 5 | # 6 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 7 | 8 | .DEFAULT_GOAL := help 9 | 10 | NAME=$(shell basename `pwd`) 11 | SAMPLES=$(shell ls data) 12 | 13 | help: ## Display help and quit 14 | @echo Makefile for the $(NAME) package. 15 | @echo Available commands: 16 | @grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | \ 17 | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-15s\033[0m\ 18 | %s\n", $$1, $$2}' 19 | 20 | requirements: ## Install Python requirements 21 | pip install -r requirements.txt 22 | 23 | backup_time: 24 | echo "Last backup: " `date` >> _backup_time 25 | chmod 700 _backup_time 26 | 27 | _sync: 28 | rsync --copy-links --progress -r \ 29 | . afr4001@pascal.med.cornell.edu:projects/$(NAME) 30 | 31 | sync: _sync backup_time ## [dev] Sync data/code to SCU server 32 | 33 | install: 34 | pip install --use-feature=in-tree-build . 35 | 36 | docs: 37 | cd docs; make html; xdg-open build/html/index.html 38 | 39 | clean: 40 | cd docs; make clean 41 | 42 | .PHONY : help \ 43 | requirements \ 44 | sync \ 45 | backup_time \ 46 | _sync \ 47 | sync \ 48 | install \ 49 | docs \ 50 | clean 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised discovery of tissue architechture with graphs (UTAG) 2 | [![Zenodo badge](https://zenodo.org/badge/doi/10.1038/s41592-022-01657-2.svg)](https://doi.org/10.1038/s41592-022-01657-2) ⬅️ read the published article here
3 | [![Biorxiv badge](https://zenodo.org/badge/doi/10.1101/2022.03.15.484534.svg)](https://doi.org/10.1101/2022.03.15.484534) ⬅️ read the preprint here
4 | [![Zenodo badge](https://zenodo.org/badge/doi/10.5281/zenodo.6376767.svg)](https://doi.org/10.5281/zenodo.6376767) ⬅️ Preprocessed Multiplexed Image Data and UTAG results
5 | 6 | 7 | This package implements segmentation of multiplexed imaging data into microanatomical domains. 8 | Multiplexed imaging data types are typically imaging mass cytometry (IMC), co-detection by indexing (CODEX), multiplexed ion beam imaging by time of flight (MIBI-TOF), cyclic immunofluorescence (CyCIF), and others. 9 | The package also provides functions for the downstream analysis of the detected micro-anatomical structure. 10 | 11 | 12 | ## Getting Started 13 | 14 | ### Install from github 15 | 16 | ```bash 17 | pip install git+https://github.com/ElementoLab/utag.git@main 18 | ``` 19 | Installation should take less than 10 seconds. 20 | 21 | #### Requirements 22 | There are no specific hardware requirements. 23 | 24 | Software requirements: 25 | - UTAG has been tested on Mac OS, Linux, and Windows [WSL](https://docs.microsoft.com/en-us/windows/wsl/about) (Windows Subsystem for Linux). 26 | - Python 3.9+ (tested on 3.9.19) 27 | - Python packages (automatically installed by `pip`): 28 | - numpy 29 | - pandas 30 | - anndata 31 | - scanpy 32 | - parc 33 | - squidpy 34 | - scipy 35 | - matplotlib 36 | - tqdm 37 | - networkx 38 | - parmap 39 | - scikit-learn 40 | - setuptools_scm (may require manual installation for MacOS through pip) 41 | 42 | Specific versions of Python packages have been pinned to the [setup.cfg](setup.cfg) file. 43 | 44 | ## Tutorial 45 | Check out new [tutorial](documentation/UTAG%20Tutorial.ipynb) for how to run and visualize results with UTAG on your own data with minimal required input (numeric cell feature matrix and xy coordinates). 46 | 47 | ## Basic Usage Principles 48 | 49 | The UTAG process can be run with a single function call `utag.utag`. 50 | The input is a [AnnData](https://anndata.readthedocs.io/) object which should have the position of cells (typically centroids) in the `spatial` slot of `adata.obsm`. 51 | The function will output domain classes for each cell stored in the `obs` slot of the returned AnnData object. 52 | 53 | 54 | ### Running an example/demo dataset 55 | 56 | Please refer to the [notebook directory](documentation/), and to the notebook on [running UTAG on healthy lung data](documentation/IMC%20Healthy%20Lung.ipynb) for a reproducible example. 57 | All data and respective results used for analysis can be downloaded from [![Zenodo badge](https://zenodo.org/badge/doi/10.5281/zenodo.6376767.svg)](https://doi.org/10.5281/zenodo.6376767). 58 | 59 | All data could alternatively be downloaded through command line: 60 | ```bash 61 | pip install zenodo_get 62 | zenodo_get -d 10.5281/zenodo.6376767 -o data 63 | ``` 64 | 65 | ### Running on your data 66 | 67 | To run the method on multiple images/slides in batch mode: 68 | ```python 69 | from utag import utag 70 | 71 | # Use Scanpy to get a h5ad file with provided data 72 | import scanpy as sc 73 | adata = sc.read( 74 | 'data/healthy_lung_adata.h5ad', 75 | backup_url='https://zenodo.org/record/6376767/files/healthy_lung_adata.h5ad?download=1') 76 | 77 | # Run UTAG on provided data 78 | utag_results = utag( 79 | adata, 80 | slide_key="roi", 81 | max_dist=20, 82 | normalization_mode='l1_norm', 83 | apply_clustering=True, 84 | clustering_method = 'leiden', 85 | resolutions = [0.05, 0.1, 0.3] 86 | ) 87 | ``` 88 | UTAG should take around \~2 min on a local machine for the batch mode on the data. 89 | 90 | To run the method on a single image, pass `None` to the slide_key argument: 91 | ```python 92 | from utag import utag 93 | utag_results = utag( 94 | adata, 95 | slide_key=None, 96 | max_dist=20, 97 | normalization_mode='l1_norm', 98 | apply_clustering=True, 99 | clustering_method = 'leiden', 100 | resolutions = [0.05, 0.1, 0.3] 101 | ) 102 | ``` 103 | 104 | To visually inspect the results of the method: 105 | ```python 106 | import scanpy as sc 107 | for roi in utag_results.obs['roi'].unique(): 108 | result = utag_results[utag_results.obs['roi'] == roi].copy() 109 | sc.pl.spatial(result, color = 'UTAG Label_leiden_0.1', spot_size = 10) 110 | ``` 111 | 112 | ## User Guide on UTAG (Hyperparameters to test and tune) 113 | 114 | 115 | Although UTAG greatly reduces manual labor involved in segmentation of microanatomical domains across, successful application of UTAG depends on three key user inputs. First is the `max_dist` parameter which defines the threshold distance. Second is the clustering resolution (under list of `resolutions`) to determine the coarsity of the clustering. Last is user interpretation of the resulting clusters to identify the structure. 116 | 117 | We intentionally leave the optimization of max_dist open to users to maximize the applicability of UTAG to unseen datasets. This is because this parameter is tightly related with the resolution or magnification of the data under use. In our manuscript, we apply the method on IMC data and optical imaging-based CyCIF, which have different per unit area pixel densities. In the case of IMC, we suggest that a well working max_dist is between 10 and 20 as 1 pixel exactly maps to 1 micrometer. With an imaging-based technique like CyCIF, the optimal distance can vary with magnification, focal lengths, distance to tissue, and other factors, which make it hard to suggest a one-fits-all rule. Also there might be nuanced differences for the exact tissue of interest that may vary across specimens under examination. 118 | 119 | We believe that the optimal clustering resolution is a hyperparameter that should be explored to suit their biological question of interest. For such reasons, we provide a list of resolutions as default to be explored by the users. A general rule here is that increasing the resolution parameter will return more refined substructures, while decreasing it will return coarser, more broad structures. We also recommend users to use a higher resolution parameter when screening for a rare structure, as a higher resolution will capture more structures, and vice versa. In our benchmarking, we saw that with the exception of extreme hyperparameter values, UTAG’s performance was fairly robust across various clustering resolutions (Extended Data Figure S3). 120 | 121 | ## Key Parameters 122 | 123 | | Input Parameter | Description | 124 | | ---------- |----------| 125 | | `adata` | (`anndata.AnnData`) n_cells x n_features. `AnnData` of cells with spatial coordinates stored in `adata.obsm['spatial']` as `numpy.ndarray`. | 126 | | `max_dist` | (`float`, default = 20.0) Threshold euclidean distance to determine whether a pair of cell is adjacent in graph structure. Recommended values are between 10 to 100 depending on magnification. | 127 | | `slide_key` | (`str`, optional, default = 'Slide') Key required for running UTAG across multiple images. Unique image identifiers should be placed under `adata.obs`. Use `None` to run UTAG on a single slide. | 128 | | `save_key` | (`str`, default = 'UTAG Label') Key to be added to adata object holding the UTAG clusters. Depending on the values of `clustering_method` and `resolutions`, the final keys will be of the form: {save_key}\_{method}\_{resolution}". | 129 | | `normalization_mode` | (`str`, default = 'l1_norm') Method to normalize adjacency matrix. 'l1_norm' will behave as mean-aggregation during message passing. Default is 'l1_norm'. Any other value will not perform normalization, leading to a sum-aggregation. | 130 | | `apply_clustering` | (`bool`, default = True) Whether to cluster the message passed matrix. | 131 | | `clustering_method` | (`Sequence[str]`, default = ['leiden', 'parc']) Which clustering method(s) to use for clustering of the message passed matrix. | 132 | | `resolutions` | (`Sequence[float]`, default = [0.05, 0.1, 0.3, 1.0]) Resolutions the methods in `clustering_method` should be run at. | 133 | | `parallel` | Whether the message passing part of the method should be parallelized. This is done using the `parmap` package and the `multiprocessing` module from the standard library. | 134 | 135 | For more detailed usage of the package and downstream analysis, please refer to [IMC Healthy Lung.ipynb](documentation/IMC%20Healthy%20Lung.ipynb) in the documentation folder. 136 | 137 | 138 | ## Running UTAG on R 139 | 140 | To make UTAG available to R users, we port the python code to using the `reticulate` package. The code was tested under after installing `UTAG` natively for python 3.8.10 and under conda environment with python 3.10.4, on Ubuntu 20.04.3 LTS and R 4.2.0. 141 | 142 | Nonetheless, we highly recommend that users use our package in python for more involved analysis as the package has been developed and tested more thoroughly in python. 143 | 144 | ```R 145 | install.packages('reticulate') 146 | library(reticulate) 147 | 148 | # grab python interpreter 149 | use_python('/usr/bin/python3') # in case UTAG is installed on native python 150 | # use_condaenv('utag') # in case UTAG is installed using conda 151 | # use_virtualevn('utag') # in case UTAG is installed under virtualenv 152 | 153 | 154 | # import necessary python packages 155 | utag <- import('utag') 156 | scanpy <- import('scanpy') 157 | 158 | # read anndata with cell expressions and and locations 159 | adata <- scanpy$read( 160 | 'data/healthy_lung_adata.h5ad', 161 | backup_url='https://zenodo.org/record/6376767/files/healthy_lung_adata.h5ad?download=1' 162 | ) 163 | 164 | # show general content of the data 165 | print(adata) 166 | 167 | # run UTAG 168 | utag_results <- utag$utag( 169 | adata, 170 | slide_key = "roi", 171 | max_dist = 20, 172 | normalization_mode = 'l1_norm', 173 | apply_clustering = TRUE, 174 | clustering_method = 'leiden', 175 | resolutions = c(0.05, 0.1, 0.3) 176 | ) 177 | 178 | # show content of the data that now includes UTAG results for various resolutions 179 | print(utag_results) 180 | ``` 181 | Also available as a [R markdown file in](documentation/Running%20in%20R.Rmd). 182 | 183 | ## Development 184 | 185 | We are happy to receive community contributions to UTAG through pull requests on Github. 186 | 187 | Please run tests after re-installing the package: 188 | ```bash 189 | pytest --pyargs utag 190 | ``` 191 | or before by: 192 | ```bash 193 | pytest utag/tests/utag_test.py 194 | ``` 195 | -------------------------------------------------------------------------------- /documentation/Running in R.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "UTAG" 3 | output: html_document 4 | date: '2022-05-24' 5 | --- 6 | 7 | ```{r setup, include=FALSE} 8 | #install.packages('reticulate') 9 | library(reticulate) 10 | 11 | use_condaenv('utag') 12 | pd = import('pandas', as = 'pd') 13 | #system('pip3 install git+https://github.com/ElementoLab/utag.git@main') 14 | ``` 15 | 16 | ```{r} 17 | print('Hello World') 18 | pd = import('pandas', as = 'pd', convert = FALSE) 19 | 20 | #import pandas 21 | #import anndata 22 | 23 | ``` 24 | 25 | ```{python} 26 | from utag import utag 27 | 28 | # Use Scanpy to get a h5ad file with provided data 29 | import scanpy as sc 30 | adata = sc.read( 31 | 'data/healthy_lung_adata.h5ad', 32 | backup_url='https://zenodo.org/record/6376767/files/healthy_lung_adata.h5ad?download=1') 33 | 34 | # Run UTAG on provided data 35 | utag_results = utag( 36 | adata, 37 | slide_key="roi", 38 | max_dist=20, 39 | normalization_mode='l1_norm', 40 | apply_clustering=True, 41 | clustering_method = 'leiden', 42 | resolutions = [0.05, 0.1, 0.3] 43 | ) 44 | ``` 45 | 46 | ## Including Plots 47 | 48 | You can also embed plots, for example: 49 | 50 | ```{r pressure, echo=FALSE} 51 | plot(pressure) 52 | ``` 53 | 54 | Note that the `echo = FALSE` parameter was added to the code chunk to prevent printing of the R code that generated the plot. 55 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: utag 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.01.10=h06a4308_0 8 | - certifi=2022.12.7=py38h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.4=h6a678d5_0 15 | - openssl=1.1.1t=h7f8727e_0 16 | - pip=22.3.1=py38h06a4308_0 17 | - python=3.8.10=h12debd9_8 18 | - readline=8.2=h5eee18b_0 19 | - setuptools=65.6.3=py38h06a4308_0 20 | - sqlite=3.40.1=h5082296_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.37.1=pyhd3eb1b0_0 23 | - xz=5.2.10=h5eee18b_1 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - aiohttp==3.8.4 27 | - aiosignal==1.3.1 28 | - anndata==0.8.0 29 | - asciitree==0.3.3 30 | - asttokens==2.2.1 31 | - async-timeout==4.0.2 32 | - attrs==22.2.0 33 | - backcall==0.2.0 34 | - blosc2==2.0.0 35 | - charset-normalizer==3.0.1 36 | - click==8.1.3 37 | - cloudpickle==2.2.1 38 | - contourpy==1.0.7 39 | - cycler==0.11.0 40 | - cython==0.29.33 41 | - dask==2023.2.0 42 | - dask-image==2022.9.0 43 | - decorator==5.1.1 44 | - docrep==0.3.2 45 | - dunamai==1.15.0 46 | - entrypoints==0.4 47 | - executing==1.2.0 48 | - fasteners==0.18 49 | - fonttools==4.38.0 50 | - frozenlist==1.3.3 51 | - fsspec==2023.1.0 52 | - get-version==3.5.4 53 | - h5py==3.8.0 54 | - hnswlib==0.7.0 55 | - idna==3.4 56 | - igraph==0.10.4 57 | - imageio==2.25.1 58 | - importlib-metadata==6.0.0 59 | - importlib-resources==5.10.2 60 | - inflect==6.0.2 61 | - ipython==8.10.0 62 | - jedi==0.18.2 63 | - joblib==1.2.0 64 | - kiwisolver==1.4.4 65 | - legacy-api-wrap==1.2 66 | - leidenalg==0.9.1 67 | - llvmlite==0.38.1 68 | - locket==1.0.0 69 | - matplotlib==3.6.0 70 | - matplotlib-inline==0.1.6 71 | - matplotlib-scalebar==0.8.1 72 | - msgpack==1.0.4 73 | - multidict==6.0.4 74 | - natsort==8.2.0 75 | - networkx==3.0 76 | - numba==0.55.2 77 | - numcodecs==0.11.0 78 | - numexpr==2.8.4 79 | - numpy==1.22.4 80 | - omnipath==1.0.6 81 | - packaging==23.0 82 | - pandas==1.5.3 83 | - parc==0.33 84 | - parmap==1.6.0 85 | - parso==0.8.3 86 | - partd==1.3.0 87 | - patsy==0.5.3 88 | - pexpect==4.8.0 89 | - pickleshare==0.7.5 90 | - pillow==9.4.0 91 | - pims==0.6.1 92 | - prompt-toolkit==3.0.36 93 | - ptyprocess==0.7.0 94 | - pure-eval==0.2.2 95 | - py-cpuinfo==9.0.0 96 | - pybind11==2.10.3 97 | - pydantic==1.10.5 98 | - pygments==2.14.0 99 | - pynndescent==0.5.8 100 | - pyparsing==3.0.9 101 | - python-dateutil==2.8.2 102 | - pytz==2022.7.1 103 | - pywavelets==1.4.1 104 | - pyyaml==6.0 105 | - requests==2.28.2 106 | - scanpy==1.9.1 107 | - scikit-image==0.19.3 108 | - scikit-learn==1.2.1 109 | - scipy==1.10.0 110 | - seaborn==0.12.2 111 | - session-info==1.0.0 112 | - sinfo==0.3.4 113 | - six==1.16.0 114 | - slicerator==1.1.0 115 | - squidpy==1.2.3 116 | - stack-data==0.6.2 117 | - statsmodels==0.13.5 118 | - stdlib-list==0.8.0 119 | - tables==3.8.0 120 | - texttable==1.6.7 121 | - threadpoolctl==3.1.0 122 | - tifffile==2023.2.3 123 | - toolz==0.12.0 124 | - tqdm==4.64.1 125 | - traitlets==5.9.0 126 | - typing-extensions==4.5.0 127 | - umap-learn==0.5.3 128 | - urllib3==1.26.14 129 | - utag==0.1.1.dev37+g6091543 130 | - validators==0.20.0 131 | - wcwidth==0.2.6 132 | - wrapt==1.14.1 133 | - xarray==2023.1.0 134 | - yarl==1.8.2 135 | - zarr==2.14.1 136 | - zipp==3.13.0 137 | prefix: /home/june/anaconda3/envs/utag 138 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # PIP, using PEP621 2 | [project] 3 | name = "utag" 4 | authors = [ 5 | {name = "Junbum Kim", email = "juk4007@med.cornell.edu"}, 6 | {name = "Andre Rendeiro", email = "afrendeiro@gmail.com"}, 7 | ] 8 | description = "Unsupervised discovery of tissue architechture with graphs (UTAG)" 9 | readme = "README.md" 10 | keywords = [ 11 | "computational biology", 12 | "bioinformatics", 13 | "imaging", 14 | "multiplexed imaging", 15 | ] 16 | classifiers = [ 17 | "Programming Language :: Python :: 3 :: Only", 18 | "Programming Language :: Python :: 3.7", 19 | "Programming Language :: Python :: 3.8", 20 | "Development Status :: 3 - Alpha", 21 | "Typing :: Typed", 22 | "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)", 23 | "Topic :: Scientific/Engineering :: Bio-Informatics", 24 | ] 25 | #license = "gpt3" 26 | requires-python = ">=3.5" 27 | # not yet supported by pip! 28 | dependencies = [ 29 | "numpy>=1.19", 30 | "pandas>=1.0.1", 31 | "anndata", 32 | "scanpy", 33 | "parc", 34 | "squidpy", 35 | "scipy>=1.6", 36 | "matplotlib>=3.4", 37 | "tqdm", 38 | "networkx>=2.4", 39 | "parmap", 40 | "scikit-learn" 41 | ] 42 | dynamic = ['version'] 43 | 44 | [project.optional-dependencies] 45 | # not yet supported by pip! 46 | dev = [ 47 | "ipython", 48 | "black[d]", 49 | "mypy", # pin to version supporting pyproject.toml 50 | "pylint", 51 | "git-lint", 52 | "pytest", 53 | "rich", 54 | # data-science-types 55 | "PyQt5", 56 | ] 57 | test = [ 58 | "pytest>=6", 59 | "pytest-cov", 60 | ] 61 | doc = [ 62 | "Sphinx", 63 | "sphinx-issues", 64 | "sphinx-rtd-theme", 65 | "sphinx-autodoc-typehints" 66 | ] 67 | 68 | [project.urls] 69 | homepage = "https://github.com/ElementoLab/utag" 70 | documentation = "https://github.com/ElementoLab/utag/blob/main/README.md" 71 | repository = "https://github.com/ElementoLab/utag" 72 | 73 | [build-system] 74 | # requires = ["poetry>=0.12", "setuptools>=45", "wheel", "poetry-dynamic-versioning"] 75 | # build-backend = "poetry.masonry.api" 76 | requires = ["setuptools==68.2.2", "wheel", "setuptools_scm[toml]>=6.0"] 77 | build-backend = "setuptools.build_meta" 78 | 79 | [tool.setuptools_scm] 80 | write_to = "utag/_version.py" 81 | write_to_template = 'version = __version__ = "{version}"' 82 | 83 | # Poetry 84 | [tool.poetry-dynamic-versioning] 85 | enable = true 86 | vcs = "git" 87 | style = "semver" 88 | 89 | [tool.poetry] 90 | name = "utag" 91 | version = "0.0.0" # waiting on next release of poetry to use dynamic-versioning extension 92 | description = "Unsupervised discovery of tissue architechture with graphs (UTAG)" 93 | authors = ["Junbum Kim ", "Andre Rendeiro "] 94 | license = "GPL-3.0-or-later" 95 | 96 | [tool.poetry.dependencies] 97 | python = "^3.8" 98 | numpy = "^1.19" 99 | pandas = "^1.0.1" 100 | scipy = "^1.6" 101 | scikit-image = "^1.18" 102 | matplotlib = "^3.4" 103 | networkx = "^2.4" 104 | tensorflow-gpu = "^2.4.1" 105 | 106 | [tool.poetry.dev-dependencies] 107 | ipython = "^7.16.1" 108 | pylint = "^2.5.3" 109 | git-lint = "^0.1.2" 110 | black = {extras = ["d"], version = "^19.10b0"} 111 | mypy = "^0.782" 112 | pytest = "^5.4.3" 113 | Sphinx = "^3.1.1" 114 | sphinx-issues = "^1.2.0" 115 | sphinx-rtd-theme = "^0.5.0" 116 | sphinx-autodoc-typehints = "^1.12.0" 117 | 118 | [tool.poetry.extras] 119 | 120 | 121 | [tool.black] 122 | line-length = 90 123 | target-version = ['py36'] 124 | include = '\.pyi?$' 125 | exclude = ''' 126 | 127 | ( 128 | /( 129 | \.eggs # exclude a few common directories in the 130 | | \.git # root of the project 131 | | \.hg 132 | | \.mypy_cache 133 | | \.tox 134 | | \.venv 135 | | _build 136 | | buck-out 137 | | build 138 | | dist 139 | )/ 140 | | foo.py # also separately exclude a file named foo.py in 141 | # the root of the project 142 | ) 143 | ''' 144 | 145 | [tool.mypy] 146 | python_version = '3.8' 147 | warn_return_any = true 148 | warn_unused_configs = true 149 | 150 | # Packages without type annotations in shed yet 151 | [[tool.mypy.overrides]] 152 | module = [ 153 | 'numpy.*', 154 | 'pandas.*', 155 | 'scipy.*', 156 | 'skimage.*', 157 | 'matplotlib.*', 158 | 'networkx.*', 159 | # 160 | 'utag.*' 161 | ] 162 | ignore_missing_imports = true 163 | 164 | [tool.pytest.ini_options] 165 | minversion = "6.0" 166 | addopts = "-ra -q --strict-markers" 167 | testpaths = [ 168 | "tests", 169 | "integration", 170 | ] 171 | markers = [ 172 | 'slow', # 'marks tests as slow (deselect with "-m 'not slow'")', 173 | 'serial' 174 | ] 175 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Remove once PIP supports reading project metadata from pyproject.toml 2 | [metadata] 3 | name = utag 4 | description = Unsupervised discovery of tissue architechture with graphs. 5 | version = attr: utag.__version__ 6 | long_description = file: README.md 7 | long_description_content_type = text/markdown 8 | classifiers = 9 | Programming Language :: Python :: 3 :: Only 10 | Programming Language :: Python :: 3.8 11 | Programming Language :: Python :: 3.9 12 | Development Status :: 4 - Beta 13 | Topic :: Scientific/Engineering :: Bio-Informatics 14 | keywords = science, bioinformatics, bioimage analysis, multiplexed imaging 15 | url = https://github.com/ElementoLab/utag 16 | project_urls = 17 | Bug Tracker = https://github.com/ElementoLab/utag/issues 18 | Documentation = https://github.com/ElementoLab/utag/blob/main/README.md 19 | Source Code = https://github.com/ElementoLab/utag 20 | 21 | author = Junbum Kim, Andre Rendeiro 22 | author_email = juk4007@med.cornell.edu, afrendeiro@gmail.com 23 | license = GPL3 24 | 25 | [options] 26 | install_requires = 27 | numpy>=1.19 28 | pandas>=1.0.1 29 | anndata 30 | scanpy 31 | parc 32 | squidpy 33 | scipy>=1.6 34 | matplotlib>=3.4 35 | tqdm 36 | networkx>=2.4 37 | parmap 38 | scikit-learn 39 | 40 | packages = find: 41 | -------------------------------------------------------------------------------- /utag/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from scipy.sparse import SparseEfficiencyWarning 4 | 5 | from utag.segmentation import utag 6 | 7 | 8 | try: 9 | # Even though there is no "imc/_version" file, 10 | # it should be generated by 11 | # setuptools_scm when building the package 12 | from utag._version import version 13 | 14 | __version__ = version 15 | except ImportError: 16 | from setuptools_scm import get_version as _get_version 17 | 18 | version = __version__ = _get_version(root="..", relative_to=__file__) 19 | 20 | warnings.simplefilter("ignore", FutureWarning) 21 | warnings.simplefilter("ignore", SparseEfficiencyWarning) 22 | -------------------------------------------------------------------------------- /utag/segmentation.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import warnings 3 | import os 4 | 5 | import scanpy as sc 6 | import squidpy as sq 7 | import numpy as np 8 | import pandas as pd 9 | import matplotlib.pyplot as plt 10 | from tqdm import tqdm 11 | import anndata 12 | import parmap 13 | 14 | from utag.types import Path, Array, AnnData 15 | from utag.utils import sparse_matrix_dstack 16 | 17 | 18 | def utag( 19 | adata: AnnData, 20 | channels_to_use: tp.Sequence[str] = None, 21 | slide_key: tp.Optional[str] = "Slide", 22 | save_key: str = "UTAG Label", 23 | filter_by_variance: bool = False, 24 | max_dist: float = 20.0, 25 | normalization_mode: str = "l1_norm", 26 | keep_spatial_connectivity: bool = False, 27 | pca_kwargs: tp.Dict[str, tp.Any] = dict(n_comps=10), 28 | apply_umap: bool = False, 29 | umap_kwargs: tp.Dict[str, tp.Any] = dict(), 30 | apply_clustering: bool = True, 31 | clustering_method: tp.Sequence[str] = ["leiden", "parc", "kmeans"], 32 | resolutions: tp.Sequence[float] = [0.05, 0.1, 0.3, 1.0], 33 | leiden_kwargs: tp.Dict[str, tp.Any] = None, 34 | parc_kwargs: tp.Dict[str, tp.Any] = None, 35 | parallel: bool = True, 36 | processes: int = None, 37 | ) -> AnnData: 38 | """ 39 | Discover tissue architechture in single-cell imaging data 40 | by combining phenotypes and positional information of cells. 41 | 42 | Parameters 43 | ---------- 44 | adata: AnnData 45 | AnnData object with spatial positioning of cells in obsm 'spatial' slot. 46 | channels_to_use: Optional[Sequence[str]] 47 | An optional sequence of strings used to subset variables to use. 48 | Default (None) is to use all variables. 49 | max_dist: float 50 | Maximum distance to cut edges within a graph. 51 | Should be adjusted depending on resolution of images. 52 | For imaging mass cytometry, where resolution is 1um, 20 often gives good results. 53 | Default is 20. 54 | slide_key: {str, None} 55 | Key of adata.obs containing information on the batch structure of the data. 56 | In general, for image data this will often be a variable indicating the image 57 | so image-specific effects are removed from data. 58 | Default is "Slide". 59 | save_key: str 60 | Key to be added to adata object holding the UTAG clusters. 61 | Depending on the values of `clustering_method` and `resolutions`, 62 | the final keys will be of the form: {save_key}_{method}_{resolution}". 63 | Default is "UTAG Label". 64 | filter_by_variance: bool 65 | Whether to filter vairiables by variance. 66 | Default is False, which keeps all variables. 67 | max_dist: float 68 | Recommended values are between 20 to 50 depending on magnification. 69 | Default is 20. 70 | normalization_mode: str 71 | Method to normalize adjacency matrix. 72 | Default is "l1_norm", any other value will not use normalization. 73 | keep_spatial_connectivity: bool 74 | Whether to keep sparse matrices of spatial connectivity and distance in the obsp attribute of the 75 | resulting anndata object. This could be useful in downstream applications. 76 | Default is not to (False). 77 | pca_kwargs: Dict[str, Any] 78 | Keyword arguments to be passed to scanpy.pp.pca for dimensionality reduction after message passing. 79 | Default is to pass n_comps=10, which uses 10 Principal Components. 80 | apply_umap: bool 81 | Whether to build a UMAP representation after message passing. 82 | Default is False. 83 | umap_kwargs: Dict[str, Any] 84 | Keyword arguments to be passed to scanpy.tl.umap for dimensionality reduction after message passing. 85 | Default is 10.0. 86 | apply_clustering: bool 87 | Whether to cluster the message passed matrix. 88 | Default is True. 89 | clustering_method: Sequence[str] 90 | Which clustering method(s) to use for clustering of the message passed matrix. 91 | Default is ["leiden", "parc"]. 92 | resolutions: Sequence[float] 93 | What resolutions should the methods in `clustering_method` be run at. 94 | Default is [0.05, 0.1, 0.3, 1.0]. 95 | leiden_kwargs: dict[str, Any] 96 | Keyword arguments to pass to scanpy.tl.leiden. 97 | parc_kwargs: dict[str, Any] 98 | Keyword arguments to pass to parc.PARC. 99 | parallel: bool 100 | Whether to run message passing part of algorithm in parallel. 101 | Will accelerate the process but consume more memory. 102 | Default is True. 103 | processes: int 104 | Number of processes to use in parallel. 105 | Default is to use all available (-1). 106 | 107 | Returns 108 | ------- 109 | adata: AnnData 110 | AnnData object with UTAG domain predictions for each cell in adata.obs, column `save_key`. 111 | """ 112 | ad = adata.copy() 113 | 114 | if channels_to_use: 115 | ad = ad[:, channels_to_use] 116 | 117 | if filter_by_variance: 118 | ad = low_variance_filter(ad) 119 | 120 | if isinstance(clustering_method, list): 121 | clustering_method = [m.upper() for m in clustering_method] 122 | elif isinstance(clustering_method, str): 123 | clustering_method = [clustering_method.upper()] 124 | else: 125 | print( 126 | "Invalid Clustering Method. Clustering Method Should Either be a string or a list" 127 | ) 128 | return 129 | assert all(m in ["LEIDEN", "PARC", "KMEANS"] for m in clustering_method) 130 | 131 | if "PARC" in clustering_method: 132 | from parc import PARC # early fail if not available 133 | if "KMEANS" in clustering_method: 134 | from sklearn.cluster import KMeans 135 | 136 | print("Applying UTAG Algorithm...") 137 | if slide_key: 138 | ads = [ 139 | ad[ad.obs[slide_key] == slide].copy() for slide in ad.obs[slide_key].unique() 140 | ] 141 | ad_list = parmap.map( 142 | _parallel_message_pass, 143 | ads, 144 | radius=max_dist, 145 | coord_type="generic", 146 | set_diag=True, 147 | mode=normalization_mode, 148 | pm_pbar=True, 149 | pm_parallel=parallel, 150 | pm_processes=processes, 151 | ) 152 | ad_result = anndata.concat(ad_list) 153 | if keep_spatial_connectivity: 154 | ad_result.obsp["spatial_connectivities"] = sparse_matrix_dstack( 155 | [x.obsp["spatial_connectivities"] for x in ad_list] 156 | ) 157 | ad_result.obsp["spatial_distances"] = sparse_matrix_dstack( 158 | [x.obsp["spatial_distances"] for x in ad_list] 159 | ) 160 | else: 161 | sq.gr.spatial_neighbors(ad, radius=max_dist, coord_type="generic", set_diag=True) 162 | ad_result = custom_message_passing(ad, mode=normalization_mode) 163 | 164 | if apply_clustering: 165 | if "n_comps" in pca_kwargs: 166 | if pca_kwargs["n_comps"] > ad_result.shape[1]: 167 | pca_kwargs["n_comps"] = ad_result.shape[1] - 1 168 | print( 169 | f"Overwriding provided number of PCA dimensions to match number of features: {pca_kwargs['n_comps']}" 170 | ) 171 | sc.tl.pca(ad_result, **pca_kwargs) 172 | sc.pp.neighbors(ad_result) 173 | 174 | if apply_umap: 175 | print("Running UMAP on Input Dataset...") 176 | sc.tl.umap(ad_result, **umap_kwargs) 177 | 178 | for resolution in tqdm(resolutions): 179 | 180 | res_key1 = save_key + "_leiden_" + str(resolution) 181 | res_key2 = save_key + "_parc_" + str(resolution) 182 | res_key3 = save_key + "_kmeans_" + str(resolution) 183 | if "LEIDEN" in clustering_method: 184 | print(f"Applying Leiden Clustering at Resolution: {resolution}...") 185 | kwargs = dict() 186 | kwargs.update(leiden_kwargs or {}) 187 | sc.tl.leiden( 188 | ad_result, resolution=resolution, key_added=res_key1, **kwargs 189 | ) 190 | add_probabilities_to_centroid(ad_result, res_key1) 191 | 192 | if "PARC" in clustering_method: 193 | from parc import PARC 194 | 195 | print(f"Applying PARC Clustering at Resolution: {resolution}...") 196 | 197 | kwargs = dict(random_seed=1, small_pop=1000) 198 | kwargs.update(parc_kwargs or {}) 199 | model = PARC( 200 | ad_result.obsm["X_pca"], 201 | neighbor_graph=ad_result.obsp["connectivities"], 202 | resolution_parameter=resolution, 203 | **kwargs, 204 | ) 205 | model.run_PARC() 206 | ad_result.obs[res_key2] = pd.Categorical(model.labels) 207 | ad_result.obs[res_key2] = ad_result.obs[res_key2].astype("category") 208 | add_probabilities_to_centroid(ad_result, res_key2) 209 | 210 | if "KMEANS" in clustering_method: 211 | print(f"Applying K-means Clustering at Resolution: {resolution}...") 212 | k = int(np.ceil(resolution * 10)) 213 | kmeans = KMeans(n_clusters=k, random_state=1).fit(ad_result.obsm["X_pca"]) 214 | ad_result.obs[res_key3] = pd.Categorical(kmeans.labels_.astype(str)) 215 | add_probabilities_to_centroid(ad_result, res_key3) 216 | 217 | return ad_result 218 | 219 | 220 | def _parallel_message_pass( 221 | ad: AnnData, 222 | radius: int, 223 | coord_type: str, 224 | set_diag: bool, 225 | mode: str, 226 | ): 227 | sq.gr.spatial_neighbors(ad, radius=radius, coord_type=coord_type, set_diag=set_diag) 228 | ad = custom_message_passing(ad, mode=mode) 229 | return ad 230 | 231 | 232 | def custom_message_passing(adata: AnnData, mode: str = "l1_norm") -> AnnData: 233 | # from scipy.linalg import sqrtm 234 | # import logging 235 | if mode == "l1_norm": 236 | A = adata.obsp["spatial_connectivities"] 237 | from sklearn.preprocessing import normalize 238 | affinity = normalize(A, axis=1, norm="l1") 239 | else: 240 | # Plain A_mod multiplication 241 | A = adata.obsp["spatial_connectivities"] 242 | affinity = A 243 | # logging.info(type(affinity)) 244 | adata.X = affinity @ adata.X 245 | return adata 246 | 247 | 248 | def low_variance_filter(adata: AnnData) -> AnnData: 249 | return adata[:, adata.var["std"] > adata.var["std"].median()] 250 | 251 | 252 | def add_probabilities_to_centroid( 253 | adata: AnnData, col: str, name_to_output: str = None 254 | ) -> AnnData: 255 | from utag.utils import z_score 256 | from scipy.special import softmax 257 | 258 | if name_to_output is None: 259 | name_to_output = col + "_probabilities" 260 | 261 | mean = z_score(adata.to_df()).groupby(adata.obs[col]).mean() 262 | probs = softmax(adata.to_df() @ mean.T, axis=1) 263 | adata.obsm[name_to_output] = probs 264 | return adata 265 | 266 | 267 | def evaluate_performance( 268 | adata: AnnData, 269 | batch_key: str = "Slide", 270 | truth_key: str = "DOM_argmax", 271 | pred_key: str = "cluster", 272 | method: str = "rand", 273 | ) -> Array: 274 | assert method in ["rand", "homogeneity"] 275 | from sklearn.metrics import rand_score, homogeneity_score 276 | 277 | score_list = [] 278 | for key in adata.obs[batch_key].unique(): 279 | batch = adata[adata.obs[batch_key] == key] 280 | if method == "rand": 281 | score = rand_score(batch.obs[truth_key], batch.obs[pred_key]) 282 | elif method == "homogeneity": 283 | score = homogeneity_score(batch.obs[truth_key], batch.obs[pred_key]) 284 | score_list.append(score) 285 | return score_list 286 | -------------------------------------------------------------------------------- /utag/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ElementoLab/utag/545f7194fad2ef9a292e12cc4c88a2bf91eed082/utag/tests/__init__.py -------------------------------------------------------------------------------- /utag/tests/utag_test.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import pytest 4 | import numpy as np 5 | import scanpy as sc 6 | from anndata import AnnData 7 | import scipy 8 | 9 | from utag import utag 10 | 11 | 12 | kwargs = dict( 13 | slide_key="roi", max_dist=20, normalization_mode="l1_norm", apply_clustering=True 14 | ) 15 | 16 | 17 | @pytest.fixture 18 | def adata() -> AnnData: 19 | return sc.read( 20 | "data/healthy_lung_adata.h5ad", 21 | backup_url="https://zenodo.org/record/6376767/files/healthy_lung_adata.h5ad?download=1", 22 | ) 23 | 24 | 25 | def check( 26 | utag_results: AnnData, 27 | probabilities: tp.Sequence[float], 28 | n: int, 29 | clustering: tp.Sequence[str], 30 | ) -> None: 31 | assert utag_results.obs.columns.str.contains("UTAG Label").any() 32 | for cluster in clustering: 33 | for prob in probabilities: 34 | col = f"UTAG Label_{cluster}_{prob}_probabilities" 35 | assert col in utag_results.obsm 36 | assert utag_results.obsm[col].shape[0] == n 37 | assert np.allclose(utag_results.obsm[col].sum(1), [1] * n) 38 | 39 | 40 | def test_subsample_serial(adata: AnnData) -> None: 41 | n = 10_000 42 | clustering = ["leiden", "parc"] 43 | utag_results = utag( 44 | adata[:n], 45 | **kwargs, 46 | clustering_method=clustering, 47 | parc_kwargs=dict(small_pop=10), 48 | resolutions=[0.3], 49 | parallel=False, 50 | ) 51 | check(utag_results, [0.3], n, clustering) 52 | 53 | 54 | def test_subsample_parallel(adata: AnnData) -> None: 55 | n = 10_000 56 | clustering = ["leiden", "parc"] 57 | utag_results = utag( 58 | adata[:n], 59 | **kwargs, 60 | clustering_method=clustering, 61 | parc_kwargs=dict(small_pop=10), 62 | resolutions=[0.3], 63 | parallel=True, 64 | ) 65 | check(utag_results, [0.3], n, clustering) 66 | 67 | 68 | def test_full(adata: AnnData) -> None: 69 | probabilities = [0.05, 0.1, 0.3] 70 | n = adata.shape[0] 71 | utag_results = utag( 72 | adata, 73 | **kwargs, 74 | clustering_method="leiden", 75 | resolutions=probabilities, 76 | ) 77 | check(utag_results, probabilities, adata.shape[0], ["leiden"]) 78 | 79 | 80 | def test_subsample_keep_spatial(adata: AnnData) -> None: 81 | n = 10_000 82 | clustering = ["parc"] 83 | utag_results = utag( 84 | adata[:n], 85 | **kwargs, 86 | clustering_method=clustering, 87 | parc_kwargs=dict(small_pop=10), 88 | keep_spatial_connectivity=True, 89 | resolutions=[0.3], 90 | parallel=False, 91 | ) 92 | check(utag_results, [0.3], n, clustering) 93 | assert "spatial_connectivities" in utag_results.obsp 94 | assert "spatial_distances" in utag_results.obsp 95 | assert utag_results.obsp["spatial_connectivities"].shape == (n, n) 96 | assert utag_results.obsp["spatial_distances"].shape == (n, n) 97 | assert ( 98 | scipy.sparse.csr_matrix.diagonal(utag_results.obsp["spatial_connectivities"]) == 1 99 | ).all() 100 | assert not scipy.sparse.csr_matrix.diagonal( 101 | utag_results.obsp["spatial_distances"] 102 | ).any() 103 | -------------------------------------------------------------------------------- /utag/types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Specific data types used for type annotations in the package. 5 | """ 6 | 7 | from __future__ import annotations 8 | import os 9 | import typing as tp 10 | import pathlib 11 | 12 | 13 | import numpy 14 | import pandas 15 | import anndata 16 | import networkx 17 | import matplotlib 18 | from matplotlib.figure import Figure as _Figure 19 | 20 | 21 | __all__ = [ 22 | "Array", 23 | "Graph", 24 | "DataFrame", 25 | "Figure", 26 | "Axis", 27 | "Path", 28 | "AnnData", 29 | ] 30 | 31 | 32 | class Path(pathlib.Path): 33 | """ 34 | A pathlib.Path child class that allows concatenation with strings 35 | by overloading the addition operator. 36 | 37 | In addition, it implements the ``startswith`` and ``endswith`` methods 38 | just like in the base :obj:`str` type. 39 | 40 | The ``replace_`` implementation is meant to be an implementation closer 41 | to the :obj:`str` type. 42 | 43 | Iterating over a directory with ``iterdir`` that does not exists 44 | will return an empty iterator instead of throwing an error. 45 | 46 | Creating a directory with ``mkdir`` allows existing directory and 47 | creates parents by default. 48 | """ 49 | 50 | _flavour = ( 51 | pathlib._windows_flavour # type: ignore[attr-defined] # pylint: disable=W0212 52 | if os.name == "nt" 53 | else pathlib._posix_flavour # type: ignore[attr-defined] # pylint: disable=W0212 54 | ) 55 | 56 | def __add__(self, string: str) -> Path: 57 | return Path(str(self) + string) 58 | 59 | def startswith(self, string: str) -> bool: 60 | return str(self).startswith(string) 61 | 62 | def endswith(self, string: str) -> bool: 63 | return str(self).endswith(string) 64 | 65 | def replace_(self, patt: str, repl: str) -> Path: 66 | return Path(str(self).replace(patt, repl)) 67 | 68 | def iterdir(self) -> tp.Generator: 69 | if self.exists(): 70 | yield from [Path(x) for x in pathlib.Path(str(self)).iterdir()] 71 | yield from [] 72 | 73 | def mkdir(self, mode=0o777, parents: bool = True, exist_ok: bool = True) -> Path: 74 | super().mkdir(mode=mode, parents=parents, exist_ok=exist_ok) 75 | return self 76 | 77 | 78 | Array = tp.Union[numpy.ndarray] 79 | Graph = tp.Union[networkx.Graph] 80 | 81 | DataFrame = tp.Union[pandas.DataFrame] 82 | AnnData = tp.Union[anndata.AnnData] 83 | 84 | Figure = tp.Union[_Figure] 85 | Axis = tp.Union[matplotlib.axis.Axis] 86 | -------------------------------------------------------------------------------- /utag/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Helper functions used throughout the package. 5 | """ 6 | 7 | import typing as tp 8 | 9 | import numpy as np 10 | import scipy 11 | import pandas as pd 12 | import networkx as nx 13 | 14 | from utag.types import Array, Graph, DataFrame, Path, AnnData 15 | 16 | 17 | def domain_connectivity( 18 | adata: AnnData, 19 | slide_key: str = 'Slide', 20 | domain_key: str = 'UTAG Label', 21 | ) -> AnnData: 22 | import squidpy as sq 23 | import numpy as np 24 | from tqdm import tqdm 25 | 26 | order = sorted(adata.obs[domain_key].unique().tolist()) 27 | 28 | global_pairwise_connection = pd.DataFrame(np.zeros(shape = (len(order),len(order))), index = order, columns = order) 29 | for slide in tqdm(adata.obs[slide_key].unique()): 30 | adata_batch = adata[adata.obs[slide_key] == slide].copy() 31 | 32 | sq.gr.spatial_neighbors(adata_batch, radius = 40, coord_type = 'generic') 33 | 34 | pairwise_connection = pd.DataFrame(index = order, columns = order) 35 | for label in adata_batch.obs[domain_key].unique(): 36 | self_connection = adata_batch[adata_batch.obs[domain_key] == label].obsp['spatial_connectivities'].todense().sum()/2 37 | self_connection = self_connection.round() 38 | 39 | pairwise_connection.loc[label, label] = self_connection 40 | 41 | for label in adata_batch.obs[domain_key].unique(): 42 | for label2 in adata_batch.obs[domain_key].unique(): 43 | if label != label2: 44 | pairwise = adata_batch[adata_batch.obs[domain_key].isin([label, label2])].obsp['spatial_connectivities'].todense().sum()/2 45 | pairwise = pairwise.round() 46 | pairwise_connection.loc[label, label2] = pairwise - pairwise_connection.loc[label, label] - pairwise_connection.loc[label2, label2] 47 | pairwise_connection.loc[label2, label] = pairwise_connection.loc[label, label2] 48 | 49 | pairwise_connection = pairwise_connection.fillna(0) 50 | global_pairwise_connection = global_pairwise_connection + pairwise_connection 51 | adata.uns[f'{domain_key}_domain_adjacency_matrix'] = global_pairwise_connection 52 | return adata 53 | 54 | def celltype_connectivity( 55 | adata: AnnData, 56 | slide_key: str = 'Slide', 57 | domain_key: str = 'UTAG Label', 58 | celltype_key: str = 'cluster_0.5_label', 59 | ) -> AnnData: 60 | import squidpy as sq 61 | import numpy as np 62 | from tqdm import tqdm 63 | 64 | global_pairwise_utag = dict() 65 | for label in adata.obs[domain_key].unique(): 66 | cell_types = adata.obs[celltype_key].unique().tolist() 67 | global_pairwise_utag[label] = pd.DataFrame(np.zeros(shape = (len(cell_types),len(cell_types))), index = cell_types, columns = cell_types) 68 | 69 | for slide in tqdm(adata.obs[slide_key].unique()): 70 | adata_batch = adata[adata.obs[slide_key] == slide].copy() 71 | sq.gr.spatial_neighbors(adata_batch, radius = 40, coord_type = 'generic') 72 | 73 | for label in adata.obs[domain_key].unique(): 74 | adata_batch2 = adata_batch[adata_batch.obs[domain_key] == label].copy() 75 | pairwise_connection = pd.DataFrame(index = cell_types, columns = cell_types) 76 | 77 | for cell_type1 in adata_batch2.obs[celltype_key].unique(): 78 | self_connection = adata_batch2[adata_batch2.obs[celltype_key] == cell_type1].obsp['spatial_connectivities'].todense().sum()/2 79 | self_connection = self_connection.round() 80 | 81 | pairwise_connection.loc[cell_type1, cell_type1] = self_connection 82 | 83 | for cell_type1 in adata_batch.obs[celltype_key].unique(): 84 | for cell_type2 in adata_batch2.obs[celltype_key].unique(): 85 | if cell_type1 != cell_type2: 86 | pairwise = adata_batch2[adata_batch2.obs[celltype_key].isin([cell_type1, cell_type2])].obsp['spatial_connectivities'].todense().sum()/2 87 | pairwise = pairwise.round() 88 | pairwise_connection.loc[cell_type1, cell_type2] = pairwise - pairwise_connection.loc[cell_type1, cell_type1] - pairwise_connection.loc[cell_type2, cell_type2] 89 | pairwise_connection.loc[cell_type2, cell_type1] = pairwise_connection.loc[cell_type1, cell_type2] 90 | 91 | pairwise_connection = pairwise_connection.fillna(0) 92 | global_pairwise_utag[label] = global_pairwise_utag[label] + pairwise_connection 93 | 94 | adata.uns[f'{domain_key}_celltype_adjacency_matrix'] = global_pairwise_utag 95 | return adata 96 | 97 | 98 | def slide_connectivity( 99 | adata: AnnData, 100 | slide_key: str = 'roi', 101 | domain_key: str = 'UTAG Label', 102 | ) -> dict(): 103 | import squidpy as sq 104 | import numpy as np 105 | from tqdm import tqdm 106 | 107 | order = sorted(adata.obs[domain_key].unique().tolist()) 108 | slide_connection = dict() 109 | 110 | for slide in tqdm(adata.obs[slide_key].unique()): 111 | adata_batch = adata[adata.obs[slide_key] == slide].copy() 112 | 113 | sq.gr.spatial_neighbors(adata_batch, radius = 40, coord_type = 'generic') 114 | 115 | pairwise_connection = pd.DataFrame(index = order, columns = order) 116 | for label in adata_batch.obs[domain_key].unique(): 117 | self_connection = adata_batch[adata_batch.obs[domain_key] == label].obsp['spatial_connectivities'].todense().sum()/2 118 | self_connection = self_connection.round() 119 | 120 | pairwise_connection.loc[label, label] = self_connection 121 | 122 | for label in adata_batch.obs[domain_key].unique(): 123 | for label2 in adata_batch.obs[domain_key].unique(): 124 | if label != label2: 125 | pairwise = adata_batch[adata_batch.obs[domain_key].isin([label, label2])].obsp['spatial_connectivities'].todense().sum()/2 126 | pairwise = pairwise.round() 127 | pairwise_connection.loc[label, label2] = pairwise - pairwise_connection.loc[label, label] - pairwise_connection.loc[label2, label2] 128 | pairwise_connection.loc[label2, label] = pairwise_connection.loc[label, label2] 129 | 130 | pairwise_connection = pairwise_connection.fillna(0) 131 | pairwise_connection = pairwise_connection.loc[(pairwise_connection!=0).any(1), (pairwise_connection!=0).any(0)] 132 | #pairwise_connection = pairwise_connection.dropna(axis = 1) 133 | slide_connection[slide] = pairwise_connection 134 | 135 | return slide_connection 136 | 137 | def measure_per_domain_cell_type_colocalization( 138 | adata: AnnData, 139 | utag_key: str = "UTAG Label", 140 | max_dist: int = 40, 141 | n_iterations: int = 100, 142 | ): 143 | import squidpy as sq 144 | a_ = adata.copy() 145 | sq.gr.spatial_neighbors(a_, radius=max_dist, coord_type="generic") 146 | 147 | G = nx.from_scipy_sparse_matrix(a_.obsp["spatial_connectivities"]) 148 | 149 | utag_map = {i: x for i, x in enumerate(adata.obs[utag_key])} 150 | nx.set_node_attributes(G, utag_map, name=utag_key) 151 | 152 | adj, order = nx.linalg.attrmatrix.attr_matrix(G, node_attr=utag_key) 153 | order = pd.Series(order).astype(adata.obs[utag_key].dtype) 154 | freqs = pd.DataFrame(adj, order, order).fillna(0) + 1 155 | 156 | norm_freqs = correct_interaction_background_random(G, freqs, utag_key, n_iterations) 157 | return norm_freqs 158 | 159 | 160 | def correct_interaction_background_random( 161 | graph: nx.Graph, freqs: pd.DataFrame, attribute: str, n_iterations: int = 100 162 | ): 163 | values = {x: graph.nodes[x][attribute] for x in graph.nodes} 164 | shuffled_freqs = list() 165 | for _ in range(n_iterations): 166 | g2 = graph.copy() 167 | shuffled_attr = pd.Series(values).sample(frac=1) 168 | shuffled_attr.index = values 169 | nx.set_node_attributes(g2, shuffled_attr.to_dict(), name=attribute) 170 | rf, rl = nx.linalg.attrmatrix.attr_matrix(g2, node_attr=attribute) 171 | rl = pd.Series(rl, dtype=freqs.index.dtype) 172 | shuffled_freqs.append(pd.DataFrame(rf, index=rl, columns=rl))#.fillna(0) + 1) 173 | shuffled_freq = pd.concat(shuffled_freqs) 174 | shuffled_freq = shuffled_freq.groupby(level=0).sum() 175 | shuffled_freq = shuffled_freq.fillna(0) + 1 176 | 177 | fl = np.log((freqs / freqs.values.sum())) 178 | sl = np.log((shuffled_freq / shuffled_freq.values.sum())) 179 | # make sure both contain all edges/nodes 180 | fl = fl.reindex(sl.index, axis=0).reindex(sl.index, axis=1) 181 | sl = sl.reindex(fl.index, axis=0).reindex(fl.index, axis=1) 182 | return fl - sl 183 | 184 | 185 | def evaluate_clustering( 186 | adata: AnnData, 187 | cluster_keys: Array, 188 | celltype_label: str = 'celltype', 189 | slide_key: str = 'roi', 190 | metrics: Array = ['entropy', 'cluster_number', 'silhouette_score', 'connectivity'] 191 | ) -> DataFrame: 192 | 193 | if type(cluster_keys) == str: 194 | cluster_keys = [cluster_keys] 195 | if type(metrics) == str: 196 | metrics = [metrics] 197 | 198 | cluster_loss = pd.DataFrame(index = metrics, columns = cluster_keys) 199 | from tqdm import tqdm 200 | 201 | for metric in metrics: 202 | print(f'Evaluating Cluster {metric}') 203 | for cluster in tqdm(cluster_keys): 204 | assert(metric in ['entropy', 'cluster_number', 'silhouette_score', 'connectivity']) 205 | 206 | if metric == 'entropy': 207 | from scipy.stats import entropy 208 | distribution = adata.obs.groupby([celltype_label, cluster]).count()[slide_key].reset_index().pivot(index = cluster, columns = celltype_label, values = slide_key) 209 | cluster_entropy = distribution.apply(entropy, axis = 1).sort_values().mean() 210 | 211 | cluster_loss.loc[metric, cluster] = cluster_entropy 212 | elif metric == 'cluster_number': 213 | 214 | cluster_loss.loc[metric, cluster] = len(adata.obs[cluster].unique()) 215 | elif metric == 'silhouette_score': 216 | 217 | from sklearn.metrics import silhouette_score 218 | cluster_loss.loc[metric, cluster] = silhouette_score(adata.X, labels = adata.obs[cluster]) 219 | elif metric == 'connectivity': 220 | global_pairwise_connection = domain_connectivity(adata = adata, slide_key = slide_key, domain_key = cluster) 221 | inter_spatial_connectivity = np.log(np.diag(global_pairwise_connection).sum() / (global_pairwise_connection.sum().sum() - np.diag(global_pairwise_connection).sum())) 222 | 223 | cluster_loss.loc[metric, cluster] = inter_spatial_connectivity 224 | return cluster_loss 225 | 226 | def to_uint(x: Array, base: int = 8) -> Array: 227 | return (x * (2 ** base - 1)).astype(f"uint{base}") 228 | 229 | 230 | def to_float(x: Array, base: int = 32) -> Array: 231 | return (x / x.max()).astype(f"float{base}") 232 | 233 | 234 | def open_image_with_tf(filename: str, file_type="png"): 235 | import tensorflow as tf 236 | 237 | img = tf.io.read_file(filename) 238 | return tf.io.decode_image(img, file_type) 239 | 240 | 241 | def filter_kwargs( 242 | kwargs: tp.Dict[str, tp.Any], callabl: tp.Callable, exclude: bool = None 243 | ) -> tp.Dict[str, tp.Any]: 244 | from inspect import signature 245 | 246 | args = signature(callabl).parameters.keys() 247 | if "kwargs" in args: 248 | return kwargs 249 | return {k: v for k, v in kwargs.items() if (k in args) and k not in (exclude or [])} 250 | 251 | 252 | def array_to_graph( 253 | arr: Array, 254 | max_dist: int = 5, 255 | node_attrs: tp.Mapping[int, tp.Mapping[str, tp.Union[str, int, float]]] = None, 256 | ) -> Graph: 257 | """ 258 | Generate a Graph of object distance-based connectivity in euclidean space. 259 | 260 | Parameters 261 | ---------- 262 | arr: np.ndarray 263 | Labeled array. 264 | """ 265 | mask = arr > 0 266 | idx = arr[mask] 267 | xx, yy = np.mgrid[: arr.shape[0], : arr.shape[1]] 268 | arri = np.stack([xx[mask], yy[mask]]).T 269 | dists = pd.DataFrame(scipy.spatial.distance.cdist(arri, arri), index=idx, columns=idx) 270 | np.fill_diagonal(dists.values, np.nan) 271 | 272 | attrs = dists[dists <= max_dist].reset_index().melt(id_vars="index").dropna() 273 | attrs.index = attrs.iloc[:, :2].apply(tuple, axis=1).tolist() 274 | value = attrs["value"] 275 | g = nx.from_edgelist(attrs.index) 276 | nx.set_edge_attributes(g, value.to_dict(), "distance") 277 | nx.set_edge_attributes(g, (1 / value).to_dict(), "connectivity") 278 | 279 | if node_attrs is not None: 280 | nx.set_node_attributes(g, node_attrs) 281 | 282 | return g 283 | 284 | def compute_and_draw_network( 285 | adata, 286 | slide_key: str = 'roi', 287 | node_key: str = 'UTAG Label', 288 | figsize: tuple = (11,11), 289 | dpi: int = 100, 290 | font_size: int = 12, 291 | node_size_min: int = 1000, 292 | node_size_max: int = 3000, 293 | edge_weight: float = 10, 294 | log_transform: bool = True, 295 | ax = None 296 | ) -> nx.Graph: 297 | from utag.utils import domain_connectivity 298 | import networkx as nx 299 | import matplotlib.pyplot as plt 300 | 301 | adjacency_matrix = domain_connectivity(adata = adata, slide_key = slide_key, domain_key = node_key) 302 | s1 = adata.obs.groupby(node_key).count() 303 | s1 = s1[s1.columns[0]] 304 | node_size = s1.values 305 | node_size = (node_size - node_size.min()) / (node_size.max() - node_size.min()) * (node_size_max - node_size_min) + node_size_min 306 | 307 | if ax == None: 308 | fig = plt.figure(figsize = figsize, dpi = dpi) 309 | G = nx.from_numpy_matrix(np.matrix(adjacency_matrix), create_using=nx.Graph) 310 | G = nx.relabel.relabel_nodes(G, {i: label for i, label in enumerate(adjacency_matrix.index)}) 311 | pos = nx.circular_layout(G) 312 | 313 | edges, weights = zip(*nx.get_edge_attributes(G,'weight').items()) 314 | if log_transform: 315 | weights = np.log(np.array(list(weights))+1) 316 | else: 317 | weights = np.array(list(weights)) 318 | weights = (weights - weights.min()) / (weights.max() - weights.min()) * edge_weight + 0.2 319 | weights = tuple(weights.tolist()) 320 | 321 | if ax: 322 | nx.draw(G, pos, node_color='w', edgelist=edges, edge_color=weights, width=weights, edge_cmap=plt.cm.YlOrRd, with_labels=True, font_size = font_size, node_size = node_size, ax = ax) 323 | else: 324 | nx.draw(G, pos, node_color='w', edgelist=edges, edge_color=weights, width=weights, edge_cmap=plt.cm.YlOrRd, with_labels=True, font_size = font_size, node_size = node_size) 325 | #nx.draw(G, pos, cmap = plt.cm.tab10, node_color = range(8), edgelist=edges, edge_color=weights, width=3, edge_cmap=plt.cm.coolwarm, with_labels=True, font_size = 14, node_size = 1000) 326 | 327 | if ax == None: 328 | ax = plt.gca() 329 | 330 | color_key = node_key + '_colors' 331 | if color_key in adata.uns: 332 | ax.collections[0].set_edgecolor(adata.uns[color_key]) 333 | else: 334 | ax.collections[0].set_edgecolor('lightgray') 335 | ax.collections[0].set_linewidth(3) 336 | ax.set_xlim([1.1*x for x in ax.get_xlim()]) 337 | ax.set_ylim([1.1*y for y in ax.get_ylim()]) 338 | 339 | return G 340 | 341 | def get_adjacency_matrix(g: Graph) -> Array: 342 | return nx.adjacency_matrix(g, weight="connectivity").todense() 343 | 344 | 345 | def get_feature_matrix(g: Graph) -> DataFrame: 346 | return pd.DataFrame({n: g.nodes[n] for n in g.nodes}).T 347 | 348 | 349 | def message_pass_graph(adj: Array, feat: DataFrame) -> DataFrame: 350 | return (adj @ feat).set_index(feat.index) 351 | 352 | 353 | def pad_feature_matrix(df: DataFrame, size: int) -> DataFrame: 354 | index = df.index.tolist() + (df.index.max() + np.arange(size - df.shape[0])).tolist() 355 | return pd.DataFrame( 356 | np.pad( 357 | df.values, 358 | [(0, size - df.shape[0]), (0, 0)], 359 | ), 360 | index=index, 361 | columns=df.columns, 362 | ) 363 | 364 | 365 | def pad_adjacency_matrix(mat: DataFrame, size: int) -> DataFrame: 366 | return np.pad( 367 | mat, 368 | [ 369 | (0, size - mat.shape[0]), 370 | (0, size - mat.shape[0]), 371 | ], 372 | ) 373 | 374 | 375 | def message_pass_graphs(gs: tp.Sequence[Graph]) -> Array: 376 | n = max([len(g) for g in gs]) 377 | _adjs = list() 378 | _feats = list() 379 | for g in gs: 380 | adj = get_adjacency_matrix(g) 381 | adj = pad_adjacency_matrix(adj, n) 382 | _adjs.append(adj) 383 | feat = get_feature_matrix(g) 384 | feat = pad_feature_matrix(feat, n) 385 | _feats.append(feat) 386 | adjs = np.stack(_adjs) 387 | feats = np.stack(_feats).astype(float) 388 | 389 | return adjs @ feats 390 | 391 | 392 | def mask_to_labelme( 393 | labeled_image: Array, 394 | filename: Path, 395 | overwrite: bool = False, 396 | simplify: bool = True, 397 | simplification_threshold: float = 5.0, 398 | ) -> None: 399 | import io 400 | import base64 401 | import json 402 | 403 | import imageio 404 | import tifffile 405 | from imantics import Mask 406 | from shapely.geometry import Polygon 407 | 408 | output_file = filename.replace_(".tif", ".json") 409 | if overwrite or output_file.exists(): 410 | return 411 | polygons = Mask(labeled_image).polygons() 412 | shapes = list() 413 | for point in polygons.points: 414 | 415 | if not simplify: 416 | poly = np.asarray(point).tolist() 417 | else: 418 | poly = np.asarray( 419 | Polygon(point).simplify(simplification_threshold).exterior.coords.xy 420 | ).T.tolist() 421 | shape = { 422 | "label": "A", 423 | "points": poly, 424 | "group_id": None, 425 | "shape_type": "polygon", 426 | "flags": {}, 427 | } 428 | shapes.append(shape) 429 | 430 | f = io.BytesIO() 431 | imageio.imwrite(f, tifffile.imread(filename), format="PNG") 432 | f.seek(0) 433 | encoded = base64.encodebytes(f.read()) 434 | 435 | payload = { 436 | "version": "4.5.6", 437 | "flags": {}, 438 | "shapes": shapes, 439 | "imagePath": filename.name, 440 | "imageData": encoded.decode("ascii"), 441 | "imageHeight": labeled_image.shape[0], 442 | "imageWidth": labeled_image.shape[1], 443 | } 444 | with open(output_file.as_posix(), "w") as fp: 445 | json.dump(payload, fp, indent=2) 446 | 447 | 448 | def z_score(x: Array) -> Array: 449 | """ 450 | Scale (divide by standard deviation) and center (subtract mean) array-like objects. 451 | """ 452 | return (x - x.min()) / (x.max() - x.min()) 453 | 454 | 455 | def sparse_matrix_dstack( 456 | matrices: tp.Sequence[scipy.sparse.csr_matrix], 457 | ) -> scipy.sparse.csr_matrix: 458 | """ 459 | Diagonally stack sparse matrices. 460 | """ 461 | import scipy 462 | from tqdm import tqdm 463 | 464 | n = sum([x.shape[0] for x in matrices]) 465 | _res = list() 466 | i = 0 467 | for x in tqdm(matrices): 468 | v = scipy.sparse.csr_matrix((x.shape[0], n)) 469 | v[:, i : i + x.shape[0]] = x 470 | _res.append(v) 471 | i += x.shape[0] 472 | return scipy.sparse.vstack(_res) 473 | -------------------------------------------------------------------------------- /utag/vizualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import scanpy as sc 4 | import numpy as np 5 | import pandas as pd 6 | import matplotlib.pyplot as plt 7 | import anndata 8 | import holoviews as hv 9 | from holoviews import opts, dim 10 | 11 | from utag.types import Path, Array, AnnData, DataFrame 12 | 13 | 14 | def add_spatial_image( 15 | adata: AnnData, 16 | image_path: Path, 17 | rgb_channels = [19, 9, 14], 18 | log_transform: bool = False, 19 | median_filter: bool = False, 20 | scale_method: str = 'adjust_gamma', 21 | contrast_percentile = (0, 90), 22 | gamma: float = 0.2, 23 | gain: float = 0.5 24 | ): 25 | 26 | adata.obsm['spatial'] = adata.obs[['Y_centroid', 'X_centroid']].to_numpy() 27 | adata.uns["spatial"] = {'image': {}} 28 | adata.uns["spatial"]['image']["images"] = {} 29 | 30 | img = rgbfy_multiplexed_image( 31 | image_path = image_path, 32 | rgb_channels = rgb_channels, 33 | contrast_percentile = contrast_percentile, 34 | log_transform = log_transform, 35 | median_filter = median_filter, 36 | scale_method = scale_method, 37 | gamma = gamma, 38 | gain = gain 39 | ) 40 | 41 | 42 | adata.uns["spatial"]['image']["images"] = {"hires": img} 43 | adata.uns["spatial"]['image']["scalefactors"] = {"tissue_hires_scalef": 1, "spot_diameter_fullres": 1} 44 | return adata 45 | 46 | def add_scale_box_to_fig( 47 | img: Array, 48 | ax, 49 | box_width: int = 100, 50 | box_height: float = 3, 51 | color: str = 'white' 52 | ) -> Array: 53 | import matplotlib.patches as patches 54 | x = img.shape[1] 55 | y = img.shape[0] 56 | 57 | # Create a Rectangle patch 58 | rect = patches.Rectangle((x - box_width, y * (1-box_height/100)), box_width, y * (box_height/100), linewidth=0.1, edgecolor='black', facecolor=color) 59 | 60 | # Add the patch to the Axes 61 | ax.add_patch(rect) 62 | return ax 63 | 64 | def rgbfy_multiplexed_image( 65 | image_path: Path, 66 | rgb_channels = [19, 9, 14], 67 | log_transform: bool = True, 68 | median_filter: bool = True, 69 | scale_method: str = 'adjust_gamma', 70 | contrast_percentile = (10, 90), 71 | gamma: float = 0.4, 72 | gain: float = 1 73 | ) -> Array: 74 | from skimage.exposure import rescale_intensity, adjust_gamma, equalize_hist 75 | from scipy.ndimage import median_filter as mf 76 | import tifffile 77 | 78 | def rescale(img, contrast_percentile): 79 | r1, r2 = np.percentile(img, contrast_percentile) 80 | img = rescale_intensity(img, in_range = (r1, r2), out_range = (0,1)) 81 | return img 82 | #assert(len(rgb_channels) == 3 or len(rgb_channels) == 1) 83 | 84 | img = tifffile.imread(image_path) 85 | img = img.astype(np.float32) 86 | if median_filter == True: 87 | img = mf(img, size = 3) 88 | 89 | image_to_save = np.stack([img[x] for x in rgb_channels], axis = 2) 90 | 91 | for i in range(len(rgb_channels)): 92 | if log_transform == True: 93 | image_to_save[:,:,i] = np.log(image_to_save[:,:,i] + 1) 94 | else: 95 | image_to_save[:,:,i] = image_to_save[:,:,i] 96 | 97 | output_img = image_to_save 98 | 99 | for i in range(3): 100 | if scale_method == 'contrast_stretch': 101 | output_img[:,:,i] = rescale(output_img[:,:,i], contrast_percentile) 102 | elif scale_method == 'adjust_gamma': 103 | output_img[:,:,i] = adjust_gamma(output_img[:,:,i], gamma=gamma, gain=gain) 104 | #output_img[:,:,i] = rescale(output_img[:,:,i], contrast_percentile) 105 | elif scale_method == 'equalize_hist': 106 | output_img[:,:,i] = equalize_hist(output_img[:,:,i]) 107 | 108 | output_img[:,:,i] = np.clip(output_img[:,:,i], 0, 1) 109 | return output_img 110 | 111 | 112 | def draw_network( 113 | adata: AnnData, 114 | node_key: str = 'UTAG Label', 115 | adjacency_matrix_key: str = 'UTAG Label_domain_adjacency_matrix', 116 | figsize: tuple = (11,11), 117 | dpi: int = 200, 118 | font_size: int = 12, 119 | node_size_min: int = 1000, 120 | node_size_max: int = 3000, 121 | edge_weight: float = 5, 122 | edge_weight_baseline: float = 1, 123 | log_transform: bool = True, 124 | ax = None 125 | ): 126 | import networkx as nx 127 | s1 = adata.obs.groupby(node_key).count() 128 | s1 = s1[s1.columns[0]] 129 | node_size = s1.values 130 | node_size = (node_size - node_size.min()) / (node_size.max() - node_size.min()) * (node_size_max - node_size_min) + node_size_min 131 | 132 | if ax == None: 133 | fig = plt.figure(figsize = figsize, dpi = dpi) 134 | G = nx.from_numpy_matrix(np.matrix(adata.uns[adjacency_matrix_key]), create_using=nx.Graph) 135 | G = nx.relabel.relabel_nodes(G, {i: label for i, label in enumerate(adata.uns[adjacency_matrix_key].index)}) 136 | 137 | edges, weights = zip(*nx.get_edge_attributes(G,'weight').items()) 138 | 139 | if log_transform: 140 | weights = np.log(np.array(list(weights))+1) 141 | else: 142 | weights = np.array(list(weights)) 143 | 144 | weights = (weights - weights.min()) / (weights.max() - weights.min()) * edge_weight + edge_weight_baseline 145 | weights = tuple(weights.tolist()) 146 | 147 | #pos = nx.spectral_layout(G, weight = 'weight') 148 | pos = nx.spring_layout(G, weight = 'weight', seed = 42, k = 1) 149 | 150 | if ax: 151 | nx.draw(G, pos, node_color='w', edgelist=edges, edge_color=weights, width=weights, edge_cmap=plt.cm.YlOrRd, with_labels=True, font_size = font_size, node_size = node_size, ax = ax) 152 | else: 153 | nx.draw(G, pos, node_color='w', edgelist=edges, edge_color=weights, width=weights, edge_cmap=plt.cm.YlOrRd, with_labels=True, font_size = font_size, node_size = node_size) 154 | 155 | if ax == None: 156 | ax = plt.gca() 157 | 158 | color_key = node_key + '_colors' 159 | if color_key in adata.uns: 160 | ax.collections[0].set_edgecolor(adata.uns[color_key]) 161 | ax.collections[0].set_facecolor(adata.uns[color_key]) 162 | else: 163 | ax.collections[0].set_edgecolor('lightgray') 164 | ax.collections[0].set_linewidth(3) 165 | ax.set_xlim([1.3*x for x in ax.get_xlim()]) 166 | ax.set_ylim([1*y for y in ax.get_ylim()]) 167 | 168 | if ax == None: 169 | return fig 170 | 171 | def adj2chord( 172 | adjacency_matrix: Array, 173 | size:int = 300 174 | ): 175 | hv.output(fig='svg', size=size) 176 | 177 | links = adjacency_matrix.stack().reset_index().rename(columns = {'level_0': 'source', 'level_1': 'target', 0: 'value'}).dropna() 178 | order2ind = {k:i for i, k in enumerate(adjacency_matrix.index.tolist())} 179 | 180 | links['source'] = links['source'].replace(order2ind) 181 | links['target'] = links['target'].replace(order2ind) 182 | links['value'] = links['value'].astype(int) 183 | 184 | nodes = pd.DataFrame(order2ind.keys(), index = order2ind.values(), columns = ['name']).reset_index() 185 | nodes['group'] = nodes['index'] 186 | del nodes['index'] 187 | nodes = hv.Dataset(nodes, 'index') 188 | 189 | chord = hv.Chord((links, nodes)).select(value=(5, None)) 190 | chord.opts( 191 | opts.Chord( 192 | cmap='tab10', 193 | edge_cmap='tab10', 194 | edge_color=dim('source').str(), 195 | labels='name', 196 | node_color=dim('index').str() 197 | ) 198 | ) 199 | 200 | return chord --------------------------------------------------------------------------------