├── .gitattributes ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── images └── workflow.png ├── miso ├── __init__.py ├── checkpoints │ ├── vit256_small_dino.pth │ └── vit4k_xs_dino.pth ├── hipt_4k.py ├── hipt_model_utils.py ├── hist_features.py ├── model.py ├── nets.py ├── utils.py ├── vision_transformer.py └── vision_transformer4k.py ├── requirements.txt ├── setup.py └── tutorial └── tutorial.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | miso_tutorial_data/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | miso/checkpoints/* 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | 159 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT AND PERMISSION NOTICE 2 | 3 | Penn Software MISO 4 | 5 | Copyright (C) 2022 The Trustees of the University of Pennsylvania 6 | 7 | All rights reserved. 8 | 9 | 10 | 11 | The Trustees of the University of Pennsylvania ("Penn") and Kyle Coleman the developers ("Developer") of Penn Software MISO ("Software") give recipient ("Recipient") and Recipient's Institution ("Institution") permission to use, copy, and modify the software in source and binary forms, with or without modification for non-profit research purposes only provided that the following conditions are met: 12 | 13 | 14 | 15 | 1) All copies of Software in binary form and/or source code, related documentation and/or other materials provided with the Software must reproduce and retain the above copyright notice, this list of conditions and the following disclaimer. 16 | 17 | 18 | 19 | 20 | 21 | 2) Recipient shall have the right to create modifications of the Software ("Modifications") for their internal research, non-commercial, or academic purposes only. 22 | 23 | 24 | 25 | 26 | 27 | 3) All copies of Modifications in binary form and/or source code and related documentation must reproduce and retain the above copyright notice, this list of conditions and the following disclaimer. 28 | 29 | 30 | 31 | 32 | 33 | 4) Recipient and Institution shall not distribute Software or Modifications to any commercial third parties without the prior written approval of Penn. 34 | 35 | 36 | 37 | 38 | 39 | 5) Recipient is encouraged to provide the Developer with feedback on the use of the Software and Modifications. The Developers and Penn are permitted to use any information Recipient provides in making changes to the Software. 40 | 41 | Please report these by opening an issue on our Github at: https://github.com/kpcoleman/miso 42 | 43 | 44 | 45 | 6) Recipient acknowledges that the Developers, Penn and its licensees may develop modifications to Software that may be substantially similar to Recipient's modifications of Software, and that the Developers, Penn and its licensees shall not be constrained in any way by Recipient in Penn's or its licensees' use or management of such modifications. Recipient acknowledges the right of the Developers and Penn to prepare and publish modifications to Software that may be substantially similar or functionally equivalent to your modifications and improvements, and if Recipient or Institution obtains patent protection for any modification or improvement to Software, Recipient and Institution agree not to allege or enjoin infringement of their patent by the Developers, Penn or any of Penn's licensees obtaining modifications or improvements to Software from the Penn or the Developers. 46 | 47 | 48 | 49 | 50 | 51 | 7) Recipient and Developer will acknowledge in their respective publications, if any, the contributions made to each other's research involving or based on the Software. The current citations for Software are: 52 | 53 | 54 | 55 | 56 | 57 | 8) Any party desiring a license to use the Software and/or Modifications for commercial purposes shall contact The Penn Center for Innovation at 215-898-9591. 58 | 59 | 60 | 61 | 62 | 63 | 64 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS, CONTRIBUTORS, AND THE TRUSTEES OF THE UNIVERSITY OF PENNSYLVANIA "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER, CONTRIBUTORS OR THE TRUSTEES OF THE UNIVERSITY OF PENNSYLVANIA BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 65 | 66 | 67 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.pth 2 | 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Resolving tissue complexity by multi-modal spatial omics modeling with MISO 2 | 3 | ### Kyle Coleman*, Amelia Schroeder, Melanie Loth, Daiwei Zhang, Jeong Hwan Park, Ji-Youn Sung, Niklas Blank, Alexis Jazmyn, Xuyu Qian, Jianfeng Chen, Jiahui Jiang, Hanying Yan, Laith Z. Samarah, Jean R. Clemenceau, Inyeop Jang, Minji Kim, Isabelle Barnfather, Joshua D. Rabinowitz, Yanxiang Deng, Edward B. Lee, Alexander Lazar, Jianjun Gao, Emma E. Furth, Tae Hyun Hwang, Linghua Wang, Christoph A. Thaiss, Jian Hu*, Mingyao Li* 4 | 5 | MISO is a deep-learning based method developed for the integration and clustering of multi-modal spatial omics data. MISO requires minimal hyperparameter tuning, and can be applied to any number of 6 | omic and imaging data modalities from any multi-modal spatial omics experiment. MISO has been evaluated on datasets from experiements including spatial transcriptomics (transcriptomics and histology), 7 | spatial epigenome-transcriptome co-profiling (chromatin accessibility, histone modification, and transcriptomics), spatial CITE-seq (transcriptomics, 8 | proteomics, and histology), and spatial transcriptomics and metabolomics (transcriptomics, metabolomics, and histology) 9 | 10 | ![png](images/workflow.png) 11 | 12 | 13 | ## MISO Installation 14 | 15 | Typical install time is ~1 min. 16 | MISO has been tested on the following operating systems: 17 | - macOS: Ventura (13.5.1) 18 | - Linux: CentOS (7) 19 | 20 | 21 | MISO installation requires python version 3.7. The version of python can be checked by: 22 | ```python 23 | import platform 24 | platform.python_version() 25 | ``` 26 | 27 | '3.7.13' 28 | 29 | 30 | We recommend creating and activating a new conda environment when installing the MISO package. For instance, 31 | ```bash 32 | conda create -n miso python=3.7.13 33 | conda activate miso 34 | ``` 35 | 36 | The MISO repository can be downloaded using: 37 | 38 | ```bash 39 | git clone https://github.com/kpcoleman/miso 40 | ``` 41 | 42 | The pretrained ViT weights are stored on Git LFS, and can be downloaded using: 43 | 44 | ```bash 45 | cd miso 46 | git lfs install 47 | git lfs fetch 48 | git lfs pull 49 | ``` 50 | 51 | The MISO package and dependencies can then be installed: 52 | 53 | ```bash 54 | python -m pip install . 55 | ``` 56 | 57 | Typical training time for MISO on a dataset containing less than 10,000 spots is <1 min on a GPU and <5 min on a CPU. The H&E histology image feature extraction step takes approximately 10 minutes on a GPU and 2 hours on a CPU. 58 | For a tutorial, please see: https://github.com/kpcoleman/miso/blob/main/tutorial/tutorial.ipynb 59 | 60 | The miso conda environment can be used for the tutorial by: 61 | 62 | ```bash 63 | python -m pip install ipykernel 64 | python -m ipykernel install --user --name=miso 65 | ``` 66 | 67 | 68 | ## Software Requirements 69 | einops==0.6.0 70 | importlib 71 | importlib-metadata 72 | numpy==1.21.6 73 | opencv_python==4.6.0.66 74 | Pillow>=6.1.0 75 | scanpy==1.9.1 76 | scikit_image==0.19.3 77 | scikit_learn==1.0.2 78 | scipy==1.7.3 79 | setuptools==65.6.3 80 | torch==1.13.1 81 | torchvision==0.14.1 82 | tqdm==4.64.1 83 | 84 | H&E image feature extraction code is based on HIPT and iSTAR. Pre-trained vision transformer models are from HIPT. 85 | 86 | 87 | -------------------------------------------------------------------------------- /images/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kpcoleman/miso/1252c6f4b280a0fd303c2a3e131e68b379acb3a6/images/workflow.png -------------------------------------------------------------------------------- /miso/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' 2 | from . model import Miso 3 | -------------------------------------------------------------------------------- /miso/checkpoints/vit256_small_dino.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6960cd5a8657dc8bb214671aa0c6dbd3f5b698e84386884955836487ddc89e24 3 | size 704238867 4 | -------------------------------------------------------------------------------- /miso/checkpoints/vit4k_xs_dino.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2b0bd9e9a602a35f2bb3f76da39d2b53a91f23fc3f115dc59a63267d95ad2b7b 3 | size 395710078 4 | -------------------------------------------------------------------------------- /miso/hipt_4k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | Image.MAX_IMAGE_PIXELS = None 5 | import torch 6 | import torch.multiprocessing 7 | from torchvision import transforms 8 | torch.multiprocessing.set_sharing_strategy('file_system') 9 | from einops import rearrange, reduce, repeat 10 | from . hipt_model_utils import get_vit256, get_vit4k, tensorbatch2im, eval_transforms 11 | 12 | 13 | class HIPT_4K(torch.nn.Module): 14 | """ 15 | HIPT Model (ViT-4K) for encoding non-square images (with [256 x 256] patch tokens), with 16 | [256 x 256] patch tokens encoded via ViT-256 using [16 x 16] patch tokens. 17 | """ 18 | def __init__(self, 19 | model256_path=None, 20 | model4k_path=None, 21 | device256=torch.device('cuda:0'), 22 | device4k=torch.device('cuda:0')): 23 | 24 | super().__init__() 25 | self.model256 = get_vit256(pretrained_weights=model256_path).to(device256) 26 | self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device4k) 27 | self.device256 = device256 28 | self.device4k = device4k 29 | 30 | def forward(self, x): 31 | return self.forward_all(x)[0] 32 | 33 | def forward_all(self, x): 34 | """ 35 | Forward pass of HIPT (given an image tensor x), outputting the [CLS] token from ViT-4K. 36 | 1. x is center-cropped such that the W / H is divisible by the patch token size in ViT-4K (e.g. - 256 x 256). 37 | 2. x then gets unfolded into a "batch" of [256 x 256] images. 38 | 3. A pretrained ViT-256 model extracts the CLS token from each [256 x 256] image in the batch. 39 | 4. These batch-of-features are then reshaped into a 2D feature grid (of width "w_256" and height "h_256".) 40 | 5. This feature grid is then used as the input to ViT-4K, outputting [CLS]_4K. 41 | 42 | Args: 43 | - x (torch.Tensor): [1 x C x W' x H'] image tensor. 44 | 45 | Return: 46 | - features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default). 47 | """ 48 | features_cls256, features_sub256 = self.forward_all256(x) 49 | features_cls4k, features_sub4k = self.forward_all4k(features_cls256) 50 | 51 | return features_cls4k, features_sub4k, features_sub256 52 | 53 | def forward_all256(self, x): 54 | batch_256, w_256, h_256 = self.prepare_img_tensor(x) # 1. [1 x 3 x W x H] 55 | batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256) # 2. [1 x 3 x w_256 x h_256 x 256 x 256] 56 | batch_256 = rearrange(batch_256, 'b c p1 p2 w h -> (b p1 p2) c w h') # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256) 57 | 58 | features_cls256 = [] 59 | features_sub256 = [] 60 | for mini_bs in range(0, batch_256.shape[0], 256): # 3. B may be too large for ViT-256. We further take minibatches of 256. 61 | minibatch_256 = batch_256[mini_bs:mini_bs+256].to(self.device256, non_blocking=True) 62 | fea_all256 = self.model256.forward_all(minibatch_256).cpu() 63 | fea_cls256 = fea_all256[:, 0] 64 | fea_sub256 = fea_all256[:, 1:] 65 | features_cls256.append(fea_cls256) # 3. Extracting ViT-256 features from [256 x 3 x 256 x 256] image batches. 66 | features_sub256.append(fea_sub256) 67 | 68 | features_cls256 = torch.vstack(features_cls256) # 3. [B x 384], where 384 == dim of ViT-256 [ClS] token. 69 | features_sub256 = torch.vstack(features_sub256) 70 | features_cls256 = features_cls256.reshape(w_256, h_256, 384).transpose(0,1).transpose(0,2).unsqueeze(dim=0) # [1 x 384 x w_256 x h_256] 71 | features_sub256 = features_sub256.reshape(w_256, h_256, 16, 16, 384).permute(4, 0, 1, 2, 3).unsqueeze(dim=0) # [1 x 384 x w_256 x h_256 x 16 x 16] 72 | return features_cls256, features_sub256 73 | 74 | def forward_all4k(self, features_cls256): 75 | __, __, w_256, h_256 = features_cls256.shape 76 | features_cls256 = features_cls256.to(self.device4k, non_blocking=True) 77 | features_all4k = self.model4k.forward_all(features_cls256) 78 | # attn_all4k = self.model4k.get_last_selfattention(features_cls256) 79 | features_cls4k = features_all4k[:, 0] # 5. [1 x 192], where 192 == dim of ViT-4K [ClS] token. 80 | features_sub4k = features_all4k[:, 1:] 81 | features_sub4k = features_sub4k.reshape(1, w_256, h_256, 192).permute(0, 3, 1, 2) 82 | return features_cls4k, features_sub4k 83 | 84 | def prepare_img_tensor(self, img: torch.Tensor, patch_size=256): 85 | """ 86 | Helper function that takes a non-square image tensor, and takes a center crop s.t. the width / height 87 | are divisible by 256. 88 | 89 | (Note: "_256" for w / h is should technically be renamed as "_ps", but may not be easier to read. 90 | Until I need to make HIPT with patch_sizes != 256, keeping the naming convention as-is.) 91 | 92 | Args: 93 | - img (torch.Tensor): [1 x C x W' x H'] image tensor. 94 | - patch_size (int): Desired patch size to evenly subdivide the image. 95 | 96 | Return: 97 | - img_new (torch.Tensor): [1 x C x W x H] image tensor, where W and H are divisble by patch_size. 98 | - w_256 (int): # of [256 x 256] patches of img_new's width (e.g. - W/256) 99 | - h_256 (int): # of [256 x 256] patches of img_new's height (e.g. - H/256) 100 | """ 101 | make_divisble = lambda l, patch_size: (l - (l % patch_size)) 102 | b, c, w, h = img.shape 103 | load_size = make_divisble(w, patch_size), make_divisble(h, patch_size) 104 | w_256, h_256 = w // patch_size, h // patch_size 105 | img_new = transforms.CenterCrop(load_size)(img) 106 | return img_new, w_256, h_256 107 | -------------------------------------------------------------------------------- /miso/hipt_model_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import torch.multiprocessing 5 | torch.multiprocessing.set_sharing_strategy('file_system') 6 | from torchvision import transforms 7 | from . import vision_transformer as vits 8 | from . import vision_transformer4k as vits4k 9 | 10 | def get_vit256(pretrained_weights=None, arch='vit_small', device=torch.device('cuda:0')): 11 | r""" 12 | Builds ViT-256 Model. 13 | 14 | Args: 15 | - pretrained_weights (str): Path to ViT-256 Model Checkpoint. 16 | - arch (str): Which model architecture. 17 | - device (torch): Torch device to save model. 18 | 19 | Returns: 20 | - model256 (torch.nn): Initialized model. 21 | """ 22 | 23 | checkpoint_key = 'teacher' 24 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 25 | model256 = vits.__dict__[arch](patch_size=16, num_classes=0) 26 | for p in model256.parameters(): 27 | p.requires_grad = False 28 | model256.eval() 29 | model256.to(device) 30 | 31 | if pretrained_weights is not None: 32 | state_dict = torch.load(pretrained_weights, map_location="cpu") 33 | if checkpoint_key is not None and checkpoint_key in state_dict: 34 | #print(f"Take key {checkpoint_key} in provided checkpoint dict") 35 | state_dict = state_dict[checkpoint_key] 36 | # remove `module.` prefix 37 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 38 | # remove `backbone.` prefix induced by multicrop wrapper 39 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 40 | msg = model256.load_state_dict(state_dict, strict=False) 41 | #print('Pretrained weights loaded from {}'.format(pretrained_weights)) 42 | 43 | return model256 44 | 45 | 46 | def get_vit4k(pretrained_weights=None, arch='vit4k_xs', device=torch.device('cuda:0')): 47 | r""" 48 | Builds ViT-4K Model. 49 | 50 | Args: 51 | - pretrained_weights (str): Path to ViT-4K Model Checkpoint. 52 | - arch (str): Which model architecture. 53 | - device (torch): Torch device to save model. 54 | 55 | Returns: 56 | - model256 (torch.nn): Initialized model. 57 | """ 58 | 59 | checkpoint_key = 'teacher' 60 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 61 | model4k = vits4k.__dict__[arch](num_classes=0) 62 | for p in model4k.parameters(): 63 | p.requires_grad = False 64 | model4k.eval() 65 | model4k.to(device) 66 | 67 | if pretrained_weights is not None: 68 | state_dict = torch.load(pretrained_weights, map_location="cpu") 69 | if checkpoint_key is not None and checkpoint_key in state_dict: 70 | #print(f"Take key {checkpoint_key} in provided checkpoint dict") 71 | state_dict = state_dict[checkpoint_key] 72 | # remove `module.` prefix 73 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 74 | # remove `backbone.` prefix induced by multicrop wrapper 75 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 76 | msg = model4k.load_state_dict(state_dict, strict=False) 77 | #print('Pretrained weights loaded from {}'.format(pretrained_weights)) 78 | 79 | return model4k 80 | 81 | 82 | def eval_transforms(): 83 | """ 84 | """ 85 | mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 86 | eval_t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)]) 87 | return eval_t 88 | 89 | 90 | def roll_batch2img(batch: torch.Tensor, w: int, h: int, patch_size=256): 91 | """ 92 | Rolls an image tensor batch (batch of [256 x 256] images) into a [W x H] Pil.Image object. 93 | 94 | Args: 95 | batch (torch.Tensor): [B x 3 x 256 x 256] image tensor batch. 96 | 97 | Return: 98 | Image.PIL: [W x H X 3] Image. 99 | """ 100 | batch = batch.reshape(w, h, 3, patch_size, patch_size) 101 | img = rearrange(batch, 'p1 p2 c w h-> c (p1 w) (p2 h)').unsqueeze(dim=0) 102 | return Image.fromarray(tensorbatch2im(img)[0]) 103 | 104 | 105 | def tensorbatch2im(input_image, imtype=np.uint8): 106 | r"""" 107 | Converts a Tensor array into a numpy image array. 108 | 109 | Args: 110 | - input_image (torch.Tensor): (B, C, W, H) Torch Tensor. 111 | - imtype (type): the desired type of the converted numpy array 112 | 113 | Returns: 114 | - image_numpy (np.array): (B, W, H, C) Numpy Array. 115 | """ 116 | if not isinstance(input_image, np.ndarray): 117 | image_numpy = input_image.cpu().float().numpy() # convert it into a numpy array 118 | #if image_numpy.shape[0] == 1: # grayscale to RGB 119 | # image_numpy = np.tile(image_numpy, (3, 1, 1)) 120 | image_numpy = (np.transpose(image_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 121 | else: # if it is a numpy array, do nothing 122 | image_numpy = input_image 123 | return image_numpy.astype(imtype) 124 | -------------------------------------------------------------------------------- /miso/hist_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | import argparse 4 | import sys 5 | import warnings 6 | import pkg_resources 7 | 8 | from einops import rearrange, reduce, repeat 9 | import numpy as np 10 | import skimage 11 | import torch 12 | from torch import nn 13 | from PIL import Image 14 | import cv2 as cv 15 | from scipy.ndimage import label 16 | from scipy.ndimage.morphology import binary_fill_holes 17 | from scipy.ndimage import uniform_filter 18 | from skimage.transform import rescale 19 | 20 | from . hipt_model_utils import eval_transforms 21 | from . hipt_4k import HIPT_4K 22 | from typing import Optional, Sequence 23 | 24 | try: 25 | shell = get_ipython().__class__.__name__ 26 | if shell == 'ZMQInteractiveShell': 27 | from tqdm.notebook import tqdm 28 | else: 29 | from tqdm import tqdm 30 | except NameError: 31 | from tqdm import tqdm 32 | 33 | 34 | Image.MAX_IMAGE_PIXELS = None 35 | 36 | 37 | def rescale_image(img, scale): 38 | img = np.array(img).astype(np.float32) 39 | if img.ndim == 2: 40 | scale = [scale, scale] 41 | elif img.ndim == 3: 42 | scale = [scale, scale, 1] 43 | else: 44 | raise ValueError('Unrecognized image ndim') 45 | img = rescale(img, scale, preserve_range=True) 46 | img = img.astype(np.uint8) 47 | return img 48 | 49 | def preprocess(img): 50 | img = np.array(img) 51 | if img.ndim == 3 and img.shape[-1] == 4: 52 | img = img[..., :3] # remove alpha channel 53 | return img 54 | 55 | def remove_fg_elements(mask: np.ndarray, size_threshold: float): 56 | r'''Removes small foreground elements''' 57 | labels, _ = label(mask) 58 | labels_unique, label_counts = np.unique(labels, return_counts=True) 59 | small_labels = labels_unique[ 60 | label_counts < size_threshold ** 2 * np.prod(mask.shape) 61 | ] 62 | mask[np.isin(labels, small_labels)] = False 63 | return mask 64 | 65 | def cleanup_mask(mask: np.ndarray, size_threshold: float): 66 | r'''Removes small background and foreground elements''' 67 | mask = ~remove_fg_elements(~mask, size_threshold) 68 | mask = remove_fg_elements(mask, size_threshold) 69 | return mask 70 | 71 | def resize( 72 | image: np.ndarray, 73 | target_shape: Sequence[int], 74 | resample: int = Image.NEAREST, 75 | ) -> np.ndarray: 76 | r''' 77 | Resizes image to a given `target_shape` 78 | 79 | :param image: Image array 80 | :param target_shape: Target shape 81 | :param resample: Resampling filter 82 | :returns: The rescaled image 83 | ''' 84 | image_pil = Image.fromarray(image) 85 | image_pil = image_pil.resize(target_shape[::-1], resample=resample) 86 | return np.array(image_pil) 87 | 88 | 89 | def compute_tissue_mask( 90 | image: np.ndarray, 91 | convergence_threshold: float = 0.0001, 92 | size_threshold: float = 0.01, 93 | initial_mask: Optional[np.ndarray] = None, 94 | max_iter: int = 100, 95 | ) -> np.ndarray: 96 | r''' 97 | Computes boolean mask indicating likely foreground elements in histology 98 | image. 99 | ''' 100 | # pylint: disable=no-member 101 | # ^ pylint fails to identify cv.* members 102 | original_shape = image.shape[:2] 103 | scale_factor = 1000 / max(original_shape) 104 | image = rescale(image, scale_factor, resample=Image.NEAREST) 105 | image = cv.blur(image, (5, 5)) 106 | if initial_mask is None: 107 | initial_mask = ( 108 | cv.blur(cv.Canny(image, 100, 200), (20, 20)) > 0 109 | ) 110 | else: 111 | initial_mask = rescale( 112 | initial_mask, scale_factor, resample=Image.NEAREST) 113 | initial_mask = binary_fill_holes(initial_mask) 114 | initial_mask = remove_fg_elements(initial_mask, 0.1) # type: ignore 115 | mask = np.where(initial_mask, cv.GC_PR_FGD, cv.GC_PR_BGD) 116 | mask = mask.astype(np.uint8) 117 | bgd_model = np.zeros((1, 65), np.float64) 118 | fgd_model = bgd_model.copy() 119 | print('Computing tissue mask:') 120 | for i in range(max_iter): 121 | old_mask = mask.copy() 122 | try: 123 | cv.grabCut( 124 | image, 125 | mask, 126 | None, 127 | bgd_model, 128 | fgd_model, 129 | 1, 130 | cv.GC_INIT_WITH_MASK, 131 | ) 132 | except cv.error as cv_err: 133 | warnings.warn(f'Failed to mask tissue\n{str(cv_err).strip()}') 134 | mask = np.full_like(mask, cv.GC_PR_FGD) 135 | break 136 | prop_changed = (mask != old_mask).sum() / np.prod(mask.shape) 137 | print(' Iteration %2d Δ = %.2f%%', i, 100 * prop_changed) 138 | if prop_changed < convergence_threshold: 139 | break 140 | mask = np.isin(mask, [cv.GC_FGD, cv.GC_PR_FGD]) 141 | mask = cleanup_mask(mask, size_threshold) 142 | mask = resize(mask, target_shape=original_shape, resample=Image.NEAREST) 143 | return mask 144 | 145 | 146 | def remove_border(x): 147 | x = x.copy() 148 | x[0] = 0 149 | x[-1] = 0 150 | x[:, 0] = 0 151 | x[:, -1] = 0 152 | return x 153 | 154 | 155 | def get_extent(mask): 156 | extent = [] 157 | for ax in range(mask.ndim): 158 | ma = mask.swapaxes(0, ax) 159 | ma = ma.reshape(ma.shape[0], -1) 160 | notempty = ma.any(1) 161 | start = notempty.argmax() 162 | stop = notempty.size - notempty[::-1].argmax() 163 | extent.append([start, stop]) 164 | extent = np.array(extent) 165 | return extent 166 | 167 | 168 | def crop_image(img, extent, mode='edge', constant_values=None): 169 | extent = np.array(extent) 170 | pad = np.zeros((img.ndim, 2), dtype=int) 171 | for i, (lower, upper) in enumerate(extent): 172 | if lower < 0: 173 | pad[i][0] = 0 - lower 174 | if upper > img.shape[i]: 175 | pad[i][1] = upper - img.shape[i] 176 | if (pad != 0).any(): 177 | kwargs = {} 178 | if mode == 'constant' and constant_values is not None: 179 | kwargs['constant_values'] = constant_values 180 | img = np.pad(img, pad, mode=mode, **kwargs) 181 | extent += pad[:extent.shape[0], [0]] 182 | for i, (lower, upper) in enumerate(extent): 183 | img = img.take(range(lower, upper), axis=i) 184 | return img 185 | 186 | 187 | 188 | def adjust_margins(img, pad, pad_value=None): 189 | extent = np.stack([[0, 0], img.shape[:2]]).T 190 | # make size divisible by pad without changing coords 191 | remainder = (extent[:, 1] - extent[:, 0]) % pad 192 | complement = (pad - remainder) % pad 193 | extent[:, 1] += complement 194 | if pad_value is None: 195 | mode = 'edge' 196 | else: 197 | mode = 'constant' 198 | img = crop_image( 199 | img, extent, mode=mode, constant_values=pad_value) 200 | return img 201 | 202 | 203 | def shrink_mask(x, size): 204 | size = size * 2 - 1 205 | x = uniform_filter(x.astype(float), size=size) 206 | x = np.isclose(x, 1) 207 | return x 208 | 209 | 210 | def patchify(x, patch_size): 211 | shape_ori = np.array(x.shape[:2]) 212 | shape_ext = ( 213 | (shape_ori + patch_size - 1) 214 | // patch_size * patch_size) 215 | x = np.pad( 216 | x, 217 | ( 218 | (0, shape_ext[0] - x.shape[0]), 219 | (0, shape_ext[1] - x.shape[1]), 220 | (0, 0)), 221 | mode='edge') 222 | tiles_shape = np.array(x.shape[:2]) // patch_size 223 | # x = rearrange( 224 | # x, '(h1 h) (w1 w) c -> h1 w1 h w c', 225 | # h=patch_size, w=patch_size) 226 | # x = rearrange( 227 | # x, '(h1 h) (w1 w) c -> (h1 w1) h w c', 228 | # h=patch_size, w=patch_size) 229 | tiles = [] 230 | for i0 in range(tiles_shape[0]): 231 | a0 = i0 * patch_size # TODO: change to patch_size[0] 232 | b0 = a0 + patch_size # TODO: change to patch_size[0] 233 | for i1 in range(tiles_shape[1]): 234 | a1 = i1 * patch_size # TODO: change to patch_size[1] 235 | b1 = a1 + patch_size # TODO: change to patch_size[1] 236 | tiles.append(x[a0:b0, a1:b1]) 237 | 238 | shapes = dict( 239 | original=shape_ori, 240 | padded=shape_ext, 241 | tiles=tiles_shape) 242 | return tiles, shapes 243 | 244 | 245 | def get_embeddings_sub(model, x): 246 | x = x.astype(np.float32) / 255.0 247 | x = eval_transforms()(x) 248 | x_cls, x_sub = model.forward_all256(x[None]) 249 | x_cls = x_cls.cpu().detach().numpy() 250 | x_sub = x_sub.cpu().detach().numpy() 251 | x_cls = x_cls[0].transpose(1, 2, 0) 252 | x_sub = x_sub[0].transpose(1, 2, 3, 4, 0) 253 | return x_cls, x_sub 254 | 255 | 256 | def get_embeddings_cls(model, x): 257 | x = torch.tensor(x.transpose(2, 0, 1)) 258 | with torch.no_grad(): 259 | __, x_sub4k = model.forward_all4k(x[None]) 260 | x_sub4k = x_sub4k.cpu().detach().numpy() 261 | x_sub4k = x_sub4k[0].transpose(1, 2, 0) 262 | return x_sub4k 263 | 264 | 265 | def get_embeddings(img, pretrained=True, device='cuda'): 266 | ''' 267 | Extract embeddings from histology tiles 268 | Args: 269 | tiles: Histology image tiles. 270 | Shape: (N, H, W, C). 271 | `H` and `W` are both divisible by 256. 272 | Channels `C` include R, G, B, foreground mask. 273 | Returns: 274 | emb_cls: Embeddings of (256 x 256)-sized patches 275 | Shape: (H/256, W/256, 384) 276 | emb_sub: Embeddings of (16 x 16)-sized patches 277 | Shape: (H/16, W/16, 384) 278 | ''' 279 | #print('Extracting embeddings...') 280 | t0 = time() 281 | 282 | tile_size = 4096 283 | tiles, shapes = patchify(img, patch_size=tile_size) 284 | 285 | model256_path, model4k_path = None, None 286 | if pretrained: 287 | model256_path = pkg_resources.resource_filename(__name__,'checkpoints/vit256_small_dino.pth') 288 | model4k_path = pkg_resources.resource_filename(__name__,'checkpoints/vit4k_xs_dino.pth') 289 | model = HIPT_4K( 290 | model256_path=model256_path, 291 | model4k_path=model4k_path, 292 | device256=device, device4k=device) 293 | model.eval() 294 | patch_size = (256, 256) 295 | subpatch_size = (16, 16) 296 | n_subpatches = tuple( 297 | a // b for a, b in zip(patch_size, subpatch_size)) 298 | 299 | emb_sub = [] 300 | emb_mid = [] 301 | for i in range(len(tiles)): 302 | #if i % 10 == 0: 303 | # print('tile', i, '/', len(tiles)) 304 | x_mid, x_sub = get_embeddings_sub(model, tiles[i]) 305 | emb_mid.append(x_mid) 306 | emb_sub.append(x_sub) 307 | del tiles 308 | torch.cuda.empty_cache() 309 | emb_mid = rearrange( 310 | emb_mid, '(h1 w1) h2 w2 k -> (h1 h2) (w1 w2) k', 311 | h1=shapes['tiles'][0], w1=shapes['tiles'][1]) 312 | 313 | emb_cls = get_embeddings_cls(model, emb_mid) 314 | del emb_mid, model 315 | torch.cuda.empty_cache() 316 | 317 | shape_orig = np.array(shapes['original']) // subpatch_size 318 | 319 | chans_sub = [] 320 | for i in range(emb_sub[0].shape[-1]): 321 | chan = rearrange( 322 | np.array([e[..., i] for e in emb_sub]), 323 | '(h1 w1) h2 w2 h3 w3 -> (h1 h2 h3) (w1 w2 w3)', 324 | h1=shapes['tiles'][0], w1=shapes['tiles'][1]) 325 | chan = chan[:shape_orig[0], :shape_orig[1]] 326 | chans_sub.append(chan) 327 | del emb_sub 328 | 329 | chans_cls = [] 330 | for i in range(emb_cls[0].shape[-1]): 331 | chan = repeat( 332 | np.array([e[..., i] for e in emb_cls]), 333 | 'h12 w12 -> (h12 h3) (w12 w3)', 334 | h3=n_subpatches[0], w3=n_subpatches[1]) 335 | chan = chan[:shape_orig[0], :shape_orig[1]] 336 | chans_cls.append(chan) 337 | del emb_cls 338 | 339 | #print(int(time() - t0), 'sec') 340 | 341 | return chans_cls, chans_sub 342 | 343 | 344 | def get_embeddings_shift( 345 | img, margin=256, stride=64, 346 | pretrained=True, device='cuda'): 347 | # margin: margin for shifting. Divisble by 256 348 | # stride: stride for shifting. Divides `margin`. 349 | factor = 16 # scaling factor between cls and sub. Fixed 350 | shape_emb = np.array(img.shape[:2]) // factor 351 | chans_cls = [ 352 | np.zeros(shape_emb, dtype=np.float32) 353 | for __ in range(192)] 354 | chans_sub = [ 355 | np.zeros(shape_emb, dtype=np.float32) 356 | for __ in range(384)] 357 | start_list = list(range(0, margin, stride)) 358 | n_reps = 0 359 | for k,start0 in enumerate(start_list): 360 | for start1 in tqdm(start_list, desc = 'Extracting image features: ' + str(k+1) + '/' + str(len(start_list))): 361 | #print(f'shift {start0}/{margin}, {start1}/{margin}') 362 | t0 = time() 363 | stop0, stop1 = -margin+start0, -margin+start1 364 | im = img[start0:stop0, start1:stop1] 365 | cls, sub = get_embeddings( 366 | im, pretrained=pretrained, device=device) 367 | del im 368 | sta0, sta1 = start0 // factor, start1 // factor 369 | sto0, sto1 = stop0 // factor, stop1 // factor 370 | for i in range(len(chans_cls)): 371 | chans_cls[i][sta0:sto0, sta1:sto1] += cls[i] 372 | del cls 373 | for i in range(len(chans_sub)): 374 | chans_sub[i][sta0:sto0, sta1:sto1] += sub[i] 375 | del sub 376 | n_reps += 1 377 | #print(int(time() - t0), 'sec') 378 | 379 | mar = margin // factor 380 | for chan in chans_cls: 381 | chan /= n_reps 382 | chan[-mar:] = 0.0 383 | chan[:, -mar:] = 0.0 384 | for chan in chans_sub: 385 | chan /= n_reps 386 | chan[-mar:] = 0.0 387 | chan[:, -mar:] = 0.0 388 | 389 | return chans_cls, chans_sub 390 | 391 | 392 | 393 | 394 | def impute_missing(x, mask, radius=3, method='ns'): 395 | method_dict = { 396 | 'telea': cv.INPAINT_TELEA, 397 | 'ns': cv.INPAINT_NS} 398 | method = method_dict[method] 399 | channels = [x[..., i] for i in range(x.shape[-1])] 400 | mask = mask.astype(np.uint8) 401 | y = [cv.inpaint(c, mask, radius, method) for c in channels] 402 | y = np.stack(y, -1) 403 | return y 404 | 405 | 406 | def smoothen( 407 | x, size, kernel='gaussian', backend='cv', mode='mean', 408 | impute_missing_values=True, device='cuda'): 409 | 410 | if x.ndim == 3: 411 | expand_dim = False 412 | elif x.ndim == 2: 413 | expand_dim = True 414 | x = x[..., np.newaxis] 415 | else: 416 | raise ValueError('ndim must be 2 or 3') 417 | 418 | mask = np.isfinite(x).all(-1) 419 | if (~mask).any() and impute_missing_values: 420 | x = impute_missing(x, ~mask) 421 | 422 | if kernel == 'gaussian': 423 | sigma = size / 4 # approximate std of uniform filter 1/sqrt(12) 424 | truncate = 4.0 425 | winsize = np.ceil(sigma * truncate).astype(int) * 2 + 1 426 | if backend == 'cv': 427 | print(f'gaussian filter: winsize={winsize}, sigma={sigma}') 428 | y = cv.GaussianBlur( 429 | x, (winsize, winsize), sigmaX=sigma, sigmaY=sigma, 430 | borderType=cv.BORDER_REFLECT) 431 | elif backend == 'skimage': 432 | y = skimage.filters.gaussian( 433 | x, sigma=sigma, truncate=truncate, 434 | preserve_range=True, channel_axis=-1) 435 | else: 436 | raise ValueError('backend must be cv or skimage') 437 | elif kernel == 'uniform': 438 | if backend == 'cv': 439 | kernel = np.ones((size, size), np.float32) / size**2 440 | y = cv.filter2D( 441 | x, ddepth=-1, kernel=kernel, 442 | borderType=cv.BORDER_REFLECT) 443 | if y.ndim == 2: 444 | y = y[..., np.newaxis] 445 | elif backend == 'torch': 446 | assert isinstance(size, int) 447 | padding = size // 2 448 | size = size + 1 449 | 450 | pool_dict = { 451 | 'mean': nn.AvgPool2d( 452 | kernel_size=size, stride=1, padding=0), 453 | 'max': nn.MaxPool2d( 454 | kernel_size=size, stride=1, padding=0)} 455 | pool = pool_dict[mode] 456 | 457 | mod = nn.Sequential( 458 | nn.ReflectionPad2d(padding), 459 | pool) 460 | y = mod(torch.tensor(x, device=device).permute(2, 0, 1)) 461 | y = y.permute(1, 2, 0) 462 | y = y.cpu().detach().numpy() 463 | else: 464 | raise ValueError('backend must be cv or torch') 465 | else: 466 | raise ValueError('kernel must be gaussian or uniform') 467 | 468 | if not mask.all(): 469 | y[~mask] = np.nan 470 | 471 | if expand_dim and y.ndim == 3: 472 | y = y[..., 0] 473 | 474 | return y 475 | 476 | 477 | def smoothen_embeddings( 478 | embs, size, kernel, 479 | method='cv', groups=None, device='cuda'): 480 | if groups is None: 481 | groups = embs.keys() 482 | out = {} 483 | for grp, em in embs.items(): 484 | if grp in groups: 485 | if isinstance(em, list): 486 | smoothened = [ 487 | smoothen( 488 | c[..., np.newaxis], size=size, 489 | kernel=kernel, backend=method, 490 | device=device)[..., 0] 491 | for c in em] 492 | else: 493 | smoothened = smoothen(em, size, method, device=device) 494 | else: 495 | smoothened = em 496 | out[grp] = smoothened 497 | return out 498 | 499 | 500 | def match_foregrounds(embs): 501 | print('Matching foregrounds...') 502 | t0 = time() 503 | channels = np.concatenate(list(embs.values())) 504 | mask = np.isfinite(channels).all(0) 505 | for group, channels in embs.items(): 506 | for chan in channels: 507 | chan[~mask] = np.nan 508 | print(int(time() - t0), 'sec') 509 | 510 | 511 | def adjust_weights(embs, weights=None): 512 | print('Adjusting weights...') 513 | t0 = time() 514 | if weights is None: 515 | weights = {grp: 1.0 for grp in embs.keys()} 516 | for grp in embs.keys(): 517 | channels = embs[grp] 518 | wt = weights[grp] 519 | means = np.array([np.nanmean(chan) for chan in channels]) 520 | std = np.sum([np.nanvar(chan) for chan in channels])**0.5 521 | for chan, me in zip(channels, means): 522 | chan[:] -= me 523 | chan[:] /= std 524 | chan[:] *= wt**0.5 525 | print(int(time() - t0), 'sec') 526 | 527 | 528 | 529 | def get_features(img,locs,rad,pixel_size_raw,pixel_size=0.5,pretrained=True,device='cpu'): 530 | scale = pixel_size_raw / pixel_size 531 | print('Scaling image') 532 | img = rescale_image(img, scale = scale) 533 | rad = rad*scale 534 | locs1 = locs.copy() 535 | locs1['4'] = locs1['4']*scale 536 | locs1['5'] = locs1['5']*scale 537 | print('Preprocessing image') 538 | img = preprocess(img) 539 | #mask = compute_tissue_mask(img) 540 | #mask = remove_border(mask) 541 | print('Adjusting margins') 542 | img = adjust_margins(img, pad=256, pad_value=255) 543 | #img[~mask] = 0 544 | #mask = shrink_mask(mask, size=256) 545 | #mask = mask[..., np.newaxis].astype(np.uint8) * 255 546 | #img = np.concatenate([img, mask], -1) 547 | #print('Extracting image features') 548 | emb_cls, emb_sub = get_embeddings_shift(img, pretrained=True, device=device) 549 | embs = dict(cls=emb_cls, sub=emb_sub) 550 | print('Smoothing embeddings') 551 | embs = smoothen_embeddings(embs, size=16, kernel='uniform', groups=['cls'], method='cv', device=device) 552 | embs = smoothen_embeddings(embs, size=4, kernel='uniform', groups=['sub'], method='cv', device=device) 553 | #match_foregrounds(embs) 554 | #adjust_weights(embs) # use uniform weights by default 555 | cls1 = rearrange(emb_cls, 'c h w -> h w c') 556 | sub1 = rearrange(emb_sub, 'c h w -> h w c') 557 | cls_sub1 = np.concatenate((cls1, sub1), 2) 558 | if rad>16: 559 | cls_sub2 = np.stack([rearrange(cls_sub1[(int(np.ceil((locs1['4'][i]-rad)/16))):(int(np.floor((locs1['4'][i]+rad)/16))),(int(np.ceil((locs1['5'][i]-rad)/16))):(int(np.floor((locs1['5'][i]+rad)/16)))], 'h w c -> (h w) c').mean(0) for i in range(locs1.shape[0])]) 560 | else: 561 | cls_sub2 = cls_sub1[(locs1['4']/16).round().astype('int'), (locs1['5']/16).round().astype('int'), :] 562 | return cls_sub2 563 | -------------------------------------------------------------------------------- /miso/model.py: -------------------------------------------------------------------------------- 1 | from . nets import * 2 | import torch 3 | from torch import nn, optim 4 | from torch.utils.data import TensorDataset, DataLoader 5 | from . utils import calculate_affinity 6 | import numpy as np 7 | import pandas as pd 8 | from numpy.linalg import svd 9 | from sklearn.metrics.pairwise import euclidean_distances 10 | from scanpy.external.tl import phenograph 11 | from sklearn.metrics import adjusted_rand_score 12 | from sklearn.cluster import KMeans 13 | from scipy.sparse import csr_matrix 14 | from scipy.sparse import kron 15 | from scipy.sparse import coo_matrix 16 | from scipy.spatial.distance import cdist 17 | from sklearn.preprocessing import StandardScaler 18 | from itertools import combinations 19 | from sklearn.decomposition import PCA 20 | from PIL import Image 21 | import scipy 22 | 23 | try: 24 | shell = get_ipython().__class__.__name__ 25 | if shell == 'ZMQInteractiveShell': 26 | from tqdm.notebook import tqdm 27 | else: 28 | from tqdm import tqdm 29 | except NameError: 30 | from tqdm import tqdm 31 | 32 | 33 | class Miso(nn.Module): 34 | def __init__(self, features, ind_views='all', combs='all', sparse=False, neighbors = None, device='cpu'): 35 | super(Miso, self).__init__() 36 | self.device = device 37 | self.num_views = len(features) 38 | self.features = [torch.Tensor(i).to(self.device) for i in features] 39 | self.sparse = sparse 40 | features = [StandardScaler().fit_transform(i) for i in features] 41 | if neighbors is None and self.sparse: 42 | neighbors=100 43 | 44 | adj = [calculate_affinity(i, sparse = self.sparse, neighbors=neighbors) for i in features] 45 | self.adj1 = adj 46 | pcs = [PCA(128).fit_transform(i) if i.shape[1] > 128 else i for i in features] 47 | self.pcs = [torch.Tensor(i).to(self.device) for i in pcs] 48 | if not self.sparse: 49 | self.adj = [torch.Tensor(i).to(self.device) for i in adj] 50 | else: 51 | adj = [coo_matrix(i) for i in adj] 52 | indices = [torch.LongTensor(np.vstack((i.row, i.col))) for i in adj] 53 | values = [torch.FloatTensor(i.data) for i in adj] 54 | shape = [torch.Size(i.shape) for i in adj] 55 | self.adj = [torch.sparse.FloatTensor(indices[i], values[i], shape[i]).to(self.device) for i in range(len(adj))] 56 | 57 | if ind_views=='all': 58 | self.ind_views = list(range(len(self.pcs))) 59 | else: 60 | self.ind_views = ind_views 61 | if combs=='all': 62 | self.combinations = list(combinations(list(range(len(self.pcs))),2)) 63 | else: 64 | self.combinations = combs 65 | 66 | def train(self): 67 | self.mlps = [MLP(input_shape = self.pcs[i].shape[1], output_shape = 32).to(self.device) for i in range(len(self.pcs))] 68 | def sc_loss(A,Y): 69 | if not self.sparse: 70 | return (torch.triu(torch.cdist(Y,Y))*torch.triu(A)).mean() 71 | else: 72 | row = A.coalesce().indices()[0] 73 | col = A.coalesce().indices()[1] 74 | rows1 = Y[row] 75 | rows2 = Y[col] 76 | dist = torch.norm(rows1 - rows2, dim=1) 77 | return (dist*A.coalesce().values()).mean() 78 | 79 | 80 | for i in range(self.num_views): 81 | self.mlps[i].train() 82 | optimizer = optim.Adam(self.mlps[i].parameters(), lr=1e-3) 83 | for epoch in tqdm(range(1000), desc='Training network for modality ' + str(i+1)): 84 | optimizer.zero_grad() 85 | x_hat = self.mlps[i](self.pcs[i]) 86 | Y1 = self.mlps[i].get_embeddings(self.pcs[i]) 87 | loss1 = nn.MSELoss()(self.pcs[i],x_hat) 88 | loss2 = sc_loss(self.adj[i], Y1) 89 | loss=loss1+loss2 90 | loss.backward() 91 | optimizer.step() 92 | 93 | [self.mlps[i].eval() for i in range(self.num_views)] 94 | Y = [self.mlps[i].get_embeddings(self.pcs[i]) for i in range(self.num_views)] 95 | if self.combinations is not None: 96 | interactions = [Y[i][:, :, None]*Y[j][:, None, :] for i,j in self.combinations] 97 | interactions = [i.reshape(i.shape[0],-1) for i in interactions] 98 | interactions = [torch.matmul(i,torch.pca_lowrank(i,q=32)[2]) for i in interactions] 99 | Y = [Y[i] for i in self.ind_views] 100 | Y = [StandardScaler().fit_transform(i.cpu().detach().numpy()) for i in Y] 101 | Y = np.concatenate(Y,1) 102 | if self.combinations is not None: 103 | interactions = [StandardScaler().fit_transform(i.cpu().detach().numpy()) for i in interactions] 104 | interactions = np.concatenate(interactions,1) 105 | emb = np.concatenate((Y,interactions),1) 106 | else: 107 | emb = Y 108 | self.emb = emb 109 | 110 | def cluster(self, n_clusters=10): 111 | clusters = KMeans(n_clusters, random_state = 100).fit_predict(self.emb) 112 | self.clusters = clusters 113 | return clusters 114 | 115 | -------------------------------------------------------------------------------- /miso/nets.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn.utils.parametrizations import orthogonal 4 | 5 | class AE(nn.Module): 6 | def __init__(self, **kwargs): 7 | super().__init__() 8 | self.encoder = nn.Linear(in_features=kwargs["input_shape"], out_features=128) 9 | self.decoder = nn.Linear(in_features=128, out_features=kwargs["input_shape"]) 10 | self.relu = nn.ReLU() 11 | 12 | def forward(self, x): 13 | x = self.encoder(x) 14 | x = self.relu(x) 15 | x = self.decoder(x) 16 | x = self.relu(x) 17 | return x 18 | 19 | def get_embeddings(self,x): 20 | x = self.encoder(x) 21 | x = self.relu(x) 22 | return x 23 | 24 | class MLP(nn.Module): 25 | def __init__(self, **kwargs): 26 | super().__init__() 27 | self.layer1 = nn.Linear(kwargs["input_shape"], 32) 28 | #self.layer2 = nn.Linear(64, kwargs["output_shape"]) 29 | self.layer3 = orthogonal(nn.Linear(kwargs["output_shape"], kwargs["output_shape"])) 30 | self.layer4 = nn.Linear(32,kwargs["input_shape"]) 31 | self.relu = nn.ReLU() 32 | self.tanh = nn.Tanh() 33 | self.dropout = nn.Dropout() 34 | #self.softmax = nn.Softmax() 35 | 36 | def forward(self, x): 37 | x = self.layer1(x) 38 | #x = self.tanh(x) 39 | #x = self.layer2(x) 40 | x = self.layer3(x) 41 | x = self.layer4(x) 42 | #x = self.relu(x) 43 | #x1 = x.T@x + torch.eye(x.shape[1]).to(x.device)*1e-7 44 | #l = torch.cholesky(x1) 45 | #x = x@((x.shape[0])**(1/2)*l.inverse().T) 46 | #x, _ = torch.qr(x) 47 | return x 48 | 49 | def get_embeddings(self,x): 50 | x = self.layer3(self.layer1(x)) 51 | return x 52 | 53 | class MLP1(nn.Module): 54 | def __init__(self, **kwargs): 55 | super().__init__() 56 | self.layer1 = nn.Linear(32, kwargs["output_shape"]) 57 | self.layer2 = orthogonal(nn.Linear(kwargs["output_shape"], kwargs["output_shape"])) 58 | self.relu = nn.ReLU() 59 | #self.softmax = nn.Softmax() 60 | 61 | def forward(self, x): 62 | x = self.layer1(x) 63 | #x = self.relu(x) 64 | x = self.layer2(x) 65 | #x = self.relu(x) 66 | #x1 = x.T@x + torch.eye(x.shape[1]).to(x.device)*1e-7 67 | #l = torch.cholesky(x1) 68 | #x = x@((x.shape[0])**(1/2)*l.inverse().T) 69 | #x, _ = torch.qr(x) 70 | return x 71 | 72 | -------------------------------------------------------------------------------- /miso/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import pairwise_distances 3 | import torch 4 | import scipy 5 | import scanpy as sc 6 | import matplotlib.pyplot as plt 7 | import matplotlib.colors as mcolors 8 | from sklearn.neighbors import kneighbors_graph 9 | from PIL import Image 10 | import random 11 | 12 | def protein_norm(x): 13 | s = np.sum(np.log1p(x[x > 0])) 14 | exp = np.exp(s / len(x)) 15 | return np.log1p(x / exp) 16 | 17 | 18 | def preprocess(adata,modality): 19 | adata.var_names_make_unique() 20 | if modality in ['rna','atac']: 21 | sc.pp.filter_genes(adata,min_cells=10) 22 | sc.pp.log1p(adata) 23 | 24 | if scipy.sparse.issparse(adata.X): 25 | return adata.X.A 26 | else: 27 | return adata.X 28 | 29 | elif modality=='protein': 30 | adata.X = np.apply_along_axis(protein_norm, 1, (adata.X.A if scipy.sparse.issparse(adata.X) else np.array(adata.X))) 31 | return adata.X 32 | 33 | elif modality=='metabolite': 34 | sc.pp.log1p(adata) 35 | if scipy.sparse.issparse(adata.X): 36 | return adata.X.A 37 | else: 38 | return adata.X 39 | 40 | 41 | def calculate_affinity(X1, sig=30, sparse = False, neighbors = 100): 42 | if not sparse: 43 | dist1 = pairwise_distances(X1) 44 | a1 = np.exp(-1*(dist1**2)/(2*(sig**2))) 45 | return a1 46 | else: 47 | dist1 = kneighbors_graph(X1, n_neighbors = neighbors, mode='distance') 48 | dist1.data = np.exp(-1*(dist1.data**2)/(2*(sig**2))) 49 | dist1.eliminate_zeros() 50 | return dist1 51 | 52 | def cmap_tab20(x): 53 | cmap = plt.get_cmap('tab20') 54 | x = x % 20 55 | x = (x // 10) + (x % 10) * 2 56 | return cmap(x) 57 | 58 | 59 | 60 | def cmap_tab30(x): 61 | n_base = 20 62 | n_max = 30 63 | brightness = 0.7 64 | brightness = (brightness,) * 3 + (1.0,) 65 | isin_base = (x < n_base)[..., np.newaxis] 66 | isin_extended = ((x >= n_base) * (x < n_max))[..., np.newaxis] 67 | isin_beyond = (x >= n_max)[..., np.newaxis] 68 | color = ( 69 | isin_base * cmap_tab20(x) 70 | + isin_extended * cmap_tab20(x-n_base) * brightness 71 | + isin_beyond * (0.0, 0.0, 0.0, 1.0)) 72 | return color 73 | 74 | 75 | def cmap_tab70(x): 76 | cmap_base = cmap_tab30 77 | brightness = 0.5 78 | brightness = np.array([brightness] * 3 + [1.0]) 79 | color = [ 80 | cmap_base(x), # same as base colormap 81 | 1 - (1 - cmap_base(x-20)) * brightness, # brighter 82 | cmap_base(x-20) * brightness, # darker 83 | 1 - (1 - cmap_base(x-40)) * brightness**2, # even brighter 84 | cmap_base(x-40) * brightness**2, # even darker 85 | [0.0, 0.0, 0.0, 1.0], # black 86 | ] 87 | x = x[..., np.newaxis] 88 | isin = [ 89 | (x < 30), 90 | (x >= 30) * (x < 40), 91 | (x >= 40) * (x < 50), 92 | (x >= 50) * (x < 60), 93 | (x >= 60) * (x < 70), 94 | (x >= 70)] 95 | color_out = np.sum( 96 | [isi * col for isi, col in zip(isin, color)], 97 | axis=0) 98 | return color_out 99 | 100 | 101 | def plot(clusters,locs): 102 | locs['2'] = locs['2'].astype('int') 103 | locs['3'] = locs['3'].astype('int') 104 | im1 = np.empty((locs['2'].max()+1, locs['3'].max()+1)) 105 | im1[:] = np.nan 106 | im1[locs['2'],locs['3']] = clusters 107 | im2 = cmap_tab70(im1.astype('int')) 108 | im2[np.isnan(im1)] = 1 109 | im3 = Image.fromarray((im2 * 255).astype(np.uint8)) 110 | return im3 111 | 112 | def plot_on_histology(clusters, locs, im, scale, s=10): 113 | locs = locs*scale 114 | locs = locs.round().astype('int') 115 | im = im[(locs['4'].min()-10):(locs['4'].max()+10),(locs['5'].min()-10):(locs['5'].max()+10)] 116 | locs = locs-locs.min()+10 117 | cmap1 = mcolors.ListedColormap([cmap_tab70(np.array(i)) for i in range(len(np.unique(clusters)))]) 118 | plt.imshow(im, alpha=0.7); 119 | plot = plt.scatter(x=locs['5'], y=locs['4'], c = clusters, cmap=cmap1, s=s); 120 | plt.axis('off'); 121 | return plot 122 | 123 | def set_random_seed(seed=100): 124 | np.random.seed(seed) 125 | torch.manual_seed(seed) 126 | random.seed(seed) 127 | 128 | -------------------------------------------------------------------------------- /miso/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """ 12 | Mostly copy-paste from timm library. 13 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 14 | """ 15 | import math 16 | from functools import partial 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | 22 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 23 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 24 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 25 | def norm_cdf(x): 26 | # Computes standard normal cumulative distribution function 27 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 28 | 29 | if (mean < a - 2 * std) or (mean > b + 2 * std): 30 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 31 | "The distribution of values may be incorrect.", 32 | stacklevel=2) 33 | 34 | with torch.no_grad(): 35 | # Values are generated by using a truncated uniform distribution and 36 | # then using the inverse CDF for the normal distribution. 37 | # Get upper and lower cdf values 38 | l = norm_cdf((a - mean) / std) 39 | u = norm_cdf((b - mean) / std) 40 | 41 | # Uniformly fill tensor with values from [l, u], then translate to 42 | # [2l-1, 2u-1]. 43 | tensor.uniform_(2 * l - 1, 2 * u - 1) 44 | 45 | # Use inverse cdf transform for normal distribution to get truncated 46 | # standard normal 47 | tensor.erfinv_() 48 | 49 | # Transform to proper mean, std 50 | tensor.mul_(std * math.sqrt(2.)) 51 | tensor.add_(mean) 52 | 53 | # Clamp to ensure it's in the proper range 54 | tensor.clamp_(min=a, max=b) 55 | return tensor 56 | 57 | 58 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 59 | # type: (Tensor, float, float, float, float) -> Tensor 60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 61 | 62 | 63 | def drop_path(x, drop_prob: float = 0., training: bool = False): 64 | if drop_prob == 0. or not training: 65 | return x 66 | keep_prob = 1 - drop_prob 67 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 68 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 69 | random_tensor.floor_() # binarize 70 | output = x.div(keep_prob) * random_tensor 71 | return output 72 | 73 | 74 | class DropPath(nn.Module): 75 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 76 | """ 77 | def __init__(self, drop_prob=None): 78 | super(DropPath, self).__init__() 79 | self.drop_prob = drop_prob 80 | 81 | def forward(self, x): 82 | return drop_path(x, self.drop_prob, self.training) 83 | 84 | 85 | class Mlp(nn.Module): 86 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 87 | super().__init__() 88 | out_features = out_features or in_features 89 | hidden_features = hidden_features or in_features 90 | self.fc1 = nn.Linear(in_features, hidden_features) 91 | self.act = act_layer() 92 | self.fc2 = nn.Linear(hidden_features, out_features) 93 | self.drop = nn.Dropout(drop) 94 | 95 | def forward(self, x): 96 | x = self.fc1(x) 97 | x = self.act(x) 98 | x = self.drop(x) 99 | x = self.fc2(x) 100 | x = self.drop(x) 101 | return x 102 | 103 | 104 | class Attention(nn.Module): 105 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 106 | super().__init__() 107 | self.num_heads = num_heads 108 | head_dim = dim // num_heads 109 | self.scale = qk_scale or head_dim ** -0.5 110 | 111 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 112 | self.attn_drop = nn.Dropout(attn_drop) 113 | self.proj = nn.Linear(dim, dim) 114 | self.proj_drop = nn.Dropout(proj_drop) 115 | 116 | def forward(self, x): 117 | B, N, C = x.shape 118 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 119 | q, k, v = qkv[0], qkv[1], qkv[2] 120 | 121 | attn = (q @ k.transpose(-2, -1)) * self.scale 122 | attn = attn.softmax(dim=-1) 123 | attn = self.attn_drop(attn) 124 | 125 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 126 | x = self.proj(x) 127 | x = self.proj_drop(x) 128 | return x, attn 129 | 130 | 131 | class Block(nn.Module): 132 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 133 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 134 | super().__init__() 135 | self.norm1 = norm_layer(dim) 136 | self.attn = Attention( 137 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 138 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 139 | self.norm2 = norm_layer(dim) 140 | mlp_hidden_dim = int(dim * mlp_ratio) 141 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 142 | 143 | def forward(self, x, return_attention=False): 144 | y, attn = self.attn(self.norm1(x)) 145 | if return_attention: 146 | return attn 147 | x = x + self.drop_path(y) 148 | x = x + self.drop_path(self.mlp(self.norm2(x))) 149 | return x 150 | 151 | 152 | class PatchEmbed(nn.Module): 153 | """ Image to Patch Embedding 154 | """ 155 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 156 | super().__init__() 157 | num_patches = (img_size // patch_size) * (img_size // patch_size) 158 | self.img_size = img_size 159 | self.patch_size = patch_size 160 | self.num_patches = num_patches 161 | 162 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 163 | 164 | def forward(self, x): 165 | B, C, H, W = x.shape 166 | x = self.proj(x).flatten(2).transpose(1, 2) 167 | return x 168 | 169 | 170 | class VisionTransformer(nn.Module): 171 | """ Vision Transformer """ 172 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 173 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 174 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 175 | super().__init__() 176 | self.num_features = self.embed_dim = embed_dim 177 | 178 | self.patch_embed = PatchEmbed( 179 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 180 | num_patches = self.patch_embed.num_patches 181 | 182 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 183 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 184 | self.pos_drop = nn.Dropout(p=drop_rate) 185 | 186 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 187 | self.blocks = nn.ModuleList([ 188 | Block( 189 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 190 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 191 | for i in range(depth)]) 192 | self.norm = norm_layer(embed_dim) 193 | 194 | # Classifier head 195 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 196 | 197 | trunc_normal_(self.pos_embed, std=.02) 198 | trunc_normal_(self.cls_token, std=.02) 199 | self.apply(self._init_weights) 200 | 201 | def _init_weights(self, m): 202 | if isinstance(m, nn.Linear): 203 | trunc_normal_(m.weight, std=.02) 204 | if isinstance(m, nn.Linear) and m.bias is not None: 205 | nn.init.constant_(m.bias, 0) 206 | elif isinstance(m, nn.LayerNorm): 207 | nn.init.constant_(m.bias, 0) 208 | nn.init.constant_(m.weight, 1.0) 209 | 210 | def interpolate_pos_encoding(self, x, w, h): 211 | npatch = x.shape[1] - 1 212 | N = self.pos_embed.shape[1] - 1 213 | if npatch == N and w == h: 214 | return self.pos_embed 215 | class_pos_embed = self.pos_embed[:, 0] 216 | patch_pos_embed = self.pos_embed[:, 1:] 217 | dim = x.shape[-1] 218 | w0 = w // self.patch_embed.patch_size 219 | h0 = h // self.patch_embed.patch_size 220 | # we add a small number to avoid floating point error in the interpolation 221 | # see discussion at https://github.com/facebookresearch/dino/issues/8 222 | w0, h0 = w0 + 0.1, h0 + 0.1 223 | patch_pos_embed = nn.functional.interpolate( 224 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 225 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 226 | mode='bicubic', 227 | ) 228 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 229 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 230 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 231 | 232 | def prepare_tokens(self, x): 233 | B, nc, w, h = x.shape 234 | x = self.patch_embed(x) # patch linear embedding 235 | 236 | # add the [CLS] token to the embed patch tokens 237 | cls_tokens = self.cls_token.expand(B, -1, -1) 238 | x = torch.cat((cls_tokens, x), dim=1) 239 | 240 | # add positional encoding to each token 241 | x = x + self.interpolate_pos_encoding(x, w, h) 242 | 243 | return self.pos_drop(x) 244 | 245 | def forward(self, x): 246 | x = self.forward_all(x) 247 | return x[:, 0] 248 | 249 | def forward_all(self, x): 250 | x = self.prepare_tokens(x) 251 | for blk in self.blocks: 252 | x = blk(x) 253 | x = self.norm(x) 254 | return x 255 | 256 | def get_last_selfattention(self, x): 257 | x = self.prepare_tokens(x) 258 | for i, blk in enumerate(self.blocks): 259 | if i < len(self.blocks) - 1: 260 | x = blk(x) 261 | else: 262 | # return attention of the last block 263 | return blk(x, return_attention=True) 264 | 265 | def get_intermediate_layers(self, x, n=1): 266 | x = self.prepare_tokens(x) 267 | # we return the output tokens from the `n` last blocks 268 | output = [] 269 | for i, blk in enumerate(self.blocks): 270 | x = blk(x) 271 | if len(self.blocks) - i <= n: 272 | output.append(self.norm(x)) 273 | return output 274 | 275 | 276 | def vit_tiny(patch_size=16, **kwargs): 277 | model = VisionTransformer( 278 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 279 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 280 | return model 281 | 282 | 283 | def vit_small(patch_size=16, **kwargs): 284 | model = VisionTransformer( 285 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 286 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 287 | return model 288 | 289 | 290 | def vit_base(patch_size=16, **kwargs): 291 | model = VisionTransformer( 292 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 293 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 294 | return model 295 | 296 | 297 | class DINOHead(nn.Module): 298 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 299 | super().__init__() 300 | nlayers = max(nlayers, 1) 301 | if nlayers == 1: 302 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 303 | else: 304 | layers = [nn.Linear(in_dim, hidden_dim)] 305 | if use_bn: 306 | layers.append(nn.BatchNorm1d(hidden_dim)) 307 | layers.append(nn.GELU()) 308 | for _ in range(nlayers - 2): 309 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 310 | if use_bn: 311 | layers.append(nn.BatchNorm1d(hidden_dim)) 312 | layers.append(nn.GELU()) 313 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 314 | self.mlp = nn.Sequential(*layers) 315 | self.apply(self._init_weights) 316 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 317 | self.last_layer.weight_g.data.fill_(1) 318 | if norm_last_layer: 319 | self.last_layer.weight_g.requires_grad = False 320 | 321 | def _init_weights(self, m): 322 | if isinstance(m, nn.Linear): 323 | trunc_normal_(m.weight, std=.02) 324 | if isinstance(m, nn.Linear) and m.bias is not None: 325 | nn.init.constant_(m.bias, 0) 326 | 327 | def forward(self, x): 328 | x = self.mlp(x) 329 | x = nn.functional.normalize(x, dim=-1, p=2) 330 | x = self.last_layer(x) 331 | return x 332 | -------------------------------------------------------------------------------- /miso/vision_transformer4k.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def softmax(x, dim, inplace=False): 8 | if inplace: 9 | torch.exp(x, out=x) 10 | else: 11 | x = torch.exp(x) 12 | s = torch.sum(x, dim=dim, keepdim=True) 13 | x /= s 14 | return x 15 | 16 | 17 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 18 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 19 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 20 | def norm_cdf(x): 21 | # Computes standard normal cumulative distribution function 22 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 23 | 24 | if (mean < a - 2 * std) or (mean > b + 2 * std): 25 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 26 | "The distribution of values may be incorrect.", 27 | stacklevel=2) 28 | 29 | with torch.no_grad(): 30 | # Values are generated by using a truncated uniform distribution and 31 | # then using the inverse CDF for the normal distribution. 32 | # Get upper and lower cdf values 33 | l = norm_cdf((a - mean) / std) 34 | u = norm_cdf((b - mean) / std) 35 | 36 | # Uniformly fill tensor with values from [l, u], then translate to 37 | # [2l-1, 2u-1]. 38 | tensor.uniform_(2 * l - 1, 2 * u - 1) 39 | 40 | # Use inverse cdf transform for normal distribution to get truncated 41 | # standard normal 42 | tensor.erfinv_() 43 | 44 | # Transform to proper mean, std 45 | tensor.mul_(std * math.sqrt(2.)) 46 | tensor.add_(mean) 47 | 48 | # Clamp to ensure it's in the proper range 49 | tensor.clamp_(min=a, max=b) 50 | return tensor 51 | 52 | 53 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 54 | # type: (Tensor, float, float, float, float) -> Tensor 55 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 56 | 57 | 58 | 59 | def drop_path(x, drop_prob: float = 0., training: bool = False): 60 | if drop_prob == 0. or not training: 61 | return x 62 | keep_prob = 1 - drop_prob 63 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 64 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 65 | random_tensor.floor_() # binarize 66 | output = x.div(keep_prob) * random_tensor 67 | return output 68 | 69 | 70 | class DropPath(nn.Module): 71 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 72 | """ 73 | def __init__(self, drop_prob=None): 74 | super(DropPath, self).__init__() 75 | self.drop_prob = drop_prob 76 | 77 | def forward(self, x): 78 | return drop_path(x, self.drop_prob, self.training) 79 | 80 | 81 | class Mlp(nn.Module): 82 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 83 | super().__init__() 84 | out_features = out_features or in_features 85 | hidden_features = hidden_features or in_features 86 | self.fc1 = nn.Linear(in_features, hidden_features) 87 | self.act = act_layer() 88 | self.fc2 = nn.Linear(hidden_features, out_features) 89 | self.drop = nn.Dropout(drop) 90 | 91 | def forward(self, x): 92 | x = self.fc1(x) 93 | x = self.act(x) 94 | x = self.drop(x) 95 | x = self.fc2(x) 96 | x = self.drop(x) 97 | return x 98 | 99 | 100 | class Attention(nn.Module): 101 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 102 | super().__init__() 103 | self.num_heads = num_heads 104 | head_dim = dim // num_heads 105 | self.scale = qk_scale or head_dim ** -0.5 106 | 107 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 108 | self.attn_drop = nn.Dropout(attn_drop) 109 | self.proj = nn.Linear(dim, dim) 110 | self.proj_drop = nn.Dropout(proj_drop) 111 | 112 | def forward(self, x): 113 | B, N, C = x.shape 114 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | q, k, v = qkv[0], qkv[1], qkv[2] 116 | 117 | attn = q @ k.transpose(-2, -1) 118 | attn *= self.scale 119 | softmax(attn, dim=-1, inplace=True) # attn = attn.softmax(dim=-1) 120 | attn = self.attn_drop(attn) 121 | 122 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 123 | del qkv, q, k, v 124 | x = self.proj(x) 125 | x = self.proj_drop(x) 126 | return x, attn 127 | 128 | 129 | class Block(nn.Module): 130 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 131 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 132 | super().__init__() 133 | self.norm1 = norm_layer(dim) 134 | self.attn = Attention( 135 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 136 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 137 | self.norm2 = norm_layer(dim) 138 | mlp_hidden_dim = int(dim * mlp_ratio) 139 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 140 | 141 | def forward(self, x, return_attention=False): 142 | y, attn = self.attn(self.norm1(x)) 143 | if return_attention: 144 | return attn 145 | x = x + self.drop_path(y) 146 | x = x + self.drop_path(self.mlp(self.norm2(x))) 147 | return x 148 | 149 | 150 | class VisionTransformer4K(nn.Module): 151 | """ Vision Transformer 4K """ 152 | def __init__(self, num_classes=0, img_size=[224], input_embed_dim=384, output_embed_dim = 192, 153 | depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 154 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, num_prototypes=64, **kwargs): 155 | super().__init__() 156 | embed_dim = output_embed_dim 157 | self.num_features = self.embed_dim = embed_dim 158 | self.phi = nn.Sequential(*[nn.Linear(input_embed_dim, output_embed_dim), nn.GELU(), nn.Dropout(p=drop_rate)]) 159 | num_patches = int(img_size[0] // 16)**2 160 | #print("# of Patches:", num_patches) 161 | 162 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 163 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 164 | self.pos_drop = nn.Dropout(p=drop_rate) 165 | 166 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 167 | self.blocks = nn.ModuleList([ 168 | Block( 169 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 170 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 171 | for i in range(depth)]) 172 | self.norm = norm_layer(embed_dim) 173 | 174 | # Classifier head 175 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 176 | 177 | trunc_normal_(self.pos_embed, std=.02) 178 | trunc_normal_(self.cls_token, std=.02) 179 | self.apply(self._init_weights) 180 | 181 | def _init_weights(self, m): 182 | if isinstance(m, nn.Linear): 183 | trunc_normal_(m.weight, std=.02) 184 | if isinstance(m, nn.Linear) and m.bias is not None: 185 | nn.init.constant_(m.bias, 0) 186 | elif isinstance(m, nn.LayerNorm): 187 | nn.init.constant_(m.bias, 0) 188 | nn.init.constant_(m.weight, 1.0) 189 | 190 | def interpolate_pos_encoding(self, x, w, h): 191 | npatch = x.shape[1] - 1 192 | N = self.pos_embed.shape[1] - 1 193 | if npatch == N and w == h: 194 | return self.pos_embed 195 | class_pos_embed = self.pos_embed[:, 0] 196 | patch_pos_embed = self.pos_embed[:, 1:] 197 | dim = x.shape[-1] 198 | w0 = w // 1 199 | h0 = h // 1 200 | # we add a small number to avoid floating point error in the interpolation 201 | # see discussion at https://github.com/facebookresearch/dino/issues/8 202 | w0, h0 = w0 + 0.1, h0 + 0.1 203 | patch_pos_embed = nn.functional.interpolate( 204 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 205 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 206 | mode='bicubic', 207 | ) 208 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 209 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 210 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 211 | 212 | def prepare_tokens(self, x): 213 | #print('preparing tokens (after crop)', x.shape) 214 | self.mpp_feature = x 215 | B, embed_dim, w, h = x.shape 216 | x = x.flatten(2, 3).transpose(1,2) 217 | 218 | x = self.phi(x) 219 | 220 | 221 | # add the [CLS] token to the embed patch tokens 222 | cls_tokens = self.cls_token.expand(B, -1, -1) 223 | x = torch.cat((cls_tokens, x), dim=1) 224 | 225 | # add positional encoding to each token 226 | x = x + self.interpolate_pos_encoding(x, w, h) 227 | 228 | return self.pos_drop(x) 229 | 230 | def forward(self, x): 231 | x = self.forward_all(x) 232 | return x[:, 0] 233 | 234 | def forward_all(self, x): 235 | x = self.prepare_tokens(x) 236 | for blk in self.blocks: 237 | x = blk(x) 238 | x = self.norm(x) 239 | return x 240 | 241 | def get_last_selfattention(self, x): 242 | x = self.prepare_tokens(x) 243 | for i, blk in enumerate(self.blocks): 244 | if i < len(self.blocks) - 1: 245 | x = blk(x) 246 | else: 247 | # return attention of the last block 248 | return blk(x, return_attention=True) 249 | 250 | def get_intermediate_layers(self, x, n=1): 251 | x = self.prepare_tokens(x) 252 | # we return the output tokens from the `n` last blocks 253 | output = [] 254 | for i, blk in enumerate(self.blocks): 255 | x = blk(x) 256 | if len(self.blocks) - i <= n: 257 | output.append(self.norm(x)) 258 | return output 259 | 260 | def vit4k_xs(patch_size=16, **kwargs): 261 | model = VisionTransformer4K( 262 | patch_size=patch_size, input_embed_dim=384, output_embed_dim=192, 263 | depth=6, num_heads=6, mlp_ratio=4, 264 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 265 | return model 266 | 267 | def count_parameters(model): 268 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 269 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.0 2 | importlib 3 | importlib-metadata 4 | numpy==1.21.6 5 | opencv_python==4.6.0.66 6 | Pillow>=6.1.0 7 | scanpy==1.9.1 8 | scikit_image==0.19.3 9 | scikit_learn==1.0.2 10 | scipy==1.7.3 11 | setuptools==65.6.3 12 | torch==1.13.1 13 | torchvision==0.14.1 14 | tqdm==4.64.1 15 | 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import pathlib 3 | 4 | here = pathlib.Path(__file__).parent.resolve() 5 | 6 | long_description = (here / "README.md").read_text(encoding="utf-8") 7 | 8 | setup( 9 | name="miso", 10 | version="0.1.0", 11 | description="Resolving tissue complexity by multi-modal spatial omics modeling with MISO", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/kpcoleman/miso", 15 | author="Kyle Coleman", 16 | author_email="kpcoleman87@gmail.com", 17 | packages=find_packages(), 18 | #include_package_data=True, 19 | python_requires="==3.7.*", 20 | package_data = {'miso': ['checkpoints/*.pth'],}, 21 | install_requires=["scikit-learn==1.0.2","scikit_image==0.19.3","torch==1.13.1","torchvision==0.14.1","numpy==1.21.6","Pillow>=6.1.0","opencv-python==4.6.0.66","scipy==1.7.3","einops==0.6.0","scanpy==1.9.1","tqdm==4.64.1"], 22 | classifiers=[ 23 | "Programming Language :: Python :: 3", 24 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 25 | "Operating System :: OS Independent", 26 | ] 27 | ) 28 | 29 | 30 | --------------------------------------------------------------------------------