├── fastglioma ├── utils │ ├── __init__.py │ ├── format_slide_embedding.py │ └── common.py ├── datasets │ ├── __init__.py │ ├── improc.py │ ├── embedding_dataset.py │ └── emb_proc.py ├── models │ ├── __init__.py │ ├── cnn.py │ ├── vit.py │ ├── mil.py │ └── resnet.py ├── tf │ ├── README.md │ ├── config │ │ └── feedforward.yaml │ ├── feedforward.py │ └── transformer.py ├── eval │ ├── config │ │ ├── save_hidisc.yaml │ │ ├── eval_scm.yaml │ │ └── eval_hidisc.yaml │ ├── save_embedding.py │ └── eval_knn.py ├── inference │ ├── config │ │ └── infer.yaml │ └── run_inference.py ├── train │ ├── config │ │ ├── train_scm.yaml │ │ ├── train_ordmet.yaml │ │ └── train_hidisc.yaml │ ├── train_slide.py │ ├── train_scorer.py │ └── train_patch.py └── losses │ ├── hidisc.py │ ├── vicreg.py │ ├── ordmet.py │ └── supcon.py ├── figures ├── Figure_1.png └── Figure_2.png ├── .gitignore ├── setup.py ├── LICENSE ├── THIRD_PARTY └── README.md /fastglioma/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLNeurosurg/fastglioma/HEAD/figures/Figure_1.png -------------------------------------------------------------------------------- /figures/Figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MLNeurosurg/fastglioma/HEAD/figures/Figure_2.png -------------------------------------------------------------------------------- /fastglioma/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .improc import get_srh_base_aug 2 | from .srh_dataset import PatchDataset, SlideDataset, slide_collate_fn -------------------------------------------------------------------------------- /fastglioma/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cnn import MLP, ContrastiveLearningNetwork, VICRegNetwork 2 | from .mil import MIL_Classifier, TransformerMIL -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | env/ 2 | logs/ 3 | model/ 4 | __pycache__/ 5 | *.egg-info/ 6 | out/ 7 | .vscode/ 8 | slurm*.out 9 | .DS_Store 10 | *.pt 11 | *.npy 12 | *.npz 13 | .ipynb_checkpoints 14 | lightning_logs/ 15 | *bak 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="fastglioma", 5 | version="1.0", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "setuptools", 9 | "pip", 10 | "pytest", 11 | "yapf", 12 | "tqdm", 13 | "pyyaml", 14 | "pytest==8.3.2", 15 | "pandas==1.5.3", 16 | "numpy==1.24.4", 17 | "matplotlib==3.6.3", 18 | "tifffile==2020.10.1", 19 | "scikit-learn==1.4.1.post1", 20 | "scikit-image", 21 | "opencv-python==3.4.18.65", 22 | "torch==1.13.0", 23 | "torchvision==0.14.0", 24 | "pytorch-lightning==1.8.4", 25 | "huggingface-hub==0.24.6", 26 | "timm", 27 | "tensorboard" # yapf:disable 28 | ]) 29 | -------------------------------------------------------------------------------- /fastglioma/tf/README.md: -------------------------------------------------------------------------------- 1 | ## Tensorflow implementation 2 | 3 | This directory contains the TensorFlow implementation of the FastGlioma model, for investigational use on the NIO imager. `resnet.py` and `transformer.py` are re-implementations of the PyTorch models in the `models/` directory. Similar to the `inference/` directory, the `feedforward.yaml` file is used to specify the model and inference parameters and the `feedforward.py` script is used to predict on the OpenSRH dataset, starting from the FastGlioma PyTorch checkpoint on HuggingFace. `feedforward.py` has an additional flag, `eval/compare_to_torch`, to directly compare the outputs of the PyTorch and TensorFlow models. 4 | 5 | All intermediate outputs and final logits of the TensorFlow implementation satisfy an absolute tolerance of 1e-5 of the original PyTorch implementation. TensorFlow 2.15.1 and an updated version of NumPy (1.26.0) was used for this implementation. 6 | 7 | -------------------------------------------------------------------------------- /fastglioma/eval/config/save_hidisc.yaml: -------------------------------------------------------------------------------- 1 | infra: 2 | log_dir: ./ # where all the experiments are 3 | exp_name: fastglioma/ # where you want to save the inference results 4 | comment: save_embed # can use this to customize for each save 5 | seed: 1000 6 | data: 7 | db_root: /path/to/opensrh/ 8 | studies: val # train/val or specify studies (e.g., ["NIO_001", "NIO_004"]) 9 | train_augmentation: [] 10 | valid_augmentation: 11 | - which: inpaint_rows_always_apply 12 | params: 13 | image_size: 300 14 | y_skip: 5 15 | srh_base_augmentation: three_channels 16 | rand_aug_prob: 0.3 17 | model: 18 | backbone: 19 | which: resnet34 20 | params: 21 | num_channel_in: 3 # {1: lowres, 3: highres} 22 | mlp_hidden: [] 23 | num_embedding_out: 128 24 | eval: 25 | predict_batch_size: 128 26 | ckpt_path: /path/to/pretrained/ckpt 27 | save_by_slide: 28 | saving_dir: /path/to/embeddings/ 29 | tag: tag1 -------------------------------------------------------------------------------- /fastglioma/eval/config/eval_scm.yaml: -------------------------------------------------------------------------------- 1 | infra: 2 | log_dir: ./ # where all the experiments are 3 | exp_name: fastglioma/ # use the same name as the training experiment 4 | comment: scm_vicreg_dev # can use this to customize for each experiment 5 | seed: 1000 6 | data: 7 | db_root: /path/to/opensrh/ 8 | embedding_root: /path/to/embeddings/ 9 | train_augmentation: [] 10 | valid_augmentation: same 11 | rand_aug_prob: 1. 12 | tag: [tag1] # specify tag for eval 13 | model: 14 | backbone: 15 | which: transformer 16 | params: 17 | embed_dim: 512 18 | depth: 2 19 | num_heads: 4 20 | pos_emb_type: FFPEG 21 | pos_emb_grad: True 22 | prefix_len: 8 23 | mlp_hidden: [512] 24 | num_embedding_out: 128 25 | train_alg: scm 26 | eval: 27 | predict_batch_size: 128 28 | knn: 29 | batch_size: 128 30 | k: 10 31 | t: 0.07 32 | ckpt_path: relative/path/to/checkpoint.ckpt # eg. hash_datetime_expname_comment/models/ckpt-epochXX-accXXX.ckpt 33 | -------------------------------------------------------------------------------- /fastglioma/eval/config/eval_hidisc.yaml: -------------------------------------------------------------------------------- 1 | infra: 2 | log_dir: ./ # where all the experiments are 3 | exp_name: fastglioma/ # use the same name as the training experiment 4 | comment: patient_disc_dev # can use this to customize for each experiment 5 | seed: 1000 6 | data: 7 | db_root: /path/to/opensrh/ 8 | train_augmentation: # specify inpaint inrows for fastsrh 9 | - which: inpaint_rows_always_apply 10 | params: 11 | image_size: 300 12 | y_skip: 5 13 | valid_augmentation: same 14 | srh_base_augmentation: ch2_only # specify srh base augmentation 15 | rand_aug_prob: 1. 16 | model: 17 | backbone: 18 | which: resnet34 19 | params: 20 | num_channel_in: 3 # 3 if fullsrh, 1 if fastsrh 21 | mlp_hidden: [] 22 | num_embedding_out: 128 23 | train_alg: hidisc 24 | eval: 25 | predict_batch_size: 128 26 | knn: 27 | batch_size: 1024 28 | k: 200 29 | t: 0.07 30 | ckpt_path: relative/path/to/checkpoint.ckpt # eg. hash_datetime_expname_comment/models/ckpt-epochXX-accXXX.ckpt 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 University of Michigan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /fastglioma/inference/config/infer.yaml: -------------------------------------------------------------------------------- 1 | infra: 2 | log_dir: ./ # where all the experiments are 3 | exp_name: fastglioma/ # where you want to save the inference results 4 | comment: inference # can use this to customize for each inference 5 | hf_repo: mlinslab/fastglioma # specify huggingface repo to be used 6 | seed: 1000 7 | data: 8 | db_root: /path/to/opensrh/ 9 | studies: ["NIO_001"] # val or specify studies (e.g., ["NIO_001", "NIO_004"]) 10 | patch_input: highres # {highres, lowres} 11 | use_patient_class: true # whether to use patient or slide class for inference 12 | model: 13 | patch: 14 | backbone: 15 | which: resnet34 16 | params: 17 | num_channel_in: 3 # {1: lowres, 3: highres} 18 | mlp_hidden: [] 19 | num_embedding_out: 128 20 | slide: 21 | mil: 22 | which: transformer 23 | params: 24 | embed_dim: 512 25 | depth: 2 26 | num_heads: 4 27 | pos_emb_type: FFPEG 28 | pos_emb_grad: True 29 | prefix_len: 8 30 | mlp_hidden: [512] 31 | eval: 32 | predict_batch_size: 4 # keep low to avoid opening too many image files at once 33 | ckpt_path: fastglioma_highres_model.ckpt # path to ckpt in huggingface repo -------------------------------------------------------------------------------- /fastglioma/tf/config/feedforward.yaml: -------------------------------------------------------------------------------- 1 | infra: 2 | log_dir: ./ # where all the experiments are 3 | exp_name: fastglioma/ # where you want to save the inference results 4 | comment: inference_tensorflow # can use this to customize for each inference 5 | hf_repo: mlinslab/fastglioma # specify huggingface repo to be used 6 | seed: 1000 7 | data: 8 | db_root: /path/to/opensrh 9 | studies: ["NIO_001"] # val or specify studies (e.g., ["NIO_001", "NIO_004"]) 10 | patch_input: highres # {highres, lowres} 11 | use_patient_class: true # whether to use patient or slide class for inference 12 | model: 13 | patch: 14 | backbone: 15 | which: resnet34 16 | params: 17 | num_channel_in: 3 # {1: lowres, 3: highres} 18 | mlp_hidden: [] 19 | num_embedding_out: 128 20 | slide: 21 | mil: 22 | which: transformer 23 | params: 24 | embed_dim: 512 25 | depth: 2 26 | num_heads: 4 27 | pos_emb_type: FFPEG 28 | pos_emb_grad: True 29 | prefix_len: 8 30 | mlp_hidden: [512] 31 | eval: 32 | predict_batch_size: 1 # tf feedforward only supported on cpu, keep at 1 for now 33 | ckpt_path: fastglioma_highres_model.ckpt # path to ckpt in huggingface repo 34 | compare_to_torch: true # compare to torch model on dummy data -------------------------------------------------------------------------------- /fastglioma/train/config/train_scm.yaml: -------------------------------------------------------------------------------- 1 | infra: 2 | log_dir: ./ # where all the experiments are 3 | exp_name: fastglioma/ # create a subdirectory for each set of experiments 4 | comment: scm_vicreg_dev # can use this to customize for each experiment 5 | seed: 1000 6 | data: 7 | db_root: /path/to/opensrh/ 8 | embedding_root: /path/to/embeddings/ 9 | train_augmentation: 10 | - which: random_splitting 11 | params: 12 | masking_ratio: [0.7, 0.3] 13 | - which: random_cropping 14 | params: 15 | masking_size_ranges: [[100, 200], [50, 150]] 16 | masking_aspect_ratio_range: [[0.5, 2], [0.5, 2]] 17 | - which: random_masking 18 | params: 19 | masking_ratio_ranges: [[0.1, 0.8], [0.1, 0.8]] 20 | valid_augmentation: same 21 | tag: [tag1, tag2, tag3] # specify embedding tags 22 | rand_aug_prob: 1. 23 | num_transforms: 2 24 | balance_study_per_class: false 25 | model: 26 | backbone: 27 | which: transformer 28 | params: 29 | embed_dim: 512 30 | depth: 2 31 | num_heads: 4 32 | pos_emb_type: FFPEG 33 | pos_emb_grad: True 34 | prefix_len: 8 35 | mlp_hidden: [512] 36 | num_embedding_out: 128 37 | training: 38 | objective: 39 | which: vicreg 40 | params: 41 | std_coeff: 10. 42 | sim_coeff: 10. 43 | cov_coeff: 1. 44 | epsilon: 1.0e-4 45 | batch_size: 4 46 | num_epochs: 40000 47 | optimizer: adamw # [sgd, adam, adamw] 48 | learn_rate: 3.0e-4 49 | scheduler: 50 | which: cos_warmup 51 | params: 52 | num_warmup_steps: 0.1 53 | num_cycles: 0.5 54 | imagenet_backbone_checkpoint: null 55 | eval_ckpt_ep_freq: 1 56 | amp: 32 57 | deterministic: false -------------------------------------------------------------------------------- /fastglioma/train/config/train_ordmet.yaml: -------------------------------------------------------------------------------- 1 | infra: 2 | log_dir: ./ # where all the experiments are 3 | exp_name: fastglioma/ # create a subdirectory for each set of experiments 4 | comment: ordmet_dev # can use this to customize for each experiment 5 | seed: 1000 6 | data: 7 | db_root: /path/to/opensrh/ 8 | embedding_root: /path/to/embeddings/ 9 | meta_fname: opensrh_withslideclass.json 10 | use_patient_class: false 11 | train_augmentation: 12 | - which: random_splitting 13 | params: 14 | masking_ratio: [0.7, 0.3] 15 | - which: random_cropping 16 | params: 17 | masking_size_ranges: [[100, 200], [50, 150]] 18 | masking_aspect_ratio_range: [[0.5, 2], [0.5, 2]] 19 | - which: random_masking 20 | params: 21 | masking_ratio_ranges: [[0.1, 0.8], [0.1, 0.8]] 22 | valid_augmentation: [] 23 | tag: [tag1, tag2, tag3] 24 | rand_aug_prob: 1. 25 | num_transforms: 2 26 | balance_study_per_class: false 27 | model: 28 | backbone: 29 | which: transformer 30 | params: 31 | embed_dim: 512 32 | depth: 2 33 | num_heads: 4 34 | pos_emb_type: FFPEG 35 | pos_emb_grad: True 36 | prefix_len: 8 37 | mlp_hidden: [512] 38 | training: 39 | load_backbone: 40 | ckpt_path: /path/to/pretrained/ckpt 41 | finetune: false 42 | objective: 43 | which: ordmet 44 | params: {} 45 | batch_size: 16 46 | num_epochs: 40000 47 | optimizer: adamw # [sgd, adam, adamw] 48 | learn_rate: 0.00001875 49 | scheduler: 50 | which: cos_warmup 51 | params: 52 | num_warmup_steps: 0.1 53 | num_cycles: 0.5 54 | imagenet_backbone_checkpoint: null 55 | eval_ckpt_ep_freq: 400 56 | amp: 32 57 | deterministic: false -------------------------------------------------------------------------------- /fastglioma/losses/hidisc.py: -------------------------------------------------------------------------------- 1 | """HiDisc loss module. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | from typing import Tuple, Dict, Optional, Any 8 | import torch 9 | from torch import nn 10 | 11 | from fastglioma.losses.supcon import SupConLoss 12 | 13 | 14 | class HiDiscLoss(nn.Module): 15 | """Computes the HiDisc loss 16 | 17 | Input representation needs to be normalized 18 | """ 19 | 20 | def __init__(self, 21 | lambda_patient: Optional[float] = 1.0, 22 | lambda_slide: Optional[float] = 1.0, 23 | lambda_patch: Optional[float] = 1.0, 24 | supcon_loss_params: Optional[Dict] = {}): 25 | super(HiDiscLoss, self).__init__() 26 | self.criterion = SupConLoss(**supcon_loss_params) 27 | self.lambda_patient_ = lambda_patient 28 | self.lambda_slide_ = lambda_slide 29 | self.lambda_patch_ = lambda_patch 30 | 31 | def forward(self, features, labels=None): 32 | emb_sz = features.shape[-1] 33 | sz_prod = lambda x: torch.prod(torch.tensor(x)) 34 | feat_shape = features.shape 35 | 36 | patient_emb = features.reshape(feat_shape[0], -1, emb_sz) 37 | slide_emb = features.reshape(sz_prod(feat_shape[0:2]), -1, emb_sz) 38 | patch_emb = features.reshape(sz_prod(feat_shape[0:3]), -1, emb_sz) 39 | 40 | patient_loss = self.criterion(patient_emb, None) 41 | slide_loss = self.criterion(slide_emb, None) 42 | patch_loss = self.criterion(patch_emb, None) 43 | 44 | loss = ((self.lambda_patient_ * patient_loss) + 45 | (self.lambda_slide_ * slide_loss) + 46 | (self.lambda_patch_ * patch_loss)) 47 | 48 | return { 49 | "patient_loss": patient_loss, 50 | "slide_loss": slide_loss, 51 | "patch_loss": patch_loss, 52 | "sum_loss": loss 53 | } -------------------------------------------------------------------------------- /fastglioma/losses/vicreg.py: -------------------------------------------------------------------------------- 1 | """VICReg loss module. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def off_diagonal(x): 13 | n, m = x.shape 14 | assert n == m 15 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 16 | 17 | 18 | class GeneralVICRegLoss(nn.Module): 19 | """VICReg Loss""" 20 | 21 | def __init__(self, embedding_dim: int, sim_coeff: float=25, std_coeff: float=25, 22 | cov_coeff: float=1, epsilon: float=1.e-4): 23 | super(GeneralVICRegLoss, self).__init__() 24 | self.embedding_dim = embedding_dim 25 | self.sim_coeff_ = sim_coeff 26 | self.std_coeff_ = std_coeff 27 | self.cov_coeff_ = cov_coeff 28 | self.epsilon_ = epsilon 29 | 30 | @staticmethod 31 | def get_loss_names(): 32 | return ["inv", "var", "cov", "loss"] 33 | 34 | def var_loss(self, x): 35 | std_x = torch.sqrt(x.var(dim=0) + self.epsilon_) 36 | std_x = torch.mean(F.relu(1 - std_x)) 37 | return std_x 38 | 39 | def cov_loss(self, x): 40 | cov_x = (x.T @ x) / (x.shape[0] - 1) 41 | return off_diagonal(cov_x).pow_(2).sum().div(self.embedding_dim) 42 | 43 | def forward(self, x, _=None): # _=None for future supervised version 44 | 45 | # inv loss 46 | n_views = x.shape[1] 47 | if n_views == 1: 48 | repr_loss = 0 49 | else: 50 | repr_loss = torch.mean( 51 | torch.stack([ 52 | F.mse_loss(x[:, i, :], x[:, j, :]) for i in range(n_views) 53 | for j in range(i + 1, n_views) 54 | ])) 55 | 56 | x = x - x.mean(dim=0) 57 | 58 | # var, cov loss 59 | std_loss = torch.mean( 60 | torch.stack( 61 | [self.var_loss(x_.squeeze()) for x_ in x.split(1, dim=1)])) 62 | cov_loss = torch.mean( 63 | torch.stack( 64 | [self.cov_loss(x_.squeeze()) for x_ in x.split(1, dim=1)])) 65 | 66 | loss = (self.sim_coeff_ * repr_loss + self.std_coeff_ * std_loss + 67 | self.cov_coeff_ * cov_loss) 68 | 69 | return { 70 | "inv": repr_loss, 71 | "var": std_loss, 72 | "cov": cov_loss, 73 | "loss": loss 74 | } -------------------------------------------------------------------------------- /fastglioma/train/config/train_hidisc.yaml: -------------------------------------------------------------------------------- 1 | infra: 2 | log_dir: ./ # where all the experiments are 3 | exp_name: fastglioma/ # create a subdirectory for each set of experiments 4 | comment: patient_disc_dev # can use this to customize for each experiment 5 | seed: 1000 6 | data: 7 | db_root: /nfs/turbo/umms-tocho/data/opensrh/ 8 | train_augmentation: 9 | - which: inpaint_rows 10 | params: 11 | image_size: 300 12 | y_skip: 5 13 | - which: random_horiz_flip 14 | params: {} 15 | - which: random_vert_flip 16 | params: {} 17 | - which: gaussian_noise 18 | params: {} 19 | - which: color_jitter 20 | params: {} 21 | - which: random_autocontrast 22 | params: {} 23 | - which: random_solarize 24 | params: 25 | threshold: 0.2 26 | - which: random_sharpness 27 | params: 28 | sharpness_factor: 2 29 | - which: gaussian_blur 30 | params: 31 | kernel_size: 5 32 | sigma: 1 33 | - which: random_affine 34 | params: 35 | degrees: 10 36 | translate: [0.1, 0.3] 37 | - which: random_resized_crop 38 | params: 39 | size: 300 40 | - which: random_erasing 41 | params: {} 42 | valid_augmentation: same 43 | srh_base_augmentation: three_channels #ch2_only 44 | rand_aug_prob: 0.3 45 | hidisc: 46 | num_slide_samples: 2 47 | num_patch_samples: 2 48 | num_transforms: 2 49 | balance_study_per_class: true 50 | model: 51 | backbone: 52 | which: resnet34 53 | params: 54 | num_channel_in: 3 55 | mlp_hidden: [] 56 | num_embedding_out: 128 57 | training: 58 | objective: 59 | which: hidisc 60 | params: 61 | lambda_patient: 1.0 62 | lambda_slide: 1.0 63 | lambda_patch: 1.0 64 | supcon_params: 65 | temperature: 0.07 66 | base_temperature: 0.07 67 | contrast_mode: all 68 | batch_size: 128 69 | num_epochs: 40000 70 | optimizer: adamw # [sgd, adam, adamw] 71 | learn_rate: 1.0e-3 72 | scheduler: 73 | which: cos_warmup 74 | params: 75 | num_warmup_steps: 0.1 76 | num_cycles: 0.5 77 | imagenet_backbone_checkpoint: null 78 | eval_ckpt_ep_freq: 400 79 | amp: 32 80 | deterministic: false -------------------------------------------------------------------------------- /fastglioma/models/cnn.py: -------------------------------------------------------------------------------- 1 | """Model wrappers. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | from typing import Dict, List 8 | from itertools import chain 9 | 10 | import torch 11 | from torch import nn as nn 12 | 13 | 14 | class MLP(nn.Module): 15 | """MLP for classification head. 16 | 17 | Forward pass returns a tensor. 18 | """ 19 | 20 | def __init__(self, n_in: int, hidden_layers: List[int], 21 | n_out: int) -> None: 22 | super().__init__() 23 | layers_in = [n_in] + hidden_layers 24 | layers_out = hidden_layers + [n_out] 25 | 26 | layers_list = list( 27 | chain.from_iterable((nn.Linear(a, b), nn.ReLU()) 28 | for a, b in zip(layers_in, layers_out)))[:-1] 29 | self.layers = nn.Sequential(*layers_list) 30 | 31 | def init_weights(m): 32 | if isinstance(m, nn.Linear): 33 | torch.nn.init.xavier_uniform_(m.weight) 34 | m.bias.data.fill_(0.01) 35 | 36 | self.layers.apply(init_weights) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | return self.layers(x) 40 | 41 | 42 | class ContrastiveLearningNetwork(torch.nn.Module): 43 | """A network consists of a backbone and projection head. 44 | 45 | Forward pass returns the normalized embeddings after a projection layer. 46 | """ 47 | 48 | def __init__(self, backbone: callable, proj: callable): 49 | super(ContrastiveLearningNetwork, self).__init__() 50 | self.bb = backbone() 51 | self.proj = proj() 52 | 53 | def forward(self, x: torch.Tensor) -> torch.Tensor: 54 | bb_out = self.bb(x) 55 | #bb_out_norm = torch.nn.functional.normalize(bb_out, p=2.0, dim=1) 56 | proj_out = self.proj(bb_out) 57 | proj_out_norm = torch.nn.functional.normalize(proj_out, p=2.0, dim=1) 58 | 59 | return proj_out_norm.unsqueeze(1) 60 | 61 | 62 | class VICRegNetwork(torch.nn.Module): 63 | """A network consists of a backbone and projection head. 64 | 65 | Forward pass returns the normalized embeddings after a projection layer. 66 | """ 67 | 68 | def __init__(self, backbone: callable, proj: callable): 69 | super(VICRegNetwork, self).__init__() 70 | self.bb = backbone() 71 | self.proj = proj() 72 | 73 | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 74 | return self.proj(self.bb(x, **kwargs)) -------------------------------------------------------------------------------- /fastglioma/utils/format_slide_embedding.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import os 4 | import copy 5 | import shutil 6 | import gzip 7 | import numpy as np 8 | import pandas as pd 9 | from typing import Optional, List, TypedDict, Any, NamedTuple 10 | from collections import defaultdict 11 | from tqdm import tqdm 12 | from glob import glob 13 | import torch 14 | 15 | 16 | def nio_patchpath_to_coord(patch_name: str): 17 | """Converts SRH patch pathname to its 300x300 coordinate on the WSI 18 | Example: `/.../NIO_UM_135-6-0_1000_300_600.tif` -> (1, 5) 19 | """ 20 | ix, jx, iy, jy = patch_name.split("/")[-1].split("-")[-1].split( 21 | ".")[0].split("_") 22 | return (3 * (int(ix) // 1000) + int(iy) // 300, 23 | 3 * (int(jx) // 1000) + int(jy) // 300) 24 | 25 | 26 | def prediction_to_slide_embedding(saving_dir: str, 27 | tag: str, 28 | embed_path: str = "", 29 | predictions=None): 30 | 31 | """Preassign data indices by either slide/patient.""" 32 | if embed_path and not predictions: 33 | logging.info(f"Loading {embed_path}") 34 | if embed_path.endswith(".gz"): 35 | with gzip.open(embed_path) as f: 36 | predictions = torch.load(f) 37 | else: 38 | predictions = torch.load(embed_path) 39 | logging.info(f"Loading {embed_path} - OK") 40 | 41 | assert ((len(predictions["path"]) == len(predictions["label"])) or 42 | (len(predictions["label"]) == len(predictions["embeddings"]))) 43 | 44 | if not (embed_path or predictions): 45 | raise ValueError("embed_path or predictions should be specified.") 46 | 47 | # make dictionary for all slides from the patch predictions 48 | slide_instances_ = defaultdict(list) 49 | for idx in tqdm(range(len(predictions["path"]))): 50 | path_i = predictions["path"][idx] 51 | label_i = predictions["label"][idx] 52 | embedding_i = predictions["embeddings"][idx] 53 | patient_name = path_i.split("/")[-4] 54 | slide_name = path_i.split("/")[-3] 55 | slide_instances_[patient_name + "." + str(slide_name)].append( 56 | [path_i, label_i, embedding_i]) 57 | 58 | # process each slide and save 59 | for slide_id in tqdm(slide_instances_): 60 | patches = slide_instances_[slide_id] 61 | coords = [nio_patchpath_to_coord(p[0]) for p in patches] 62 | 63 | # sort patch order 64 | sorted_indices = sorted(enumerate(coords), key=lambda x: x[1]) 65 | ordered_idx = [index for index, _ in sorted_indices] 66 | patches_path = [] 67 | patches_label = [] 68 | patches_embeddings = [] 69 | patches_coords = [] 70 | for idx in ordered_idx: 71 | patches_path.append(patches[idx][0]) 72 | patches_label.append(patches[idx][1]) 73 | patches_embeddings.append(patches[idx][2]) 74 | patches_coords.append(coords[idx]) 75 | 76 | patches_label = torch.stack(patches_label) 77 | patches_embeddings = torch.stack(patches_embeddings) 78 | slide_data = { 79 | "path": patches_path, 80 | "labels": patches_label, 81 | "embeddings": patches_embeddings, 82 | "coords": patches_coords 83 | } 84 | 85 | # save 86 | institute = patches[0][0].split("/")[-5] 87 | slide_data_dir = os.path.join(saving_dir, institute, 88 | slide_id.split(".")[0], 89 | slide_id.split(".")[1]) 90 | if not os.path.exists(slide_data_dir): 91 | os.makedirs(slide_data_dir) 92 | slide_npath = os.path.join(slide_data_dir, slide_id + f"-{tag}.pt") 93 | torch.save(slide_data, slide_npath) 94 | 95 | logging.debug(f"{slide_id} DONE") 96 | -------------------------------------------------------------------------------- /fastglioma/losses/ordmet.py: -------------------------------------------------------------------------------- 1 | """Ordinal Metric Learning loss module. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import numpy as np 11 | from typing import Optional 12 | 13 | cat_data = lambda x: torch.cat([x[0], x[1]], dim=0) 14 | cat_label = lambda x: torch.cat([x, x], dim=0) 15 | 16 | 17 | def uncat_data(emb): 18 | half_sz = int(emb.shape[0] / 2) 19 | f1, f2 = torch.split(emb, [half_sz, half_sz], dim=0) 20 | return torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 21 | 22 | 23 | def off_diagonal(x): 24 | n, m = x.shape 25 | assert n == m 26 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 27 | 28 | 29 | class OrdinalMetricLoss(nn.Module): 30 | """Computes the Ordinal Metric Learning Loss 31 | 32 | Adapted from HobbitLong/SupContrast. 33 | See THIRD_PARTY for third party license info. 34 | https://github.com/HobbitLong/SupContrast/blob/master/losses.py 35 | 36 | Author: Yonglong Tian (yonglong@mit.edu) 37 | Date: May 07, 2020 38 | """ 39 | 40 | def __init__(self, pos_label: Optional[str] = "none", **kwargs): 41 | """ 42 | Args: 43 | pos_label: str, optional 44 | How to define relationships between samples with the same label ("positive pairs"). 45 | Must be one of {"none", "same", "lower", "upper", "both"}. 46 | none: no loss is computed for positive pairs 47 | same: target is set to 0.5 for positive pairs 48 | lower: target is set to 0 for positive pairs 49 | upper: target is set to 1 for positive pairs 50 | Defaults to "none". 51 | """ 52 | super(OrdinalMetricLoss, self).__init__() 53 | 54 | self.pos_label = pos_label 55 | self.crit = torch.nn.BCEWithLogitsLoss(reduction="none") 56 | 57 | @staticmethod 58 | def get_loss_names(): 59 | return ["loss"] 60 | 61 | def forward(self, scores, labels=None): 62 | """ 63 | Args: 64 | scores: torch.Tensor, shape [N, 1] 65 | labels: torch.Tensor, shape [N, 1] 66 | Returns: 67 | loss: torch.Tensor, shape [1] 68 | """ 69 | device = (torch.device('cuda') 70 | if scores.is_cuda else torch.device('cpu')) 71 | 72 | scores = scores.reshape(-1, 1) 73 | 74 | batch_size = scores.shape[0] 75 | labels = labels.contiguous().view(-1, 1) 76 | if labels.shape[0] != batch_size: 77 | raise ValueError('Num of labels does not match num of features') 78 | 79 | upper_mask = torch.gt(labels, labels.T).repeat(1, 1) 80 | lower_mask = torch.lt(labels, labels.T).repeat(1, 1) 81 | 82 | logits_mask = torch.scatter( # mask-out self-contrast cases 83 | torch.ones_like(upper_mask), 1, 84 | torch.arange(batch_size * 1).view(-1, 1).to(device), 0).float() 85 | 86 | diff_mat = scores.repeat(1, batch_size) - scores.repeat(1, batch_size).T #yapf:disable 87 | 88 | if self.pos_label == "none": 89 | neg_mask = (upper_mask | lower_mask) 90 | loss = (self.crit(diff_mat, upper_mask.float()) * 91 | neg_mask).sum(1) / neg_mask.sum(1) 92 | elif self.pos_label == "same": 93 | mask = torch.eq(labels, labels.T).float().repeat(1, 1) 94 | label_mat = torch.where(upper_mask, upper_mask.float(), mask * 0.5) 95 | loss = (self.crit(diff_mat, label_mat) * 96 | logits_mask).sum(1) / logits_mask.sum(1) 97 | elif self.pos_label == "lower": 98 | loss = (self.crit(diff_mat, upper_mask.float()) * 99 | logits_mask).sum(1) / logits_mask.sum(1) 100 | elif self.pos_label == "upper": 101 | loss = (self.crit(diff_mat, (~lower_mask).float()) * 102 | logits_mask).sum(1) / logits_mask.sum(1) 103 | else: 104 | raise ValueError( 105 | "`pos_label` must be one of {none, same, lower, upper}") 106 | 107 | loss = loss.view(1, batch_size).mean() 108 | 109 | return {"loss": loss} 110 | -------------------------------------------------------------------------------- /THIRD_PARTY: -------------------------------------------------------------------------------- 1 | This project incorporate and adapts from components from the following repos. 2 | Their copyright notices and licenses are included below. 3 | 4 | # pytorch/vision 5 | 6 | BSD 3-Clause License 7 | 8 | Copyright (c) Soumith Chintala 2016, 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | * Redistributions of source code must retain the above copyright notice, this 15 | list of conditions and the following disclaimer. 16 | 17 | * Redistributions in binary form must reproduce the above copyright notice, 18 | this list of conditions and the following disclaimer in the documentation 19 | and/or other materials provided with the distribution. 20 | 21 | * Neither the name of the copyright holder nor the names of its 22 | contributors may be used to endorse or promote products derived from 23 | this software without specific prior written permission. 24 | 25 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 26 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 27 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 28 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 29 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 30 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 31 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 32 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 33 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 34 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 35 | 36 | # HobbitLong/SupContrast 37 | 38 | BSD 2-Clause License 39 | 40 | Copyright (c) 2020, Yonglong Tian 41 | All rights reserved. 42 | 43 | Redistribution and use in source and binary forms, with or without 44 | modification, are permitted provided that the following conditions are met: 45 | 46 | 1. Redistributions of source code must retain the above copyright notice, this 47 | list of conditions and the following disclaimer. 48 | 49 | 2. Redistributions in binary form must reproduce the above copyright notice, 50 | this list of conditions and the following disclaimer in the documentation 51 | and/or other materials provided with the distribution. 52 | 53 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 54 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 55 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 56 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 57 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 58 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 59 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 60 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 61 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 62 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 63 | 64 | 65 | # IgorSusmelj/barlowtwins 66 | 67 | MIT License 68 | 69 | Copyright (c) 2022 Igor Susmelj 70 | 71 | Permission is hereby granted, free of charge, to any person obtaining a copy of 72 | this software and associated documentation files (the "Software"), to deal in 73 | the Software without restriction, including without limitation the rights to 74 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 75 | of the Software, and to permit persons to whom the Software is furnished to do 76 | so, subject to the following conditions: 77 | 78 | The above copyright notice and this permission notice shall be included in all 79 | copies or substantial portions of the Software. 80 | 81 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 82 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 83 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 84 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 85 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 86 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 87 | SOFTWARE. -------------------------------------------------------------------------------- /fastglioma/losses/supcon.py: -------------------------------------------------------------------------------- 1 | """SupCon / SimCLR loss function 2 | 3 | Adapted from HobbitLong/SupContrast. 4 | See THIRD_PARTY for third party license info. 5 | https://github.com/HobbitLong/SupContrast/blob/master/losses.py 6 | 7 | Author: Yonglong Tian (yonglong@mit.edu) 8 | Date: May 07, 2020 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class SupConLoss(nn.Module): 16 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 17 | It also supports the unsupervised contrastive loss in SimCLR. 18 | https://github.com/HobbitLong/SupContrast source code here 19 | NOTE: Loss expects the representations have already been normalized!! 20 | """ 21 | 22 | def __init__(self, 23 | temperature=0.07, 24 | contrast_mode='all', 25 | base_temperature=0.07): 26 | super(SupConLoss, self).__init__() 27 | self.temperature = temperature 28 | self.contrast_mode = contrast_mode 29 | self.base_temperature = base_temperature 30 | 31 | def forward(self, features, labels=None, mask=None): 32 | """Compute loss for model. If both `labels` and `mask` are None, 33 | it degenerates to SimCLR unsupervised loss: 34 | https://arxiv.org/pdf/2002.05709.pdf 35 | Args: 36 | features: hidden vector of shape [bsz, n_views, ...]. 37 | labels: ground truth of shape [bsz]. 38 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 39 | has the same class as sample i. Can be asymmetric. 40 | Returns: 41 | A loss scalar. 42 | """ 43 | device = (torch.device('cuda') 44 | if features.is_cuda else torch.device('cpu')) 45 | 46 | if len(features.shape) < 3: 47 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 48 | 'at least 3 dimensions are required') 49 | if len(features.shape) > 3: 50 | features = features.view(features.shape[0], features.shape[1], -1) 51 | 52 | batch_size = features.shape[0] 53 | if labels is not None and mask is not None: 54 | raise ValueError('Cannot define both `labels` and `mask`') 55 | elif labels is None and mask is None: 56 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 57 | elif labels is not None: 58 | labels = labels.contiguous().view(-1, 1) 59 | if labels.shape[0] != batch_size: 60 | raise ValueError( 61 | 'Num of labels does not match num of features') 62 | mask = torch.eq(labels, labels.T).float().to(device) 63 | else: 64 | mask = mask.float().to(device) 65 | 66 | contrast_count = features.shape[1] 67 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 68 | if self.contrast_mode == 'one': 69 | anchor_feature = features[:, 0] 70 | anchor_count = 1 71 | elif self.contrast_mode == 'all': 72 | anchor_feature = contrast_feature 73 | anchor_count = contrast_count 74 | else: 75 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 76 | 77 | # compute logits 78 | anchor_dot_contrast = torch.div( 79 | torch.matmul(anchor_feature, contrast_feature.T), self.temperature) 80 | # for numerical stability 81 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 82 | logits = anchor_dot_contrast - logits_max.detach() 83 | 84 | # tile mask 85 | mask = mask.repeat(anchor_count, contrast_count) 86 | # mask-out self-contrast cases 87 | logits_mask = torch.scatter( 88 | torch.ones_like(mask), 1, 89 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0) 90 | mask = mask * logits_mask 91 | 92 | # compute log_prob 93 | exp_logits = torch.exp(logits) * logits_mask 94 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 95 | 96 | # compute mean of log-likelihood over positive 97 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 98 | 99 | # loss 100 | loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos 101 | loss = loss.view(anchor_count, batch_size).mean() 102 | 103 | return loss -------------------------------------------------------------------------------- /fastglioma/eval/save_embedding.py: -------------------------------------------------------------------------------- 1 | """Save embeddings script. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import os 8 | import logging 9 | from shutil import copy2 10 | from functools import partial 11 | from typing import List, Union, Dict, Any 12 | 13 | import gzip 14 | import yaml 15 | import numpy as np 16 | import pandas as pd 17 | from tqdm import tqdm 18 | 19 | import torch 20 | from torchvision.transforms import Compose 21 | 22 | import pytorch_lightning as pl 23 | 24 | from fastglioma.datasets.srh_dataset import PatchDataset #SlideDataset, slide_collate_fn 25 | from fastglioma.datasets.improc import get_transformations 26 | from fastglioma.utils.common import (parse_args, get_exp_name, config_loggers, 27 | get_num_worker) 28 | 29 | from fastglioma.models.resnet import resnet_backbone 30 | from fastglioma.models.cnn import MLP, ContrastiveLearningNetwork 31 | from fastglioma.models.mil import MIL_forward, MIL_Classifier, TransformerMIL 32 | 33 | from fastglioma.train.train_hidisc import HiDiscSystem 34 | 35 | from fastglioma.utils.format_slide_embedding import prediction_to_slide_embedding 36 | 37 | 38 | def get_predictions( 39 | cf: Dict[str, Any], 40 | exp_root: str) -> Dict[str, Union[torch.Tensor, List[str]]]: 41 | """Run forward pass on the dataset, and generate embeddings and logits""" 42 | _, valid_xform = get_transformations(cf) 43 | 44 | dset = PatchDataset( 45 | data_root=cf["data"]["db_root"], 46 | studies=cf["data"]["studies"], 47 | transform=valid_xform, 48 | balance_patch_per_class=False, 49 | use_patient_class=cf["data"]["use_patient_class"]) 50 | 51 | loader = torch.utils.data.DataLoader( 52 | dset, 53 | batch_size=cf["eval"]["predict_batch_size"], 54 | drop_last=False, 55 | pin_memory=True, 56 | num_workers=get_num_worker(), 57 | # collate_fn=slide_collate_fn, 58 | persistent_workers=True) 59 | 60 | # load lightning checkpoint 61 | ckpt_path = os.path.join(cf["infra"]["log_dir"], cf["infra"]["exp_name"], 62 | cf["eval"]["ckpt_path"]) 63 | 64 | # Load model from ckpt 65 | model = HiDiscSystem.load_from_checkpoint(ckpt_path, 66 | cf=cf, 67 | num_it_per_ep=0) 68 | 69 | # Create trainer 70 | trainer = pl.Trainer(accelerator="gpu", 71 | devices=1, 72 | max_epochs=-1, 73 | default_root_dir=exp_root, 74 | enable_checkpointing=False, 75 | logger=False) 76 | 77 | predictions = trainer.predict(model, dataloaders=loader) 78 | 79 | def process_predictions(predictions): 80 | # Combine predictions into a single dictionary 81 | pred = {} 82 | for k in predictions[0].keys(): 83 | if k == "path": 84 | pred[k] = [pk for p in predictions for pk in p[k][0]] 85 | else: 86 | pred[k] = torch.cat([p[k] for p in predictions]) 87 | 88 | return pred 89 | 90 | predictions = process_predictions(predictions) 91 | return predictions 92 | 93 | 94 | def setup_eval_paths(cf, get_exp_name, cmt_append): 95 | """Get name of the ouput dirs and create them in the file system.""" 96 | log_root = cf["infra"]["log_dir"] 97 | exp_name = cf["infra"]["exp_name"] 98 | instance_name = cf["eval"]["ckpt_path"].split("/")[0] 99 | eval_instance_name = "_".join([get_exp_name(cf), cmt_append]) 100 | exp_root = os.path.join(log_root, exp_name, instance_name, "evals", 101 | eval_instance_name) 102 | 103 | # Generate needed folders 104 | pred_dir = os.path.join(exp_root, 'predictions') 105 | config_dir = os.path.join(exp_root, 'config') 106 | for dir_name in [pred_dir, config_dir]: 107 | if not os.path.exists(dir_name): 108 | os.makedirs(dir_name) 109 | 110 | return exp_root, pred_dir, partial(copy2, dst=config_dir) 111 | 112 | 113 | def main(): 114 | """Driver script for pipeline.""" 115 | cf_fd = parse_args() 116 | cf = yaml.load(cf_fd, Loader=yaml.FullLoader) 117 | exp_root, pred_dir, cp_config = setup_eval_paths(cf, get_exp_name, "") 118 | pl.seed_everything(cf["infra"]["seed"]) 119 | 120 | # Logging and copying config files 121 | cp_config(cf_fd.name) 122 | config_loggers(exp_root) 123 | 124 | logging.info("Generating predictions") 125 | predictions = get_predictions(cf, exp_root) 126 | 127 | # save embeddings 128 | prediction_to_slide_embedding( 129 | saving_dir=cf["eval"]["save_by_slide"]["saving_dir"], 130 | tag=cf["eval"]["save_by_slide"]["tag"], 131 | predictions=predictions) 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /fastglioma/models/vit.py: -------------------------------------------------------------------------------- 1 | """Vision transformer model. 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | All rights reserved. 5 | 6 | This source code is licensed under the license found in the 7 | LICENSE file in the root directory of this source tree. 8 | -------------------------------------------------------- 9 | References: 10 | timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 11 | DeiT: https://github.com/facebookresearch/deit 12 | -------------------------------------------------------- 13 | 14 | Copyright (c) 2024 University of Michigan. All rights reserved. 15 | Licensed under the MIT License. See LICENSE for license information. 16 | """ 17 | import math 18 | from functools import partial 19 | import numpy as np 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | import timm.models.vision_transformer 25 | 26 | 27 | class Attention(nn.Module): 28 | 29 | def __init__(self, 30 | dim, 31 | num_heads=8, 32 | qkv_bias=False, 33 | qk_scale=None, 34 | attn_drop=0., 35 | proj_drop=0.): 36 | super().__init__() 37 | self.num_heads = num_heads 38 | head_dim = dim // num_heads 39 | self.scale = qk_scale or head_dim**-0.5 40 | 41 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 42 | self.attn_drop = nn.Dropout(attn_drop) 43 | self.proj = nn.Linear(dim, dim) 44 | self.proj_drop = nn.Dropout(proj_drop) 45 | 46 | def forward(self, x): 47 | B, N, C = x.shape 48 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, 49 | C // self.num_heads).permute(2, 0, 3, 1, 4) 50 | q, k, v = qkv[0], qkv[1], qkv[2] 51 | 52 | attn = (q @ k.transpose(-2, -1)) * self.scale 53 | attn = attn.softmax(dim=-1) 54 | attn = self.attn_drop(attn) 55 | 56 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 57 | x = self.proj(x) 58 | x = self.proj_drop(x) 59 | return x, attn 60 | 61 | 62 | def drop_path(x, drop_prob: float = 0., training: bool = False): 63 | if drop_prob == 0. or not training: 64 | return x 65 | keep_prob = 1 - drop_prob 66 | shape = (x.shape[0], ) + (1, ) * ( 67 | x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 68 | random_tensor = keep_prob + torch.rand( 69 | shape, dtype=x.dtype, device=x.device) 70 | random_tensor.floor_() # binarize 71 | output = x.div(keep_prob) * random_tensor 72 | return output 73 | 74 | 75 | class DropPath(nn.Module): 76 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 77 | """ 78 | 79 | def __init__(self, drop_prob=None): 80 | super(DropPath, self).__init__() 81 | self.drop_prob = drop_prob 82 | 83 | def forward(self, x): 84 | return drop_path(x, self.drop_prob, self.training) 85 | 86 | 87 | class MLP(nn.Module): 88 | 89 | def __init__(self, 90 | in_features, 91 | hidden_features=None, 92 | out_features=None, 93 | act_layer=nn.GELU, 94 | drop=0.): 95 | super().__init__() 96 | out_features = out_features or in_features 97 | hidden_features = hidden_features or in_features 98 | self.fc1 = nn.Linear(in_features, hidden_features) 99 | self.act = act_layer() 100 | self.fc2 = nn.Linear(hidden_features, out_features) 101 | self.drop = nn.Dropout(drop) 102 | 103 | def forward(self, x): 104 | x = self.fc1(x) 105 | x = self.act(x) 106 | x = self.drop(x) 107 | x = self.fc2(x) 108 | x = self.drop(x) 109 | return x 110 | 111 | 112 | class Block(nn.Module): 113 | 114 | def __init__(self, 115 | dim, 116 | num_heads, 117 | mlp_ratio=4., 118 | qkv_bias=False, 119 | qk_scale=None, 120 | drop=0., 121 | attn_drop=0., 122 | drop_path=0., 123 | act_layer=nn.GELU, 124 | norm_layer=nn.LayerNorm): 125 | super().__init__() 126 | self.norm1 = norm_layer(dim) 127 | self.attn = Attention(dim, 128 | num_heads=num_heads, 129 | qkv_bias=qkv_bias, 130 | qk_scale=qk_scale, 131 | attn_drop=attn_drop, 132 | proj_drop=drop) 133 | self.drop_path = DropPath( 134 | drop_path) if drop_path > 0. else nn.Identity() 135 | self.norm2 = norm_layer(dim) 136 | mlp_hidden_dim = int(dim * mlp_ratio) 137 | self.mlp = MLP(in_features=dim, 138 | hidden_features=mlp_hidden_dim, 139 | act_layer=act_layer, 140 | drop=drop) 141 | 142 | def forward(self, x, return_attention=False): 143 | y, attn = self.attn(self.norm1(x)) 144 | if return_attention: 145 | return attn 146 | x = x + self.drop_path(y) 147 | residual = x 148 | 149 | x = self.drop_path(self.mlp(self.norm2(x))) 150 | 151 | x = residual + x 152 | return x -------------------------------------------------------------------------------- /fastglioma/utils/common.py: -------------------------------------------------------------------------------- 1 | """Common modules for FastGlioma training, evaluation, and inference. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import os 8 | import math 9 | import logging 10 | import argparse 11 | from shutil import copy2 12 | from datetime import datetime 13 | from functools import partial 14 | from typing import Tuple, Dict, Optional, Any 15 | 16 | import uuid 17 | 18 | import torch 19 | from torch import optim 20 | from torch.optim.lr_scheduler import StepLR, LambdaLR 21 | from torchvision.transforms import Compose 22 | 23 | import pytorch_lightning as pl 24 | 25 | 26 | def get_optimizer_func(cf: Dict[str, Any]) -> callable: 27 | """Return a optimizer callable based on config value""" 28 | lr = cf["training"]["learn_rate"] 29 | if cf["training"]["optimizer"] == "adamw": 30 | return partial(optim.AdamW, lr=lr) 31 | elif cf["training"]["optimizer"] == "adam": 32 | return partial(optim.Adam, lr=lr) 33 | elif cf["training"]["optimizer"] == "sgd": 34 | return partial(optim.SGD, lr=lr, momentum=0.9) 35 | else: 36 | raise NotImplementedError() 37 | 38 | 39 | def get_scheduler_func(cf: Dict[str, Any], 40 | num_it_per_ep: int = 0) -> Optional[callable]: 41 | """Return a scheduler callable based on config value.""" 42 | if "scheduler" not in cf["training"]: 43 | return None 44 | 45 | if cf["training"]["scheduler"]["which"] == "step_lr": 46 | step_size = convert_epoch_to_iter( 47 | cf["training"]["scheduler"]["params"]["step_unit"], 48 | cf["training"]["scheduler"]["params"]["step_size"], num_it_per_ep) 49 | return partial(StepLR, 50 | step_size=step_size, 51 | gamma=cf["training"]["scheduler"]["params"]["gamma"]) 52 | elif cf["training"]["scheduler"]["which"] == "cos_warmup": 53 | num_epochs = cf['training']['num_epochs'] 54 | 55 | num_warmup_steps = cf['training']['scheduler']['params'][ 56 | 'num_warmup_steps'] 57 | if isinstance(num_warmup_steps, float): # fraction of total train 58 | cf['training']['scheduler']['params']['num_warmup_steps'] = int( 59 | num_warmup_steps * num_epochs * num_it_per_ep) 60 | 61 | return partial(get_cosine_schedule_with_warmup, 62 | num_training_steps=num_it_per_ep * num_epochs, 63 | **cf["training"]["scheduler"]["params"]) 64 | else: 65 | raise NotImplementedError() 66 | 67 | 68 | def convert_epoch_to_iter(unit: str, steps: int, num_it_per_ep: int) -> int: 69 | """Converts number of epochs / iterations to number of iterations.""" 70 | if unit == "epoch": 71 | return num_it_per_ep * steps # per epoch 72 | elif unit == "iter": 73 | return steps 74 | else: 75 | NotImplementedError("unit must be one of [epoch, iter]") 76 | 77 | 78 | def get_cosine_schedule_with_warmup(optimizer: torch.optim.Optimizer, 79 | num_warmup_steps: int, 80 | num_training_steps: int, 81 | num_cycles: float = 0.5, 82 | last_epoch: int = -1): 83 | """Create cosine learn rate scheduler with linear warm up built in.""" 84 | 85 | def lr_lambda(current_step): 86 | if current_step < num_warmup_steps: 87 | return float(current_step) / float(max(1, num_warmup_steps)) 88 | progress = float(current_step - num_warmup_steps) / float( 89 | max(1, num_training_steps - num_warmup_steps)) 90 | return max( 91 | 0.0, 0.5 * 92 | (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 93 | 94 | return LambdaLR(optimizer, lr_lambda, last_epoch) 95 | 96 | 97 | def config_loggers(exp_root): 98 | """Config logger for the experiments 99 | 100 | Sets string format and where to save. 101 | """ 102 | 103 | logging_format_str = "[%(levelname)-s|%(asctime)s|%(name)s|" + \ 104 | "%(filename)s:%(lineno)d|%(funcName)s] %(message)s" 105 | logging.basicConfig(level=logging.INFO, 106 | format=logging_format_str, 107 | datefmt="%H:%M:%S", 108 | handlers=[ 109 | logging.FileHandler( 110 | os.path.join(exp_root, 'train.log')), 111 | logging.StreamHandler() 112 | ], 113 | force=True) 114 | logging.info("Exp root {}".format(exp_root)) 115 | 116 | formatter = logging.Formatter(logging_format_str, datefmt="%H:%M:%S") 117 | logger = logging.getLogger("pytorch_lightning.core") 118 | logger.setLevel(logging.INFO) 119 | logger.addHandler(logging.FileHandler(os.path.join(exp_root, 'train.log'))) 120 | for h in logger.handlers: 121 | h.setFormatter(formatter) 122 | 123 | 124 | def setup_ddp_exp_name(exp_name: str): 125 | if pl.utilities.rank_zero.rank_zero_only.rank != 0: 126 | return os.path.join(exp_name, "high_rank") 127 | else: 128 | return exp_name 129 | 130 | 131 | def setup_output_dirs(cf: Dict, get_exp_name: callable, 132 | cmt_append: str) -> Tuple[str, str, callable]: 133 | """Get name of the ouput dirs and create them in the file system.""" 134 | log_root = cf["infra"]["log_dir"] 135 | instance_name = "_".join([get_exp_name(cf), cmt_append]) 136 | exp_name = setup_ddp_exp_name(cf["infra"]["exp_name"]) 137 | exp_root = os.path.join(log_root, exp_name, instance_name) 138 | 139 | model_dir = os.path.join(exp_root, 'models') 140 | config_dir = os.path.join(exp_root, 'config') 141 | 142 | for dir_name in [model_dir, config_dir]: 143 | if not os.path.exists(dir_name): 144 | os.makedirs(dir_name) 145 | return exp_root, model_dir, partial(copy2, dst=config_dir) 146 | 147 | 148 | def parse_args(): 149 | """Get config file handle from command line argument.""" 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('-c', 152 | '--config', 153 | type=argparse.FileType('r'), 154 | required=True, 155 | help='config file for training') 156 | args = parser.parse_args() 157 | return args.config 158 | 159 | 160 | def get_exp_name(cf): 161 | """Generate experiment name with a hash, time, and comments in config.""" 162 | time = datetime.now().strftime("%b%d-%H-%M-%S") 163 | return "-".join([uuid.uuid4().hex[:8], time, cf["infra"]["comment"]]) 164 | 165 | 166 | def get_num_worker(): 167 | """Estimate number of cpu workers.""" 168 | try: 169 | num_worker = len(os.sched_getaffinity(0)) 170 | except Exception: 171 | num_worker = os.cpu_count() 172 | 173 | if num_worker > 1: 174 | return num_worker - 1 175 | else: 176 | return torch.cuda.device_count() * 4 -------------------------------------------------------------------------------- /fastglioma/inference/run_inference.py: -------------------------------------------------------------------------------- 1 | """Inference modules and script. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import os 8 | import logging 9 | from shutil import copy2 10 | from functools import partial 11 | from typing import List, Union, Dict, Any 12 | 13 | import yaml 14 | import numpy as np 15 | import pandas as pd 16 | from tqdm import tqdm 17 | 18 | import torch 19 | from torchvision.transforms import Compose 20 | 21 | import pytorch_lightning as pl 22 | 23 | from fastglioma.datasets.srh_dataset import SlideDataset, slide_collate_fn 24 | from fastglioma.datasets.improc import get_srh_base_aug, get_strong_aug 25 | from fastglioma.utils.common import (parse_args, get_exp_name, config_loggers, 26 | get_num_worker) 27 | 28 | from fastglioma.models.resnet import resnet_backbone 29 | from fastglioma.models.cnn import MLP, ContrastiveLearningNetwork 30 | from fastglioma.models.mil import MIL_forward, MIL_Classifier, TransformerMIL 31 | 32 | from huggingface_hub import hf_hub_download 33 | 34 | 35 | class FastGliomaInferenceSystem(pl.LightningModule): 36 | """Lightning system for FastGlioma inference on OpenSRH.""" 37 | 38 | def __init__(self, cf: Dict[str, Any], num_it_per_ep: int): 39 | super().__init__() 40 | self.cf_ = cf 41 | 42 | if cf["model"]["patch"]["backbone"]["which"] == "resnet34": 43 | bb = partial( 44 | resnet_backbone, 45 | arch=cf["model"]["patch"]["backbone"]["which"], 46 | num_channel_in=cf["model"]["patch"]["backbone"]["params"].get( 47 | "num_channel_in", 3)) 48 | else: 49 | raise NotImplementedError() 50 | 51 | if cf["model"]["slide"]["mil"]["which"] == "transformer": 52 | mil = partial(MIL_forward, 53 | mil=partial(TransformerMIL, 54 | **cf["model"]["slide"]["mil"]["params"])) 55 | else: 56 | raise NotImplementedError() 57 | 58 | mlp = partial(MLP, 59 | n_in=mil().num_out, 60 | hidden_layers=cf["model"]["slide"]["mlp_hidden"], 61 | n_out=1) 62 | self.model = MIL_Classifier(bb, mil, mlp) 63 | 64 | self.criterion = self.train_loss = self.val_loss = None 65 | self.num_it_per_ep_ = num_it_per_ep 66 | 67 | @staticmethod 68 | def get_kth_view(data: List[List[torch.Tensor]], k: int): 69 | return [d[k] for d in data] 70 | 71 | def forward(self, batch): 72 | return self.model(self.get_kth_view(batch["image"], 0), 73 | coords=self.get_kth_view(batch["coords"], 0)) 74 | 75 | @torch.inference_mode() 76 | def predict_step(self, batch, batch_idx): 77 | out = self.forward(batch) 78 | 79 | return { 80 | "path": [batch["path"]], 81 | "label": batch["label"], 82 | "logits": out["logits"], 83 | "embeddings": out["embeddings"] 84 | } 85 | 86 | 87 | def get_predictions( 88 | cf: Dict[str, Any], 89 | exp_root: str) -> Dict[str, Union[torch.Tensor, List[str]]]: 90 | """Run forward pass on the dataset, and generate embeddings and logits""" 91 | 92 | def get_transform(cf): 93 | return Compose( 94 | get_srh_base_aug(base_aug=("three_channels" if cf["data"]["patch_input"] == "highres" else "ch2_only")) + 95 | (get_strong_aug([{"which": "inpaint_rows_always_apply", "params": {"image_size": 300, "y_skip": 5}}], 1.) if cf["data"]["patch_input"] == "lowres" else []) 96 | ) 97 | 98 | dset = SlideDataset( 99 | data_root=cf["data"]["db_root"], 100 | studies=cf["data"]["studies"], 101 | transform=get_transform(cf), 102 | balance_slide_per_class=False, 103 | use_patient_class=cf["data"]["use_patient_class"]) 104 | 105 | loader = torch.utils.data.DataLoader( 106 | dset, 107 | batch_size=cf["eval"]["predict_batch_size"], 108 | drop_last=False, 109 | pin_memory=True, 110 | num_workers=get_num_worker(), 111 | collate_fn=slide_collate_fn, 112 | persistent_workers=True) 113 | 114 | # Load model from huggingface repo 115 | ckpt_path = hf_hub_download(repo_id=cf["infra"]["hf_repo"], 116 | filename=cf["eval"]["ckpt_path"]) 117 | model = FastGliomaInferenceSystem.load_from_checkpoint(ckpt_path, 118 | cf=cf, 119 | num_it_per_ep=0) 120 | 121 | # Create trainer 122 | trainer = pl.Trainer(accelerator="gpu", 123 | devices=1, 124 | max_epochs=-1, 125 | default_root_dir=exp_root, 126 | enable_checkpointing=False, 127 | logger=False) 128 | 129 | predictions = trainer.predict(model, dataloaders=loader) 130 | 131 | def process_predictions(predictions): 132 | # Combine predictions into a single dictionary 133 | pred = {} 134 | for k in predictions[0].keys(): 135 | if k == "path": 136 | pred[k] = [pk for p in predictions for pk in p[k][0]] 137 | else: 138 | pred[k] = torch.cat([p[k] for p in predictions]) 139 | 140 | pred["logits"] = pred["logits"].squeeze(1) 141 | pred["scores"] = torch.sigmoid(pred["logits"]) 142 | pred["label"] = [{v: k for k, v in dset.class_to_idx_.items()}[l.item()] for l in pred["label"]] #yapf:disable 143 | pred["slide"] = ["/".join(imp[0].split("/")[:9]) for imp in pred["path"]] #yapf:disable 144 | 145 | # Sort predictions by slide name 146 | sorted_indices = sorted(range(len(pred['slide'])), 147 | key=lambda k: pred['slide'][k]) 148 | 149 | # Apply the same ordering to all keys in pred 150 | for key in pred: 151 | if isinstance(pred[key], list): 152 | pred[key] = [pred[key][i] for i in sorted_indices] 153 | elif isinstance(pred[key], torch.Tensor): 154 | pred[key] = pred[key][sorted_indices] 155 | 156 | del pred["path"] 157 | 158 | return pred 159 | 160 | predictions = process_predictions(predictions) 161 | return predictions 162 | 163 | 164 | def setup_eval_paths(cf, get_exp_name, cmt_append): 165 | """Get name of the ouput dirs and create them in the file system.""" 166 | log_root = cf["infra"]["log_dir"] 167 | exp_name = cf["infra"]["exp_name"] 168 | eval_instance_name = "_".join([get_exp_name(cf), cmt_append]) 169 | exp_root = os.path.join(log_root, exp_name, eval_instance_name) 170 | 171 | # Generate needed folders 172 | pred_dir = os.path.join(exp_root, 'predictions') 173 | config_dir = os.path.join(exp_root, 'config') 174 | for dir_name in [pred_dir, config_dir]: 175 | if not os.path.exists(dir_name): 176 | os.makedirs(dir_name) 177 | 178 | return exp_root, pred_dir, partial(copy2, dst=config_dir) 179 | 180 | 181 | def main(): 182 | """Driver script for inference pipeline.""" 183 | cf_fd = parse_args() 184 | cf = yaml.load(cf_fd, Loader=yaml.FullLoader) 185 | exp_root, pred_dir, cp_config = setup_eval_paths(cf, get_exp_name, "") 186 | pl.seed_everything(cf["infra"]["seed"]) 187 | 188 | # Logging and copying config files 189 | cp_config(cf_fd.name) 190 | config_loggers(exp_root) 191 | 192 | logging.info("Generating predictions") 193 | predictions = get_predictions(cf, exp_root) 194 | torch.save(predictions, os.path.join(pred_dir, "predictions.pt")) 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /fastglioma/datasets/improc.py: -------------------------------------------------------------------------------- 1 | """Image processing functions designed to work with OpenSRH datasets. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | from typing import Optional, List, Tuple, Dict, Callable 8 | from functools import partial 9 | 10 | import random 11 | import tifffile 12 | import numpy as np 13 | 14 | import torch 15 | from torchvision.transforms import functional as F 16 | from torch.nn import ModuleList 17 | from torchvision.transforms import ( 18 | Normalize, RandomApply, Compose, RandomHorizontalFlip, RandomVerticalFlip, 19 | Resize, RandAugment, RandomErasing, RandomAutocontrast, Grayscale, 20 | RandomSolarize, ColorJitter, RandomAdjustSharpness, GaussianBlur, 21 | RandomAffine, RandomResizedCrop) 22 | 23 | # Base augmentation modules 24 | class GetThirdChannel(torch.nn.Module): 25 | """Computes the third channel of SRH image 26 | 27 | Compute the third channel of SRH images by subtracting CH3 and CH2. The 28 | channel difference is added to the subtracted_base. 29 | 30 | """ 31 | 32 | def __init__(self, 33 | mode: str = "three_channels", 34 | subtracted_base: float = 5000 / 65536.0): 35 | super().__init__() 36 | 37 | self.subtracted_base = subtracted_base 38 | aug_func_dict = { 39 | "three_channels": self.get_third_channel_, 40 | "ch2_only": self.get_ch2_, 41 | "ch3_only": self.get_ch3_, 42 | "diff_only": self.get_diff_ 43 | } 44 | if mode in aug_func_dict: 45 | self.aug_func = aug_func_dict[mode] 46 | else: 47 | raise ValueError("base_augmentation must be in " + 48 | f"{aug_func_dict.keys()}") 49 | 50 | def get_third_channel_(self, im2: torch.Tensor) -> torch.Tensor: 51 | ch2 = im2[0, :, :] 52 | ch3 = im2[1, :, :] 53 | ch1 = ch3 - ch2 + self.subtracted_base 54 | return torch.stack((ch1, ch2, ch3), dim=0) 55 | 56 | def get_ch2_(self, im2: torch.Tensor) -> torch.Tensor: 57 | return im2[0, :, :].unsqueeze(0) 58 | 59 | def get_ch3_(self, im2: torch.Tensor) -> torch.Tensor: 60 | return im2[1, :, :].unsqueeze(0) 61 | 62 | def get_diff_(self, im2: torch.Tensor) -> torch.Tensor: 63 | ch2 = im2[0, :, :] 64 | ch3 = im2[1, :, :] 65 | ch1 = ch3 - ch2 + self.subtracted_base 66 | 67 | return ch1.unsqueeze(0) 68 | 69 | def forward(self, two_channel_image: torch.Tensor) -> torch.Tensor: 70 | """ 71 | Args: 72 | two_channel_image: a 2 channel np array in the shape H * W * 2 73 | 74 | Returns: 75 | A 1 or 3 channel np array in the shape 3 * H * W 76 | """ 77 | return self.aug_func(two_channel_image) 78 | 79 | 80 | class MinMaxChop(torch.nn.Module): 81 | """Clamps the images to float (0,1) range.""" 82 | 83 | def __init__(self, min_val: float = 0.0, max_val: float = 1.0): 84 | super().__init__() 85 | self.min_ = min_val 86 | self.max_ = max_val 87 | 88 | def __call__(self, image: torch.Tensor) -> torch.Tensor: 89 | return image.clamp(self.min_, self.max_) 90 | 91 | 92 | class GaussianNoise(torch.nn.Module): 93 | """Adds guassian noise to images.""" 94 | 95 | def __init__(self, min_var: float = 0.01, max_var: float = 0.1): 96 | super().__init__() 97 | self.min_var = min_var 98 | self.max_var = max_var 99 | 100 | def __call__(self, tensor): 101 | 102 | var = random.uniform(self.min_var, self.max_var) 103 | noisy = tensor + torch.randn(tensor.size()) * var 104 | noisy = torch.clamp(noisy, min=0., max=1.) 105 | return noisy 106 | 107 | 108 | # Strong augmentation modules 109 | class InpaintRows(torch.nn.Module): 110 | 111 | def __init__(self, y_skip: int = 2, image_size: int = 300): 112 | self.y_skip = y_skip 113 | self.image_size = image_size 114 | 115 | def __call__(self, img): 116 | self.original_y = img.shape[1] 117 | mask = np.arange(0, self.original_y, self.y_skip) 118 | img_trans = img[:, mask, :] 119 | img_trans = Resize( 120 | size=(self.image_size, self.image_size), 121 | interpolation=F.InterpolationMode.BILINEAR, 122 | antialias=True)(img_trans) 123 | return img_trans 124 | 125 | def __repr__(self): 126 | return self.__class__.__name__ + '()' 127 | 128 | 129 | def process_read_im(imp: str) -> torch.Tensor: 130 | """Read in two channel image 131 | 132 | Args: 133 | imp: a string that is the path to the tiff image 134 | 135 | Returns: 136 | A 2 channel torch Tensor in the shape 2 * H * W 137 | """ 138 | # reference: https://github.com/pytorch/vision/blob/49468279d9070a5631b6e0198ee562c00ecedb10/torchvision/transforms/functional.py#L133 139 | 140 | return torch.from_numpy(tifffile.imread(imp).astype( 141 | np.float32)).contiguous() 142 | 143 | 144 | # helpers 145 | def get_srh_base_aug(base_aug: str = "three_channels") -> List: 146 | """Base processing augmentations for all SRH images 147 | 148 | Args: 149 | base_aug: specifies which channel subset should be used ('three_channel', 'ch2_only', 'ch3_only', 'diff_only') 150 | 151 | Returns: 152 | An augmented 1 or 3 torch Tensor in the shape of 3 * H * W 153 | """ 154 | u16_min = (0, 0) 155 | u16_max = (65536, 65536) # 2^16 156 | 157 | # if y_skip != 0: 158 | # xform_list = [Normalize(mean=u16_min, std=u16_max), GetThirdChannel(mode=base_aug), MinMaxChop(), InpaintRows(y_skip=y_skip)] 159 | # else: 160 | xform_list = [Normalize(mean=u16_min, std=u16_max), GetThirdChannel(mode=base_aug), MinMaxChop()] 161 | 162 | return xform_list 163 | 164 | 165 | def get_strong_aug(augs, rand_prob) -> List: 166 | """Strong augmentations for training""" 167 | rand_apply = lambda which, **kwargs: RandomApply( 168 | ModuleList([which(**kwargs)]), p=rand_prob) 169 | 170 | callable_dict = { 171 | "resize": Resize, 172 | "inpaint_rows_always_apply": InpaintRows, 173 | "inpaint_rows": partial(rand_apply, which=InpaintRows), 174 | "random_horiz_flip": partial(RandomHorizontalFlip, p=rand_prob), 175 | "random_vert_flip": partial(RandomVerticalFlip, p=rand_prob), 176 | "gaussian_noise": partial(rand_apply, which=GaussianNoise), 177 | "color_jitter": partial(rand_apply, which=ColorJitter), 178 | "random_autocontrast": partial(RandomAutocontrast, p=rand_prob), 179 | "random_solarize": partial(RandomSolarize, p=rand_prob), 180 | "random_sharpness": partial(RandomAdjustSharpness, p=rand_prob), 181 | "drop_color": partial(rand_apply, which=Grayscale), 182 | "gaussian_blur": partial(rand_apply, GaussianBlur), 183 | "random_erasing": partial(RandomErasing, p=rand_prob), 184 | "random_affine": partial(rand_apply, RandomAffine), 185 | "random_resized_crop": partial(rand_apply, RandomResizedCrop) 186 | } 187 | 188 | return [callable_dict[a["which"]](**a["params"]) for a in augs] 189 | 190 | 191 | def get_srh_aug_list(augs, base_aug: str = "three_channels", rand_prob=0.5) -> List: 192 | """Combine base and strong augmentations for training""" 193 | return get_srh_base_aug(base_aug=base_aug) + get_strong_aug(augs, rand_prob) 194 | 195 | 196 | def get_transformations( 197 | cf: Optional[Dict] = None, 198 | strong_aug: Callable = get_strong_aug) -> Tuple[Compose, Compose]: 199 | 200 | if cf: 201 | train_augs = cf["data"]["train_augmentation"] 202 | val_augs = cf["data"]["valid_augmentation"] 203 | base_aug = cf["data"]["srh_base_augmentation"] 204 | aug_prob = cf["data"]["rand_aug_prob"] 205 | else: 206 | train_augs = [] 207 | val_augs = [] 208 | base_aug = "three_channels" 209 | aug_prob = 0 210 | 211 | if val_augs == "same": 212 | val_augs = train_augs 213 | 214 | train_xform = Compose(get_srh_aug_list(train_augs, base_aug=base_aug, rand_prob=aug_prob)) 215 | valid_xform = Compose(get_srh_aug_list(val_augs, base_aug=base_aug, rand_prob=aug_prob)) 216 | 217 | return train_xform, valid_xform -------------------------------------------------------------------------------- /fastglioma/datasets/embedding_dataset.py: -------------------------------------------------------------------------------- 1 | """PyTorch embedding datasets designed to work with OpenSRH. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import os 8 | import json 9 | import logging 10 | from collections import Counter 11 | from typing import Optional, List, Union, TypedDict, Tuple 12 | import random 13 | 14 | from tqdm import tqdm 15 | import numpy as np 16 | 17 | import torch 18 | from torch.utils.data import Dataset 19 | from torchvision.datasets.folder import is_image_file 20 | from torchvision.transforms import Compose 21 | 22 | 23 | class SlideEmbeddingDataset(Dataset): 24 | """OpenSRH embedding dataset""" 25 | 26 | def __init__(self, 27 | embedding_root: str, 28 | tag: List[str], 29 | data_root: str, 30 | studies: Union[str, List[str]], 31 | transform: callable = None, 32 | target_transform: callable = torch.tensor, 33 | balance_slide_per_class: bool = False, 34 | use_patient_class: bool = True, 35 | check_images_exist: bool = False, 36 | meta_fname: str = "opensrh.json", 37 | num_transforms: int = 1) -> None: 38 | """Inits the OpenSRH dataset, where each instance is a slide.""" 39 | 40 | self.embed_root_ = embedding_root 41 | self.tag_ = tag 42 | 43 | self.data_root_ = data_root 44 | self.transform_ = transform 45 | self.target_transform_ = target_transform 46 | self.use_patient_class_ = use_patient_class 47 | self.check_images_exist_ = check_images_exist 48 | self.meta_fname_ = meta_fname 49 | self.num_transforms_ = num_transforms 50 | self.get_all_meta() 51 | self.get_study_list(studies) 52 | 53 | # Walk through each study 54 | self.instances_ = [] 55 | for p in tqdm(self.studies_): 56 | self.instances_.extend(self.get_study_instances(p)) 57 | 58 | if balance_slide_per_class: 59 | self.replicate_balance_instances() 60 | self.get_weights() 61 | 62 | def get_all_meta(self): 63 | """Read in all metadata files.""" 64 | 65 | try: 66 | with open(os.path.join(self.data_root_, 67 | f"meta/{self.meta_fname_}")) as fd: 68 | self.metadata_ = json.load(fd) 69 | except Exception as e: 70 | logging.critical("Failed to locate dataset.") 71 | raise e 72 | 73 | logging.info(f"Locate OpenSRH dataset at {self.data_root_}") 74 | return 75 | 76 | def get_study_list(self, studies): 77 | """Get a list of studies from default split or list of IDs.""" 78 | 79 | if isinstance(studies, str): 80 | try: 81 | with open( 82 | os.path.join(self.data_root_, 83 | "meta/train_val_split.json")) as fd: 84 | train_val_split = json.load(fd) 85 | except Exception as e: 86 | logging.critical("Failed to locate preset train/val split.") 87 | raise e 88 | 89 | if studies == "train": 90 | self.studies_ = train_val_split["train"] 91 | elif studies in ["valid", "val"]: 92 | self.studies_ = train_val_split["val"] 93 | else: 94 | return ValueError( 95 | "studies split must be one of [\"train\", \"val\"]") 96 | elif isinstance(studies, List): 97 | self.studies_ = studies 98 | else: 99 | raise ValueError("studies must be a string representing " + 100 | "train/val split or a list of study numbers") 101 | return 102 | 103 | def get_study_instances(self, patient: str): 104 | """Get all instances from one study.""" 105 | 106 | study_instances = [] 107 | logging.debug(patient) 108 | 109 | for s in self.metadata_[patient]["slides"]: 110 | slide_tag_instances = [] 111 | for tag in self.tag_: 112 | tag_path = os.path.join(self.embed_root_, 113 | "studies", 114 | patient, 115 | s, 116 | f"{patient}.{s}-{tag}.pt") 117 | if os.path.exists(tag_path): 118 | slide_tag_instances.append(tag_path) 119 | 120 | slide_label = ( 121 | self.metadata_[patient]["slides"][s].get( 122 | "slide_class", self.metadata_[patient]["class"]) #yapf:disable 123 | if not self.use_patient_class_ else 124 | self.metadata_[patient]["class"]) 125 | 126 | slide_instance = (f"{patient}/{s}", slide_label, slide_tag_instances) #yapf:disable 127 | 128 | logging.debug(f"slide {patient}/{s} tags {len(slide_tag_instances)}") #yapf:disable 129 | study_instances.append(slide_instance) 130 | 131 | logging.debug(f"patient {patient} slides {len(study_instances)}") 132 | return study_instances 133 | 134 | def process_classes(self): 135 | """Look for all the labels in the dataset. 136 | 137 | Creates the classes_, and class_to_idx_ attributes""" 138 | all_labels = [i[1] for i in self.instances_] 139 | self.classes_ = sorted(set(all_labels)) 140 | self.class_to_idx_ = {c: i for i, c in enumerate(self.classes_)} 141 | logging.info("Labels: {}".format(self.classes_)) 142 | return 143 | 144 | def get_weights(self): 145 | """Count number of instances for each class, and computes weights.""" 146 | 147 | # Get classes 148 | self.process_classes() 149 | all_labels = [self.class_to_idx_[i[1]] for i in self.instances_] 150 | 151 | # Count number of slides in each class 152 | count = Counter(all_labels) 153 | count = torch.Tensor([count[i] for i in range(len(count))]) 154 | logging.info("Count: {}".format(count)) 155 | 156 | # Compute weights 157 | inv_count = 1 / count 158 | self.weights_ = inv_count / torch.sum(inv_count) 159 | logging.debug("Weights: {}".format(self.weights_)) 160 | return self.weights_ 161 | 162 | def replicate_balance_instances(self): 163 | """resample the instances list to balance each class.""" 164 | all_labels = [i[1] for i in self.instances_] 165 | val_sample = max(Counter(all_labels).values()) 166 | 167 | all_instances_ = [] 168 | for l in sorted(set(all_labels)): 169 | instances_l = [i for i in self.instances_ if i[1] == l] 170 | random.shuffle(instances_l) 171 | instances_l = instances_l * (val_sample // len(instances_l) + 1) 172 | all_instances_.extend(sorted(instances_l[:val_sample])) 173 | 174 | self.instances_ = all_instances_ 175 | return 176 | 177 | def __len__(self): 178 | """Returns the length of the dataset""" 179 | return len(self.instances_) 180 | 181 | def __getitem__(self, idx): 182 | """Retrieve a list of slides specified by idx""" 183 | 184 | slide, target, tag_list = self.instances_[idx] 185 | target = self.class_to_idx_[target] 186 | 187 | if self.target_transform_ is not None: 188 | target = self.target_transform_(target) 189 | 190 | instance = { 191 | "embeddings": [None for _ in range(self.num_transforms_)], 192 | "coords": [None for _ in range(self.num_transforms_)], 193 | "path": tag_list[0], 194 | "label": target, 195 | } #yapf:disable 196 | 197 | for transform_idx in range(self.num_transforms_): 198 | pt_path = random.choice(tag_list) 199 | inst_ = torch.load(pt_path) 200 | 201 | instance["embeddings"][transform_idx] = inst_["embeddings"] 202 | instance["coords"][transform_idx] = torch.tensor(inst_["coords"]) 203 | 204 | del inst_ 205 | 206 | if self.transform_: 207 | instance = self.transform_(instance) 208 | 209 | return instance 210 | -------------------------------------------------------------------------------- /fastglioma/models/mil.py: -------------------------------------------------------------------------------- 1 | """Whole slide transformer model. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | from fastglioma.models.vit import Block 12 | 13 | 14 | class FFPEG(nn.Module): 15 | """Fourier feature positional embedding generator module. 16 | 17 | References: 18 | Learnable Fourier Features for Multi-Dimensional Spatial Positional Encoding, Li et al. NeurIPS 2021 19 | 20 | Attributes: 21 | embed_dim: dimension of the embedding 22 | dim_ff: dimension of the fourier features 23 | dim_mlp: dimension of the mlp 24 | gamma: std of the normal distribution used to initialize the fourier features 25 | prefix_len: number of registers + cls token 26 | pos_emb_grad: whether to allow the fourier features to be optimized 27 | """ 28 | def __init__(self, 29 | embed_dim: int, 30 | dim_ff: int = 96, 31 | dim_mlp: int = 36, 32 | gamma: float = .25, 33 | prefix_len: int = 0, 34 | pos_emb_grad: bool = True, 35 | **kwargs): 36 | super(FFPEG, self).__init__() 37 | self.dim_ff_ = dim_ff 38 | self.dim_mlp_ = dim_mlp 39 | self.gamma_ = gamma 40 | self.embed_dim_ = embed_dim 41 | self.prefix_len = prefix_len 42 | 43 | self.cls_pos_emb = nn.Parameter(torch.zeros(1, self.prefix_len, embed_dim), #yapf:disable 44 | requires_grad=True) 45 | 46 | self.num_pos_ = 1 # G 47 | self.pos_dim_ = 2 # M 48 | 49 | self._ff_embed = nn.Linear(self.pos_dim_, dim_ff // 2, bias=False) 50 | torch.nn.init.normal_(self._ff_embed.weight, mean=0, std=gamma) 51 | if not pos_emb_grad: 52 | for param in self._ff_embed.parameters(): 53 | param.requires_grad = False 54 | 55 | self._mlp = nn.Sequential(*[ 56 | nn.LayerNorm(dim_ff), 57 | nn.Linear(dim_ff, dim_mlp), 58 | nn.GELU(), 59 | nn.LayerNorm(dim_mlp), 60 | nn.Linear(dim_mlp, embed_dim // self.num_pos_) 61 | ]) 62 | 63 | def init_weights(m): 64 | if isinstance(m, nn.Linear): 65 | torch.nn.init.xavier_uniform_(m.weight) 66 | m.bias.data.fill_(0.01) 67 | 68 | self._mlp.apply(init_weights) 69 | 70 | def forward(self, H, coords, return_ff: bool = False): 71 | bsz, n = H.shape[0], H.shape[1] 72 | n = n - self.prefix_len 73 | 74 | x = coords.unsqueeze(0).float().unsqueeze(-2) # NxGxM (G=1, M=2) 75 | x = x.to(H.device) 76 | 77 | ff_vec = self._ff_embed(x) # NxGx(F/2) 78 | 79 | f = torch.cat([torch.cos(ff_vec), torch.sin(ff_vec)], axis=-1) 80 | f = 1 / np.sqrt(self.dim_ff_) * f # NxGxF 81 | 82 | if return_ff: return f 83 | 84 | pe = self._mlp(f).reshape(bsz, n, self.embed_dim_) 85 | pe = torch.cat((self.cls_pos_emb.repeat(bsz, 1, 1), pe), dim=1) 86 | 87 | return H + pe 88 | 89 | 90 | class MIL_forward(nn.Module): 91 | '''MIL module for batch forward 92 | 93 | Attributes: 94 | mil: process bag of instance embeddings (list) to produce a single bag embedding (tensor). 95 | ''' 96 | 97 | def __init__(self, mil: callable): 98 | super().__init__() 99 | self.mil = mil() 100 | self.num_out = self.mil.dim_out 101 | 102 | def forward_mil(self, 103 | bag_embed: list, 104 | return_embed: bool = False, 105 | **kwargs): 106 | '''Forward function for bag input 107 | 108 | Attributes: 109 | bag: A batch of bags, each bag will have the various number of instances. 110 | return_embed: return a batch of bag embeddings. 111 | ''' 112 | if 'coords' in kwargs: 113 | batch_embed = torch.stack([ 114 | self.mil(insta, coords=coords) 115 | for insta, coords in zip(bag_embed, kwargs['coords']) 116 | ]).squeeze(1) # bsz * bagsize * emb_di -> bsz * emb_dim 117 | else: 118 | batch_embed = torch.stack([ 119 | self.mil(insta) for insta in bag_embed 120 | ]).squeeze(1) # bsz * bagsize * emb_di -> bsz * emb_dim 121 | return batch_embed 122 | 123 | def forward(self, bag, return_embed: bool = False, **kwargs): 124 | return self.forward_mil(bag, **kwargs) 125 | 126 | 127 | class MIL_Classifier(nn.Module): 128 | '''MIL module for classification task. 129 | 130 | Attributes: 131 | backbone: process instances (list) into instances embeddings (list). 132 | mil: process bag of instance embeddings (list) to produce a single bag embedding (tensor). 133 | head: process the bag embedding (tensor) to bag logits (tensor). 134 | ''' 135 | 136 | def __init__(self, backbone: callable, mil: callable, head: callable): 137 | super().__init__() 138 | if backbone: 139 | self.backbone = backbone() 140 | if head: 141 | self.head = head() 142 | self.bb = mil() 143 | 144 | def forward(self, bag: list, return_embed: bool = False, **kwargs): 145 | ''' forward for bag input 146 | 147 | Attributes: 148 | bag: A batch of bags, each bag will have the various number of instances. 149 | return_embed: return a batch of bag embeddings. 150 | ''' 151 | bag_embed = bag 152 | if hasattr(self, 'backbone'): 153 | bag_embed = [ 154 | self.backbone(insta) for insta in bag 155 | ] # bsz * bagsize * input_dim -> bsz * bagsize * emb_dim 156 | 157 | batch_embed = self.bb.forward(bag_embed, coords=kwargs['coords']) 158 | if return_embed: 159 | return {"embeddings": batch_embed} 160 | 161 | if hasattr(self, 'head'): 162 | batch_logits = self.head(batch_embed) # bsz * emb_dim -> bsz 163 | return {"logits": batch_logits, "embeddings": batch_embed} 164 | return {"embeddings": batch_embed} 165 | 166 | 167 | class Identity(nn.Identity): 168 | 169 | def __init__(self): 170 | super().__init__() 171 | 172 | def forward(self, x, **kwargs): 173 | return x 174 | 175 | 176 | class TransformerMIL(torch.nn.Module): 177 | """Transformer module for MIL. 178 | 179 | Attributes: 180 | global_pool: global pooling method 181 | embed_dim: dimension of the embedding 182 | depth: number of layers 183 | num_heads: number of attention heads 184 | mlp_ratio: ratio of the MLP hidden dimension to the embedding dimension 185 | qkv_bias: whether to use bias in the QKV linear layer 186 | pos_emb_type: type of positional embedding 187 | """ 188 | def __init__(self, 189 | global_pool='token', 190 | embed_dim=768, 191 | depth=12, 192 | num_heads=12, 193 | mlp_ratio=4., 194 | qkv_bias=False, 195 | qk_scale=None, 196 | pos_emb_type=None, 197 | drop_rate=0., 198 | attn_drop_rate=0., 199 | drop_path_rate=0., 200 | norm_layer=nn.LayerNorm, 201 | **kwargs): 202 | super().__init__() 203 | self.global_pool = global_pool 204 | assert self.global_pool in ['', 'avg', 'token'] 205 | assert embed_dim % num_heads == 0, "embed_dim should be divisiable by num_heads in transformer" 206 | self.cls_token = nn.Parameter(torch.zeros(1, kwargs.get("prefix_len", 1), embed_dim)) 207 | 208 | self.pos_embed = Identity() 209 | if pos_emb_type: 210 | self.pos_embed = FFPEG(seq_len=30 * 30, 211 | embed_dim=embed_dim, 212 | **kwargs) 213 | 214 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 215 | self.blocks = nn.ModuleList([ 216 | Block(dim=embed_dim, 217 | num_heads=num_heads, 218 | mlp_ratio=mlp_ratio, 219 | qkv_bias=qkv_bias, 220 | qk_scale=qk_scale, 221 | drop=drop_rate, 222 | attn_drop=attn_drop_rate, 223 | drop_path=dpr[i], 224 | norm_layer=norm_layer) for i in range(depth) 225 | ]) 226 | self.norm = norm_layer(embed_dim) 227 | self.dim_out = embed_dim 228 | 229 | def forward_features(self, x, **kwargs): 230 | if self.cls_token is not None: 231 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), 232 | dim=1) 233 | x = self.pos_embed(x, **kwargs) 234 | 235 | for blk in self.blocks: 236 | x = blk(x) 237 | x = self.norm(x) 238 | if self.global_pool: 239 | x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 240 | return x 241 | 242 | def forward_attention(self, x, **kwargs): 243 | if self.cls_token is not None: 244 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), 245 | dim=1) 246 | if len(x.shape) == 2: 247 | x = x.unsqueeze(0) 248 | x = self.pos_embed(x, **kwargs) 249 | 250 | for i, blk in enumerate(self.blocks): 251 | if i < len(self.blocks) - 1: 252 | x = blk(x) 253 | else: 254 | return blk(x, return_attention=True) 255 | 256 | @torch.inference_mode() 257 | def forward_attention_all_blocks(self, x, **kwargs): 258 | if self.cls_token is not None: 259 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), 260 | dim=1) 261 | if len(x.shape) == 2: 262 | x = x.unsqueeze(0) 263 | x = self.pos_embed(x, **kwargs) 264 | 265 | out = [] 266 | for _, blk in enumerate(self.blocks): 267 | out.append(blk(x, return_attention=True)) 268 | x = blk(x) 269 | 270 | return out 271 | 272 | def forward(self, x, **kwargs): 273 | if len(x.shape) == 2: 274 | x = x.unsqueeze(0) 275 | x = self.forward_features(x, **kwargs) 276 | return x -------------------------------------------------------------------------------- /fastglioma/train/train_slide.py: -------------------------------------------------------------------------------- 1 | """Slide SSL with SCM/VICReg pretraining script. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import yaml 8 | import logging 9 | from functools import partial 10 | from typing import Dict, Any, List 11 | 12 | import torch 13 | 14 | import pytorch_lightning as pl 15 | import torchmetrics 16 | 17 | from fastglioma.models.cnn import MLP, VICRegNetwork 18 | from fastglioma.models.mil import TransformerMIL, MIL_forward 19 | from fastglioma.utils.common import (setup_output_dirs, parse_args, 20 | get_exp_name, config_loggers, 21 | get_optimizer_func, get_scheduler_func, 22 | get_num_worker) 23 | from fastglioma.losses.vicreg import GeneralVICRegLoss 24 | 25 | 26 | class SlideSSLSystem(pl.LightningModule): 27 | """Lightning system for slide ssl experiments.""" 28 | 29 | def __init__(self, cf: Dict[str, Any], num_it_per_ep: int): 30 | super().__init__() 31 | self.cf_ = cf 32 | self.num_it_per_ep_ = num_it_per_ep 33 | 34 | if cf["model"]["backbone"]["which"] == "transformer": 35 | mil = partial(MIL_forward, 36 | mil=partial(TransformerMIL, 37 | **cf["model"]["backbone"]["params"])) 38 | else: 39 | raise NotImplementedError() 40 | 41 | mlp = partial(MLP, 42 | n_in=mil().num_out, 43 | hidden_layers=cf["model"]["mlp_hidden"], 44 | n_out=cf["model"]["num_embedding_out"]) 45 | self.model = VICRegNetwork(mil, mlp) 46 | 47 | if "training" in cf: 48 | self.criterion = GeneralVICRegLoss( 49 | embedding_dim=cf["model"]["num_embedding_out"], 50 | **cf["training"]["objective"]["params"]) 51 | self.train_loss = torch.nn.ModuleDict({ 52 | n: torchmetrics.MeanMetric() 53 | for n in GeneralVICRegLoss.get_loss_names() 54 | }) 55 | self.val_loss = torch.nn.ModuleDict({ 56 | n: torchmetrics.MeanMetric() 57 | for n in GeneralVICRegLoss.get_loss_names() 58 | }) 59 | else: 60 | self.train_loss = self.val_loss = None 61 | 62 | @staticmethod 63 | def get_kth_view(data: List[List[torch.Tensor]], k: int): 64 | return [d[k] for d in data] 65 | 66 | def forward(self, batch): 67 | pred = [ 68 | self.model(self.get_kth_view(batch["embeddings"], 0), 69 | coords=self.get_kth_view(batch["coords"], 0)), 70 | self.model(self.get_kth_view(batch["embeddings"], 1), 71 | coords=self.get_kth_view(batch["coords"], 1)) 72 | ] 73 | 74 | pred = torch.stack(pred, dim=1) 75 | pred_gather = self.all_gather(pred, sync_grads=True) 76 | pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:]) 77 | 78 | losses = self.criterion(pred_gather) 79 | 80 | return losses 81 | 82 | def training_step(self, batch, _): 83 | losses = self.forward(batch) 84 | bs = len(batch['embeddings']) * torch.distributed.get_world_size() 85 | 86 | for k in self.train_loss: 87 | self.log(f"train/{k}", 88 | losses[k], 89 | on_step=True, 90 | on_epoch=False, 91 | batch_size=bs, 92 | rank_zero_only=True) 93 | self.train_loss[k].update(losses[k], weight=bs) 94 | 95 | return losses["loss"] 96 | 97 | @torch.inference_mode() 98 | def validation_step(self, batch, _): 99 | losses = self.forward(batch) 100 | bs = len(batch['embeddings']) * torch.distributed.get_world_size() 101 | for k in self.val_loss: 102 | self.val_loss[k].update(losses[k], weight=bs) 103 | 104 | def on_train_epoch_end(self): 105 | torch.cuda.empty_cache() 106 | losses = {} 107 | for k in self.train_loss.keys(): 108 | losses[k] = self.train_loss[k].compute() 109 | self.log(f"train/{k}_manualepoch", 110 | losses[k], 111 | on_epoch=True, 112 | sync_dist=True, 113 | rank_zero_only=True) 114 | self.train_loss[k].reset() 115 | logging.info(f"train/manualepoch {losses}") 116 | 117 | @torch.inference_mode() 118 | def on_validation_epoch_end(self): 119 | losses = {} 120 | for k in self.val_loss.keys(): 121 | losses[k] = self.val_loss[k].compute() 122 | self.log(f"valid/{k}_manualepoch", 123 | losses[k], 124 | on_epoch=True, 125 | sync_dist=True, 126 | rank_zero_only=True) 127 | self.val_loss[k].reset() 128 | logging.info(f"valid/manualepoch {losses}") 129 | 130 | @torch.inference_mode() 131 | def predict_step(self, batch, batch_idx): 132 | out = self.model.bb( 133 | self.get_kth_view(batch["embeddings"], 0), 134 | coords=self.get_kth_view(batch["coords"], 0)) 135 | 136 | return { 137 | "path": [batch["path"]], 138 | "label": batch["label"], 139 | "embeddings": out 140 | } 141 | 142 | def configure_ddp(self, *args, **kwargs): 143 | logging.basicConfig(level=logging.INFO) 144 | return super().configure_ddp(*args, **kwargs) 145 | 146 | def configure_optimizers(self): 147 | # if not training, no optimizer 148 | if "training" not in self.cf_: 149 | return None 150 | 151 | # get optimizer 152 | opt = get_optimizer_func(self.cf_)(self.model.parameters()) 153 | 154 | # check if use a learn rate scheduler 155 | sched_func = get_scheduler_func(self.cf_, self.num_it_per_ep_) 156 | if not sched_func: 157 | return opt 158 | 159 | # get learn rate scheduler 160 | lr_scheduler_config = { 161 | "scheduler": sched_func(opt), 162 | "interval": "step", 163 | "frequency": 1, 164 | "name": "lr" 165 | } 166 | 167 | return [opt], lr_scheduler_config 168 | 169 | from fastglioma.datasets.embedding_dataset import SlideEmbeddingDataset 170 | from fastglioma.datasets.emb_proc import get_emb_transformations, emb_collate_fn 171 | def get_dataloaders(cf): 172 | """Create dataloader for contrastive experiments.""" 173 | train_xform, valid_xform = get_emb_transformations(cf) 174 | 175 | logging.info(f"train_xform\n{train_xform}") 176 | logging.info(f"valid_xform\n{valid_xform}") 177 | 178 | train_dset = SlideEmbeddingDataset( 179 | data_root=cf["data"]["db_root"], 180 | embedding_root=cf["data"]["embedding_root"], 181 | tag=cf["data"]["tag"], 182 | studies="train", 183 | transform=train_xform, 184 | balance_slide_per_class=cf["data"]["balance_study_per_class"], 185 | num_transforms=cf["data"]["num_transforms"]) 186 | val_dset = SlideEmbeddingDataset( 187 | data_root=cf["data"]["db_root"], 188 | embedding_root=cf["data"]["embedding_root"], 189 | tag=cf["data"]["tag"], 190 | studies="val", 191 | transform=valid_xform, 192 | balance_slide_per_class=False, 193 | num_transforms=cf["data"]["num_transforms"]) 194 | 195 | dataloader_callable = partial(torch.utils.data.DataLoader, 196 | batch_size=cf['training']['batch_size'], 197 | drop_last=False, 198 | pin_memory=True, 199 | num_workers=get_num_worker(), 200 | persistent_workers=True, 201 | collate_fn=emb_collate_fn) 202 | 203 | return dataloader_callable(train_dset, 204 | shuffle=True), dataloader_callable(val_dset, 205 | shuffle=True) 206 | 207 | 208 | def main(): 209 | cf_fd = parse_args() 210 | cf = yaml.load(cf_fd, Loader=yaml.FullLoader) 211 | exp_root, model_dir, cp_config = setup_output_dirs(cf, get_exp_name, "") 212 | pl.seed_everything(cf["infra"]["seed"]) 213 | 214 | # logging and copying config files 215 | cp_config(cf_fd.name) 216 | config_loggers(exp_root) 217 | 218 | train_loader, valid_loader = get_dataloaders(cf) 219 | 220 | logging.info(f"num devices: {torch.cuda.device_count()}") 221 | logging.info(f"num workers in dataloader: {train_loader.num_workers}") 222 | 223 | num_it_per_ep = len(train_loader) 224 | if torch.cuda.device_count() > 1: 225 | num_it_per_ep //= torch.cuda.device_count() 226 | 227 | exp = SlideSSLSystem(cf, num_it_per_ep) 228 | 229 | # config loggers 230 | logger = [ 231 | pl.loggers.TensorBoardLogger(save_dir=exp_root, name="tb"), 232 | pl.loggers.CSVLogger(save_dir=exp_root, name="csv") 233 | ] 234 | 235 | # config callbacks 236 | epoch_ckpt = pl.callbacks.ModelCheckpoint( 237 | dirpath=model_dir, 238 | save_top_k=-1, 239 | every_n_epochs=cf["training"]["eval_ckpt_ep_freq"], 240 | filename="ckpt-epoch{epoch}", 241 | auto_insert_metric_name=False) 242 | lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step", 243 | log_momentum=False) 244 | 245 | # create trainer 246 | trainer = pl.Trainer( 247 | accelerator="gpu", 248 | devices=-1, 249 | default_root_dir=exp_root, 250 | strategy=pl.strategies.DDPStrategy(find_unused_parameters=False, 251 | static_graph=True), 252 | logger=logger, 253 | log_every_n_steps=10, 254 | callbacks=[epoch_ckpt, lr_monitor], 255 | max_epochs=cf["training"]["num_epochs"], 256 | check_val_every_n_epoch=cf["training"]["eval_ckpt_ep_freq"], 257 | precision=cf["training"].get("amp", "32"), 258 | deterministic=cf["training"].get("deterministic", False), 259 | num_nodes=1) 260 | trainer.fit(exp, 261 | train_dataloaders=train_loader, 262 | val_dataloaders=valid_loader) 263 | 264 | 265 | if __name__ == '__main__': 266 | main() 267 | -------------------------------------------------------------------------------- /fastglioma/tf/feedforward.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | from tqdm import tqdm 5 | from typing import Dict, Any, List 6 | import logging 7 | 8 | import torch 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | import numpy as np 12 | 13 | from huggingface_hub import hf_hub_download 14 | from fastglioma.inference.run_inference import FastGliomaInferenceSystem 15 | from fastglioma.datasets.srh_dataset import SlideDataset, slide_collate_fn 16 | from fastglioma.datasets.improc import get_srh_base_aug, get_strong_aug 17 | from torchvision.transforms import Compose 18 | 19 | from fastglioma.utils.common import (parse_args, get_exp_name, config_loggers, 20 | get_num_worker) 21 | from fastglioma.inference.run_inference import setup_eval_paths 22 | 23 | 24 | def create_tf_dataset(cf: Dict[str, Any]): 25 | """Create TensorFlow dataset from PyTorch dataset""" 26 | 27 | def get_transform(cf): 28 | return Compose( 29 | get_srh_base_aug( 30 | base_aug=("three_channels" if cf["data"]["patch_input"] == 31 | "highres" else "ch2_only")) + 32 | (get_strong_aug([{ 33 | "which": "inpaint_rows_always_apply", 34 | "params": { 35 | "image_size": 300, 36 | "y_skip": 5 37 | } 38 | }], 1.) if cf["data"]["patch_input"] == "lowres" else [])) 39 | 40 | # Use the existing PyTorch dataset 41 | dset = SlideDataset(data_root=cf["data"]["db_root"], 42 | studies=cf["data"]["studies"], 43 | transform=get_transform(cf), 44 | balance_slide_per_class=False, 45 | use_patient_class=cf["data"]["use_patient_class"]) 46 | 47 | # Create PyTorch dataloader 48 | loader = torch.utils.data.DataLoader( 49 | dset, 50 | batch_size=cf["eval"]["predict_batch_size"], 51 | drop_last=False, 52 | collate_fn=slide_collate_fn, 53 | num_workers=get_num_worker()) 54 | 55 | return loader, dset.class_to_idx_ 56 | 57 | 58 | def process_predictions(predictions: Dict, class_to_idx: Dict) -> Dict: 59 | """Process predictions similar to PyTorch version""" 60 | pred = {} 61 | 62 | # Convert logits to probabilities 63 | pred["logits"] = predictions["logits"] 64 | pred["scores"] = tf.sigmoid(pred["logits"]).numpy() 65 | 66 | # Convert numeric labels back to strings 67 | idx_to_class = {v: k for k, v in class_to_idx.items()} 68 | pred["label"] = [idx_to_class[l] for l in predictions["labels"]] 69 | 70 | # Get slide names 71 | pred["slide"] = ["/".join(imp[0][0].split("/")[:9]) for imp in predictions["paths"]] #yapf:disable 72 | 73 | # Store embeddings 74 | pred["embeddings"] = predictions["embeddings"] 75 | 76 | # Sort predictions by slide name 77 | sorted_indices = sorted(range(len(pred['slide'])), 78 | key=lambda k: pred['slide'][k]) 79 | 80 | # Apply sorting to all fields 81 | for key in pred: 82 | if isinstance(pred[key], list): 83 | pred[key] = [pred[key][i] for i in sorted_indices] 84 | elif isinstance(pred[key], np.ndarray): 85 | pred[key] = pred[key][sorted_indices] 86 | 87 | return pred 88 | 89 | 90 | def get_tf_predictions(cf: Dict[str, Any], 91 | model_dict: Dict[str, tf.keras.Model]) -> Dict: 92 | """Run inference using TensorFlow models""" 93 | 94 | loader, class_to_idx = create_tf_dataset(cf) 95 | 96 | all_predictions = { 97 | "logits": [], 98 | "embeddings": [], 99 | "labels": [], 100 | "paths": [] 101 | } 102 | 103 | # Run inference 104 | for batch in tqdm(loader): 105 | # Get first view of images and coordinates 106 | images = tf.cast(np.transpose(batch["image"][0][0].numpy(), 107 | (0, 2, 3, 1)), 108 | dtype=tf.float32) 109 | coords = tf.cast(tf.expand_dims(batch["coords"][0][0].numpy(), axis=0), 110 | dtype=tf.float32) 111 | 112 | # Forward pass through models 113 | patch_embeddings = tf.squeeze(model_dict["resnet"](images)) 114 | patch_embeddings = tf.expand_dims(patch_embeddings, axis=0) 115 | 116 | slide_embedding = model_dict["transformer"](patch_embeddings, 117 | coords=coords) 118 | logits = model_dict["head"](slide_embedding) 119 | 120 | # Store predictions 121 | all_predictions["logits"].append(logits.numpy()) 122 | all_predictions["embeddings"].append(slide_embedding.numpy()) 123 | all_predictions["labels"].append(batch["label"].numpy()) 124 | all_predictions["paths"].append(batch["path"]) 125 | 126 | # Concatenate results 127 | for key in ["logits", "embeddings", "labels"]: 128 | all_predictions[key] = np.concatenate(all_predictions[key]) 129 | 130 | # Process predictions into final format 131 | predictions = process_predictions(all_predictions, class_to_idx) 132 | return predictions 133 | 134 | 135 | def compare_pytorch_tensorflow_outputs(pl_system, 136 | model_dict, 137 | num_channels=3, 138 | batch_size=4, 139 | num_patches=10): 140 | """Compare outputs between PyTorch and TensorFlow models using dummy data. 141 | 142 | Args: 143 | pl_system: PyTorch Lightning system containing the PyTorch models 144 | model_dict: Dictionary containing TensorFlow models 145 | batch_size: Number of samples in batch 146 | num_patches: Number of patches per sample 147 | """ 148 | # Create dummy data 149 | dummy_images = torch.randn(batch_size, num_patches, num_channels, 224, 224) 150 | dummy_coords = torch.randn(batch_size, num_patches, 2) 151 | 152 | # PyTorch forward pass 153 | with torch.no_grad(): 154 | # Convert to expected format and run through models 155 | pt_patch_embeddings = pl_system.model.backbone(dummy_images.view(-1, num_channels, 224, 224)) #yapf:disable 156 | pt_patch_embeddings = pt_patch_embeddings.view(batch_size, num_patches, 157 | -1) 158 | pt_slide_embedding = pl_system.model.bb.mil(pt_patch_embeddings, 159 | coords=dummy_coords) 160 | pt_logits = pl_system.model.head(pt_slide_embedding) 161 | 162 | # TensorFlow forward pass 163 | # Reshape and transpose images for TF format (B*N, H, W, C) 164 | tf_images = tf.transpose(dummy_images.numpy(), (0, 1, 3, 4, 2)) 165 | tf_images = tf.reshape(tf_images, (-1, 224, 224, num_channels)) 166 | tf_coords = tf.cast(dummy_coords.numpy(), dtype=tf.float32) 167 | 168 | # Run through TF models 169 | tf_patch_embeddings = tf.squeeze(model_dict["resnet"](tf_images)) 170 | tf_patch_embeddings = tf.reshape(tf_patch_embeddings, 171 | (batch_size, num_patches, -1)) 172 | tf_slide_embedding = model_dict["transformer"](tf_patch_embeddings, 173 | coords=tf_coords) 174 | tf_logits = model_dict["head"](tf_slide_embedding) 175 | 176 | # Compare outputs 177 | print("\nOutput Comparisons:") 178 | print("------------------") 179 | print(f"Patch Embeddings - Max Diff: {np.max(np.abs(pt_patch_embeddings.numpy() - tf_patch_embeddings.numpy())):.6f}") #yapf:disable 180 | print(f"Slide Embeddings - Max Diff: {np.max(np.abs(pt_slide_embedding.numpy() - tf_slide_embedding.numpy())):.6f}") #yapf:disable 181 | print(f"Logits - Max Diff: {np.max(np.abs(pt_logits.numpy() - tf_logits.numpy())):.6f}") #yapf:disable 182 | 183 | 184 | def main(): 185 | """Driver script for inference pipeline.""" 186 | logging.basicConfig(level=logging.DEBUG) 187 | logger = logging.getLogger(__name__) 188 | 189 | cf_fd = parse_args() 190 | cf = yaml.load(cf_fd, Loader=yaml.FullLoader) 191 | exp_root, pred_dir, cp_config = setup_eval_paths(cf, get_exp_name, "") 192 | 193 | torch.manual_seed(cf["infra"]["seed"]) 194 | tf.random.set_seed(cf["infra"]["seed"]) 195 | 196 | # Logging and copying config files 197 | cp_config(cf_fd.name) 198 | config_loggers(exp_root) 199 | 200 | # Load PyTorch model and convert to TF 201 | ckpt_path = hf_hub_download(repo_id=cf["infra"]["hf_repo"], 202 | filename=cf["eval"]["ckpt_path"]) 203 | pl_system = FastGliomaInferenceSystem.load_from_checkpoint(ckpt_path, 204 | cf=cf, 205 | num_it_per_ep=0) 206 | 207 | # Convert models to TensorFlow 208 | from fastglioma.tf.resnet import resnet_backbone, convert_resnet_weights 209 | from fastglioma.tf.transformer import TransformerMIL, convert_pytorch_transformer_to_tf 210 | 211 | # Initialize TF models 212 | tf_resnet = resnet_backbone(arch=cf["model"]["patch"]["backbone"]["which"], 213 | **cf["model"]["patch"]["backbone"]["params"]) 214 | tf_transformer = TransformerMIL(**cf["model"]["slide"]["mil"]["params"]) 215 | 216 | tf_head = [] 217 | 218 | if len(cf["model"]["slide"].get("mlp_hidden", [])) > 0: 219 | for hidden_dim in cf["model"]["slide"]["mlp_hidden"]: 220 | tf_head.append(keras.layers.Dense(hidden_dim, use_bias=True)) 221 | tf_head.append(keras.layers.ReLU()) 222 | tf_head.append(keras.layers.Dense(1, use_bias=True)) 223 | 224 | tf_head = keras.Sequential(tf_head) 225 | 226 | # Convert weights 227 | torch_resnet = pl_system.model.backbone.eval() 228 | torch_transformer = pl_system.model.bb.mil.eval() 229 | torch_head = pl_system.model.head.eval() 230 | 231 | tf_resnet = convert_resnet_weights( 232 | torch_resnet, 233 | tf_resnet, 234 | num_channel_in=cf["model"]["patch"]["backbone"]["params"].get( 235 | "num_channel_in", 3)) 236 | tf_transformer = convert_pytorch_transformer_to_tf(torch_transformer, 237 | tf_transformer) 238 | 239 | # Convert head weights 240 | dummy_input = tf.zeros((1, 512), dtype=tf.float32) 241 | _ = tf_head(dummy_input) 242 | state_dict = torch_head.state_dict() 243 | tf_weights = [ 244 | state_dict['layers.0.weight'].numpy().transpose(), 245 | state_dict['layers.0.bias'].numpy(), 246 | state_dict['layers.2.weight'].numpy().transpose(), 247 | state_dict['layers.2.bias'].numpy() 248 | ] 249 | tf_head.set_weights(tf_weights) 250 | 251 | # Create model dictionary 252 | model_dict = { 253 | "resnet": tf_resnet, 254 | "transformer": tf_transformer, 255 | "head": tf_head 256 | } 257 | 258 | if cf["eval"]["compare_to_torch"]: 259 | comparison_results = compare_pytorch_tensorflow_outputs( 260 | pl_system, 261 | model_dict, 262 | num_channels=cf["model"]["patch"]["backbone"]["params"].get( 263 | "num_channel_in", 3)) 264 | 265 | # Run inference 266 | predictions = get_tf_predictions(cf, model_dict) 267 | torch.save(predictions, os.path.join(pred_dir, "tf_predictions.pt")) 268 | 269 | 270 | if __name__ == "__main__": 271 | main() 272 | -------------------------------------------------------------------------------- /fastglioma/train/train_scorer.py: -------------------------------------------------------------------------------- 1 | """Slide-level training with ordmet script. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import yaml 8 | import logging 9 | from functools import partial 10 | from typing import Dict, Any, List 11 | 12 | import torch 13 | 14 | import pytorch_lightning as pl 15 | import torchmetrics 16 | 17 | from fastglioma.models.cnn import MLP 18 | from fastglioma.models.mil import TransformerMIL, MIL_forward, MIL_Classifier 19 | from fastglioma.utils.common import (setup_output_dirs, parse_args, 20 | get_exp_name, config_loggers, 21 | get_optimizer_func, get_scheduler_func, 22 | get_num_worker) 23 | from fastglioma.losses.ordmet import OrdinalMetricLoss 24 | 25 | 26 | class SlideOrdMetSystem(pl.LightningModule): 27 | """Lightning system for slide ssl experiments.""" 28 | 29 | def __init__(self, cf: Dict[str, Any], num_it_per_ep: int): 30 | super().__init__() 31 | self.cf_ = cf 32 | self.num_it_per_ep_ = num_it_per_ep 33 | 34 | if cf["model"]["backbone"]["which"] == "transformer": 35 | mil = partial(MIL_forward, 36 | mil=partial(TransformerMIL, 37 | **cf["model"]["backbone"]["params"])) 38 | else: 39 | raise NotImplementedError() 40 | 41 | mlp = partial(MLP, 42 | n_in=mil().num_out, 43 | hidden_layers=cf["model"]["mlp_hidden"], 44 | n_out=1) 45 | self.model = MIL_Classifier(None, mil, mlp) 46 | 47 | if "training" in cf: 48 | self.criterion = OrdinalMetricLoss(**self.cf_["training"]["objective"]["params"]) 49 | self.train_loss = torchmetrics.MeanMetric() 50 | self.val_loss = torchmetrics.MeanMetric() 51 | else: 52 | self.criterion = self.train_loss = self.val_loss = None 53 | 54 | @staticmethod 55 | def get_kth_view(data: List[List[torch.Tensor]], k: int): 56 | return [d[k] for d in data] 57 | 58 | def training_step(self, batch, batch_idx): 59 | pred = self.model(self.get_kth_view(batch["embeddings"], 0), 60 | coords=self.get_kth_view(batch["coords"], 0))["logits"] 61 | pred_gather = self.all_gather(pred, sync_grads=True) 62 | pred_gather = pred_gather.reshape(-1, 1) 63 | # pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:]) 64 | label_gather = self.all_gather(batch["label"]).reshape(-1, 1) 65 | 66 | loss = self.criterion(pred_gather, label_gather)["loss"] 67 | bs = len(batch["embeddings"]) * torch.distributed.get_world_size() 68 | self.log("train/loss", 69 | loss, 70 | on_step=True, 71 | on_epoch=True, 72 | batch_size=bs, 73 | rank_zero_only=True) 74 | self.train_loss.update(loss, weight=bs) 75 | 76 | return loss 77 | 78 | @torch.inference_mode() 79 | def validation_step(self, batch, batch_idx): 80 | bs = len(batch["embeddings"]) * torch.distributed.get_world_size() 81 | 82 | pred = self.model(self.get_kth_view(batch["embeddings"], 0), 83 | coords=self.get_kth_view(batch["coords"], 0))["logits"] 84 | pred_gather = self.all_gather(pred, sync_grads=True) 85 | pred_gather = pred_gather.reshape(-1, 1) 86 | # pred_gather = pred_gather.reshape(-1, *pred_gather.shape[-2:]) 87 | label_gather = self.all_gather(batch["label"]).reshape(-1, 1) 88 | 89 | loss = self.criterion(pred_gather, label_gather)["loss"] 90 | bs = len(batch["embeddings"]) * torch.distributed.get_world_size() 91 | self.log("val/loss", 92 | loss, 93 | on_step=True, 94 | on_epoch=True, 95 | batch_size=bs, 96 | rank_zero_only=True) 97 | self.val_loss.update(loss, weight=bs) 98 | 99 | def on_train_epoch_end(self): 100 | torch.cuda.empty_cache() 101 | 102 | # compute metrics 103 | train_loss = self.train_loss.compute() 104 | 105 | # log metrics 106 | self.log("train/loss", 107 | train_loss, 108 | on_epoch=True, 109 | sync_dist=True, 110 | rank_zero_only=True) 111 | self.train_loss.reset() 112 | 113 | log_metrics = {"ap": {}, "auroc": {}} 114 | 115 | @torch.inference_mode() 116 | def on_validation_epoch_end(self): 117 | # compute metrics 118 | val_loss = self.val_loss.compute() 119 | 120 | # log metrics 121 | self.log("val/loss", 122 | val_loss, 123 | on_epoch=True, 124 | sync_dist=True, 125 | rank_zero_only=True) 126 | self.val_loss.reset() 127 | 128 | def predict_step(self, batch, batch_idx): 129 | out = self.model(self.get_kth_view(batch["embeddings"], 0), 130 | coords=self.get_kth_view(batch["coords"], 0)) 131 | 132 | return { 133 | "path": [batch["path"]], 134 | "label": batch["label"], 135 | "logits": out["logits"], 136 | "embeddings": out["embeddings"] 137 | } 138 | 139 | def configure_ddp(self, *args, **kwargs): 140 | logging.basicConfig(level=logging.INFO) 141 | return super().configure_ddp(*args, **kwargs) 142 | 143 | def configure_optimizers(self): 144 | # if not training, no optimizer 145 | if "training" not in self.cf_: 146 | return None 147 | 148 | # get optimizer 149 | opt = get_optimizer_func(self.cf_)(self.model.parameters()) 150 | 151 | # check if use a learn rate scheduler 152 | sched_func = get_scheduler_func(self.cf_, self.num_it_per_ep_) 153 | if not sched_func: 154 | return opt 155 | 156 | # get learn rate scheduler 157 | lr_scheduler_config = { 158 | "scheduler": sched_func(opt), 159 | "interval": "step", 160 | "frequency": 1, 161 | "name": "lr" 162 | } 163 | 164 | return [opt], lr_scheduler_config 165 | 166 | from fastglioma.datasets.embedding_dataset import SlideEmbeddingDataset 167 | from fastglioma.datasets.emb_proc import get_emb_transformations, emb_collate_fn 168 | def get_dataloaders(cf): 169 | """Create dataloader for contrastive experiments.""" 170 | train_xform, valid_xform = get_emb_transformations(cf) 171 | 172 | logging.info(f"train_xform\n{train_xform}") 173 | logging.info(f"valid_xform\n{valid_xform}") 174 | 175 | train_dset = SlideEmbeddingDataset( 176 | data_root=cf["data"]["db_root"], 177 | embedding_root=cf["data"]["embedding_root"], 178 | tag=cf["data"]["tag"], 179 | studies="train", 180 | transform=train_xform, 181 | balance_slide_per_class=cf["data"]["balance_study_per_class"], 182 | use_patient_class=cf["data"]["use_patient_class"], 183 | meta_fname=cf["data"]["meta_fname"], 184 | num_transforms=cf["data"]["num_transforms"]) 185 | val_dset = SlideEmbeddingDataset( 186 | data_root=cf["data"]["db_root"], 187 | embedding_root=cf["data"]["embedding_root"], 188 | tag=cf["data"]["tag"], 189 | studies="val", 190 | transform=valid_xform, 191 | balance_slide_per_class=False, 192 | use_patient_class=cf["data"]["use_patient_class"], 193 | meta_fname=cf["data"]["meta_fname"], 194 | num_transforms=cf["data"]["num_transforms"]) 195 | 196 | dataloader_callable = partial(torch.utils.data.DataLoader, 197 | batch_size=cf['training']['batch_size'], 198 | drop_last=False, 199 | pin_memory=True, 200 | num_workers=get_num_worker(), 201 | persistent_workers=True, 202 | collate_fn=emb_collate_fn) 203 | 204 | return dataloader_callable(train_dset, 205 | shuffle=True), dataloader_callable(val_dset, 206 | shuffle=True) 207 | 208 | 209 | def main(): 210 | cf_fd = parse_args() 211 | cf = yaml.load(cf_fd, Loader=yaml.FullLoader) 212 | exp_root, model_dir, cp_config = setup_output_dirs(cf, get_exp_name, "") 213 | pl.seed_everything(cf["infra"]["seed"]) 214 | 215 | # logging and copying config files 216 | cp_config(cf_fd.name) 217 | config_loggers(exp_root) 218 | 219 | train_loader, valid_loader = get_dataloaders(cf) 220 | 221 | logging.info(f"num devices: {torch.cuda.device_count()}") 222 | logging.info(f"num workers in dataloader: {train_loader.num_workers}") 223 | 224 | num_it_per_ep = len(train_loader) 225 | if torch.cuda.device_count() > 1: 226 | num_it_per_ep //= torch.cuda.device_count() 227 | 228 | exp = SlideOrdMetSystem(cf, num_it_per_ep) 229 | 230 | if "load_backbone" in cf["training"]: 231 | # load lightning checkpint 232 | ckpt_dict = torch.load(cf["training"]["load_backbone"].get("ckpt_path", None), 233 | map_location="cpu") 234 | 235 | mil_state_dict = { 236 | k.removeprefix("model.bb.mil."): ckpt_dict["state_dict"][k] 237 | for k in ckpt_dict["state_dict"] if "model.bb.mil" in k 238 | } 239 | 240 | exp.model.bb.mil.load_state_dict(mil_state_dict) 241 | 242 | if not cf["training"]["load_backbone"].get("finetune", True): 243 | for param in exp.model.bb.mil.parameters(): 244 | param.requires_grad = False 245 | exp.model.bb.mil.eval() 246 | 247 | logging.info(f"Loaded checkpoint {cf['training']['load_backbone'].get('ckpt_path', None)}") 248 | 249 | # config loggers 250 | logger = [ 251 | pl.loggers.TensorBoardLogger(save_dir=exp_root, name="tb"), 252 | pl.loggers.CSVLogger(save_dir=exp_root, name="csv") 253 | ] 254 | 255 | # config callbacks 256 | epoch_ckpt = pl.callbacks.ModelCheckpoint( 257 | dirpath=model_dir, 258 | save_top_k=-1, 259 | every_n_epochs=cf["training"]["eval_ckpt_ep_freq"], 260 | filename="ckpt-epoch{epoch}", 261 | auto_insert_metric_name=False) 262 | lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step", 263 | log_momentum=False) 264 | 265 | # create trainer 266 | trainer = pl.Trainer( 267 | accelerator="gpu", 268 | devices=-1, 269 | default_root_dir=exp_root, 270 | strategy=pl.strategies.DDPStrategy(find_unused_parameters=False, 271 | static_graph=True), 272 | logger=logger, 273 | log_every_n_steps=10, 274 | callbacks=[epoch_ckpt, lr_monitor], 275 | max_epochs=cf["training"]["num_epochs"], 276 | check_val_every_n_epoch=cf["training"]["eval_ckpt_ep_freq"], 277 | precision=cf["training"].get("amp", "32"), 278 | deterministic=cf["training"].get("deterministic", False), 279 | num_nodes=1) 280 | trainer.fit(exp, 281 | train_dataloaders=train_loader, 282 | val_dataloaders=valid_loader) 283 | 284 | 285 | if __name__ == '__main__': 286 | main() 287 | -------------------------------------------------------------------------------- /fastglioma/train/train_patch.py: -------------------------------------------------------------------------------- 1 | """Patch SSL with HiDisc pretraining script. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import yaml 8 | import logging 9 | from functools import partial 10 | from typing import Dict, Any 11 | 12 | import torch 13 | 14 | import pytorch_lightning as pl 15 | import torchmetrics 16 | 17 | from fastglioma.models.resnet import resnet_backbone 18 | from fastglioma.models.cnn import MLP, ContrastiveLearningNetwork 19 | from fastglioma.utils.common import (setup_output_dirs, parse_args, 20 | get_exp_name, config_loggers, 21 | get_optimizer_func, get_scheduler_func, 22 | get_num_worker) 23 | from fastglioma.losses.hidisc import HiDiscLoss 24 | 25 | 26 | class HiDiscSystem(pl.LightningModule): 27 | """Lightning system for hidisc experiments.""" 28 | 29 | def __init__(self, cf: Dict[str, Any], num_it_per_ep: int): 30 | super().__init__() 31 | self.cf_ = cf 32 | 33 | if "resnet" in cf["model"]["backbone"]["which"]: 34 | bb = partial( 35 | resnet_backbone, 36 | arch=cf["model"]["backbone"]["which"], 37 | num_channel_in=cf["model"]["backbone"]["params"].get( 38 | "num_channel_in", 3)) 39 | else: 40 | raise NotImplementedError() 41 | 42 | mlp = partial(MLP, 43 | n_in=bb().num_out, 44 | hidden_layers=cf["model"]["mlp_hidden"], 45 | n_out=cf["model"]["num_embedding_out"]) 46 | self.model = ContrastiveLearningNetwork(bb, mlp) 47 | 48 | if "training" in cf: 49 | crit_params = cf["training"]["objective"]["params"] 50 | self.criterion = HiDiscLoss( 51 | lambda_patient=crit_params["lambda_patient"], 52 | lambda_slide=crit_params["lambda_slide"], 53 | lambda_patch=crit_params["lambda_patch"], 54 | supcon_loss_params=crit_params["supcon_params"]) 55 | self.train_loss = torch.nn.ModuleDict({ 56 | "patient_loss": torchmetrics.MeanMetric(), 57 | "slide_loss": torchmetrics.MeanMetric(), 58 | "patch_loss": torchmetrics.MeanMetric(), 59 | "sum_loss": torchmetrics.MeanMetric() 60 | }) # yapf: disable 61 | self.val_loss = torch.nn.ModuleDict({ 62 | "patient_loss": torchmetrics.MeanMetric(), 63 | "slide_loss": torchmetrics.MeanMetric(), 64 | "patch_loss": torchmetrics.MeanMetric(), 65 | "sum_loss": torchmetrics.MeanMetric() 66 | }) #yapf: disable 67 | else: 68 | self.criterion = self.train_loss = self.val_loss = None 69 | 70 | self.num_it_per_ep_ = num_it_per_ep 71 | 72 | def forward(self, batch): 73 | im_reshaped = batch["image"].reshape(-1, *batch["image"].shape[-3:]) 74 | pred = self.model(im_reshaped) 75 | return pred.reshape(*batch["image"].shape[:4], pred.shape[-1]) 76 | 77 | def training_step(self, batch, _): 78 | im_reshaped = batch["image"].reshape(-1, *batch["image"].shape[-3:]) 79 | pred = self.model(im_reshaped) 80 | pred = pred.reshape(*batch["image"].shape[:4], pred.shape[-1]) 81 | 82 | pred_gather = self.all_gather(pred, sync_grads=True) 83 | pred_gather = pred_gather.reshape(-1, *pred_gather.shape[2:]) 84 | label_gather = self.all_gather(batch["label"]).reshape(-1, 1) 85 | 86 | losses = self.criterion(pred_gather, label_gather) 87 | 88 | bs = batch["image"][0].shape[0] * torch.cuda.device_count() 89 | log_partial = partial(self.log, 90 | on_step=True, 91 | on_epoch=True, 92 | batch_size=bs, 93 | sync_dist=True, 94 | rank_zero_only=True) 95 | for k in self.train_loss: 96 | log_partial(f"train/{k}", losses[k]) 97 | self.train_loss[k].update(losses[k], weight=bs) 98 | 99 | return losses["sum_loss"] 100 | 101 | def validation_step(self, batch, batch_idx): 102 | im_reshaped = batch["image"].reshape(-1, *batch["image"].shape[-3:]) 103 | pred = self.model(im_reshaped) 104 | pred = pred.reshape(*batch["image"].shape[:4], pred.shape[-1]) 105 | 106 | pred_gather = self.all_gather(pred, sync_grads=True) 107 | pred_gather = pred_gather.reshape(-1, *pred_gather.shape[2:]) 108 | label_gather = self.all_gather(batch["label"]).reshape(-1, 1) 109 | 110 | losses = self.criterion(pred_gather, label_gather) 111 | 112 | bs = batch["image"][0].shape[0] * torch.cuda.device_count() 113 | for k in self.val_loss: 114 | self.val_loss[k].update(losses[k], weight=bs) 115 | 116 | @torch.inference_mode() 117 | def predict_step(self, batch, batch_idx): 118 | self.model.eval() 119 | 120 | if isinstance(batch["image"], torch.Tensor): 121 | assert len(batch["image"].shape) == 4 122 | out = self.model.bb(batch["image"]) 123 | return { 124 | "path": batch["path"], 125 | "label": batch["label"], 126 | "embeddings": out 127 | } 128 | else: 129 | out = self.model.bb(batch["image"][0][0]) 130 | return { 131 | "path": batch["path"][0], 132 | "label": batch["label"][0], 133 | "embeddings": out 134 | } 135 | 136 | 137 | def on_train_epoch_end(self): 138 | for k in self.train_loss: 139 | train_loss_k = self.train_loss[k].compute() 140 | self.log(f"train/{k}_manualepoch", 141 | train_loss_k, 142 | on_epoch=True, 143 | sync_dist=True, 144 | rank_zero_only=True) 145 | logging.info(f"train/{k}_manualepoch {train_loss_k}") 146 | self.train_loss[k].reset() 147 | 148 | def on_validation_epoch_end(self): 149 | for k in self.val_loss: 150 | val_loss_k = self.val_loss[k].compute() 151 | self.log(f"val/{k}_manualepoch", 152 | val_loss_k, 153 | on_epoch=True, 154 | sync_dist=True, 155 | rank_zero_only=True) 156 | logging.info(f"val/{k}_manualepoch {val_loss_k}") 157 | self.val_loss[k].reset() 158 | 159 | def configure_ddp(self, *args, **kwargs): 160 | logging.basicConfig(level=logging.INFO) 161 | return super().configure_ddp(*args, **kwargs) 162 | 163 | def configure_optimizers(self): 164 | # if not training, no optimizer 165 | if "training" not in self.cf_: 166 | return None 167 | 168 | # get optimizer 169 | opt = get_optimizer_func(self.cf_)(self.model.parameters()) 170 | 171 | # check if use a learn rate scheduler 172 | sched_func = get_scheduler_func(self.cf_, self.num_it_per_ep_) 173 | if not sched_func: 174 | return opt 175 | 176 | # get learn rate scheduler 177 | lr_scheduler_config = { 178 | "scheduler": sched_func(opt), 179 | "interval": "step", 180 | "frequency": 1, 181 | "name": "lr" 182 | } 183 | 184 | return [opt], lr_scheduler_config 185 | 186 | from fastglioma.datasets.srh_dataset import HiDiscDataset 187 | from fastglioma.datasets.improc import get_transformations 188 | 189 | def get_dataloaders(cf): 190 | """Create dataloader for contrastive experiments.""" 191 | train_xform, valid_xform = get_transformations(cf) 192 | 193 | logging.info(f"train_xform\n{train_xform}") 194 | logging.info(f"valid_xform\n{valid_xform}") 195 | 196 | train_dset = HiDiscDataset( 197 | data_root=cf["data"]["db_root"], 198 | studies="train", 199 | transform=train_xform, 200 | balance_study_per_class=cf["data"]["balance_study_per_class"], 201 | num_slide_samples=cf["data"]["hidisc"]["num_slide_samples"], 202 | num_patch_samples=cf["data"]["hidisc"]["num_patch_samples"], 203 | num_transforms=cf["data"]["hidisc"]["num_transforms"]) 204 | val_dset = HiDiscDataset( 205 | data_root=cf["data"]["db_root"], 206 | studies="val", 207 | transform=valid_xform, 208 | balance_study_per_class=False, 209 | num_slide_samples=cf["data"]["hidisc"]["num_slide_samples"], 210 | num_patch_samples=cf["data"]["hidisc"]["num_patch_samples"], 211 | num_transforms=cf["data"]["hidisc"]["num_transforms"]) 212 | 213 | dataloader_callable = partial(torch.utils.data.DataLoader, 214 | batch_size=cf['training']['batch_size'], 215 | drop_last=False, 216 | pin_memory=True, 217 | num_workers=get_num_worker(), 218 | persistent_workers=True) 219 | 220 | return dataloader_callable(train_dset, 221 | shuffle=True), dataloader_callable(val_dset, 222 | shuffle=True) 223 | 224 | 225 | def main(): 226 | cf_fd = parse_args() 227 | cf = yaml.load(cf_fd, Loader=yaml.FullLoader) 228 | exp_root, model_dir, cp_config = setup_output_dirs(cf, get_exp_name, "") 229 | pl.seed_everything(cf["infra"]["seed"]) 230 | 231 | # logging and copying config files 232 | cp_config(cf_fd.name) 233 | config_loggers(exp_root) 234 | 235 | train_loader, valid_loader = get_dataloaders(cf) 236 | 237 | logging.info(f"num devices: {torch.cuda.device_count()}") 238 | logging.info(f"num workers in dataloader: {train_loader.num_workers}") 239 | 240 | num_it_per_ep = len(train_loader) 241 | if torch.cuda.device_count() > 1: 242 | num_it_per_ep //= torch.cuda.device_count() 243 | 244 | exp = HiDiscSystem(cf, num_it_per_ep) 245 | 246 | # config loggers 247 | logger = [ 248 | pl.loggers.TensorBoardLogger(save_dir=exp_root, name="tb"), 249 | pl.loggers.CSVLogger(save_dir=exp_root, name="csv") 250 | ] 251 | 252 | # config callbacks 253 | epoch_ckpt = pl.callbacks.ModelCheckpoint( 254 | dirpath=model_dir, 255 | save_top_k=-1, 256 | every_n_epochs=cf["training"]["eval_ckpt_ep_freq"], 257 | filename="ckpt-epoch{epoch}", 258 | auto_insert_metric_name=False) 259 | lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step", 260 | log_momentum=False) 261 | 262 | # create trainer 263 | trainer = pl.Trainer( 264 | accelerator="gpu", 265 | devices=-1, 266 | default_root_dir=exp_root, 267 | strategy=pl.strategies.DDPStrategy(find_unused_parameters=False, 268 | static_graph=True), 269 | logger=logger, 270 | log_every_n_steps=10, 271 | callbacks=[epoch_ckpt, lr_monitor], 272 | max_epochs=cf["training"]["num_epochs"], 273 | check_val_every_n_epoch=cf["training"]["eval_ckpt_ep_freq"], 274 | precision=cf["training"].get("amp", "32"), 275 | deterministic=cf["training"].get("deterministic", False), 276 | num_nodes=1) 277 | trainer.fit(exp, 278 | train_dataloaders=train_loader, 279 | val_dataloaders=valid_loader) 280 | 281 | 282 | if __name__ == '__main__': 283 | main() 284 | -------------------------------------------------------------------------------- /fastglioma/models/resnet.py: -------------------------------------------------------------------------------- 1 | """Resnet model. 2 | 3 | Adapted from torchvision. See THIRD_PARTY for third party license info. 4 | https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 5 | 6 | Copyright (c) 2024 University of Michigan. All rights reserved. 7 | Licensed under the MIT License. See LICENSE for license information. 8 | """ 9 | 10 | from typing import Type, Any, Callable, Union, List, Optional, Dict 11 | 12 | import torch 13 | from torch import nn, Tensor 14 | 15 | 16 | def conv3x3(in_planes: int, 17 | out_planes: int, 18 | stride: int = 1, 19 | groups: int = 1, 20 | dilation: int = 1) -> nn.Conv2d: 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d( 23 | in_planes, 24 | out_planes, 25 | kernel_size=3, 26 | stride=stride, 27 | padding=dilation, 28 | groups=groups, 29 | bias=False, 30 | dilation=dilation, 31 | ) 32 | 33 | 34 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 35 | """1x1 convolution""" 36 | return nn.Conv2d(in_planes, 37 | out_planes, 38 | kernel_size=1, 39 | stride=stride, 40 | bias=False) 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion: int = 1 45 | 46 | def __init__( 47 | self, 48 | inplanes: int, 49 | planes: int, 50 | stride: int = 1, 51 | downsample: Optional[nn.Module] = None, 52 | groups: int = 1, 53 | base_width: int = 64, 54 | dilation: int = 1, 55 | norm_layer: Optional[Callable[..., nn.Module]] = None, 56 | ) -> None: 57 | super().__init__() 58 | if norm_layer is None: 59 | norm_layer = nn.BatchNorm2d 60 | if groups != 1 or base_width != 64: 61 | raise ValueError( 62 | "BasicBlock only supports groups=1 and base_width=64") 63 | if dilation > 1: 64 | raise NotImplementedError( 65 | "Dilation > 1 not supported in BasicBlock") 66 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = norm_layer(planes) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.conv2 = conv3x3(planes, planes) 71 | self.bn2 = norm_layer(planes) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x: Tensor) -> Tensor: 76 | identity = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | 85 | if self.downsample is not None: 86 | identity = self.downsample(x) 87 | 88 | out += identity 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Bottleneck(nn.Module): 95 | # Bottleneck in torchvision places the stride for downsampling at 3x3 96 | # convolution(self.conv2) while original implementation places the stride 97 | # at the first 1x1 convolution(self.conv1) according to "Deep residual 98 | # learning for image recognition"https://arxiv.org/abs/1512.03385. This 99 | # variant is also known as ResNet V1.5 and improves accuracy according to 100 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 101 | 102 | expansion: int = 4 103 | 104 | def __init__( 105 | self, 106 | inplanes: int, 107 | planes: int, 108 | stride: int = 1, 109 | downsample: Optional[nn.Module] = None, 110 | groups: int = 1, 111 | base_width: int = 64, 112 | dilation: int = 1, 113 | norm_layer: Optional[Callable[..., nn.Module]] = None, 114 | ) -> None: 115 | super().__init__() 116 | if norm_layer is None: 117 | norm_layer = nn.BatchNorm2d 118 | width = int(planes * (base_width / 64.0)) * groups 119 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 120 | self.conv1 = conv1x1(inplanes, width) 121 | self.bn1 = norm_layer(width) 122 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 123 | self.bn2 = norm_layer(width) 124 | self.conv3 = conv1x1(width, planes * self.expansion) 125 | self.bn3 = norm_layer(planes * self.expansion) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.downsample = downsample 128 | self.stride = stride 129 | 130 | def forward(self, x: Tensor) -> Tensor: 131 | identity = x 132 | 133 | out = self.conv1(x) 134 | out = self.bn1(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv2(out) 138 | out = self.bn2(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv3(out) 142 | out = self.bn3(out) 143 | 144 | if self.downsample is not None: 145 | identity = self.downsample(x) 146 | 147 | out += identity 148 | out = self.relu(out) 149 | 150 | return out 151 | 152 | 153 | class ResNetBackbone(nn.Module): 154 | """A ResNet backbone model. 155 | 156 | ResNet architecture based on torchvision implementation. It does not 157 | include the dense fc layer at the end. The forward function returns the 158 | final latent representations. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | num_channel_in: int, 164 | in_planes: int, 165 | layer_planes: List[int], 166 | block: Type[Union[BasicBlock, Bottleneck]], 167 | layers: List[int], 168 | zero_init_residual: bool = False, 169 | groups: int = 1, 170 | width_per_group: int = 64, 171 | replace_stride_with_dilation: Optional[List[bool]] = None, 172 | norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: 173 | 174 | super(ResNetBackbone, self).__init__() 175 | if norm_layer is None: 176 | norm_layer = nn.BatchNorm2d 177 | self._norm_layer = norm_layer 178 | self.inplanes = in_planes 179 | self.dilation = 1 180 | if replace_stride_with_dilation is None: 181 | # each element in the tuple indicates if we should replace 182 | # the 2x2 stride with a dilated convolution instead 183 | replace_stride_with_dilation = [False, False, False] 184 | if len(replace_stride_with_dilation) != 3: 185 | raise ValueError("replace_stride_with_dilation should be None " 186 | "or a 3-element tuple, got {}".format( 187 | replace_stride_with_dilation)) 188 | self.groups = groups 189 | self.base_width = width_per_group 190 | 191 | # ---------------------------------------------------------------------- 192 | # make layers 193 | self.conv1 = nn.Conv2d(num_channel_in, self.inplanes, kernel_size=7, 194 | stride=2, padding=3, bias=False) # yapf: disable 195 | self.bn1 = norm_layer(self.inplanes) 196 | self.relu = nn.ReLU(inplace=True) 197 | 198 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 199 | self.layer1 = self._make_layer(block, layer_planes[0], layers[0]) 200 | 201 | self.layer2 = self._make_layer(block, layer_planes[1], layers[1], 202 | stride=2, dilate=replace_stride_with_dilation[0]) # yapf: disable 203 | 204 | self.layer3 = self._make_layer(block, layer_planes[2], layers[2], 205 | stride=2, dilate=replace_stride_with_dilation[1]) # yapf: disable 206 | 207 | self.layer4 = self._make_layer(block, layer_planes[3], layers[3], 208 | stride=2, dilate=replace_stride_with_dilation[2]) # yapf: disable 209 | 210 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 211 | self.num_out = block.expansion * layer_planes[-1] 212 | # ---------------------------------------------------------------------- 213 | # init layers 214 | for m in self.modules(): 215 | if isinstance(m, nn.Conv2d): 216 | nn.init.kaiming_normal_(m.weight, 217 | mode='fan_out', 218 | nonlinearity='relu') 219 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 220 | nn.init.constant_(m.weight, 1) 221 | nn.init.constant_(m.bias, 0) 222 | 223 | # Zero-initialize the last BN in each residual branch, 224 | # so that the residual branch starts with zeros, and each 225 | # residual block behaves like an identity. 226 | # This improves the model by 0.2~0.3% according to 227 | # https://arxiv.org/abs/1706.02677 228 | if zero_init_residual: 229 | for m in self.modules(): 230 | if isinstance(m, Bottleneck): 231 | nn.init.constant_(m.bn3.weight, 0) # type:ignore[arg-type] 232 | elif isinstance(m, BasicBlock): 233 | nn.init.constant_(m.bn2.weight, 0) # type:ignore[arg-type] 234 | 235 | def _make_layer(self, 236 | block: Type[Union[BasicBlock, Bottleneck]], 237 | planes: int, 238 | blocks: int, 239 | stride: int = 1, 240 | dilate: bool = False) -> nn.Sequential: 241 | norm_layer = self._norm_layer 242 | downsample = None 243 | previous_dilation = self.dilation 244 | if dilate: 245 | self.dilation *= stride 246 | stride = 1 247 | if stride != 1 or self.inplanes != planes * block.expansion: 248 | downsample = nn.Sequential( 249 | conv1x1(self.inplanes, planes * block.expansion, stride), 250 | norm_layer(planes * block.expansion), 251 | ) 252 | 253 | layers = [] 254 | layers.append( 255 | block(self.inplanes, planes, stride, downsample, self.groups, 256 | self.base_width, previous_dilation, norm_layer)) 257 | self.inplanes = planes * block.expansion 258 | for _ in range(1, blocks): 259 | layers.append( 260 | block(self.inplanes, 261 | planes, 262 | groups=self.groups, 263 | base_width=self.base_width, 264 | dilation=self.dilation, 265 | norm_layer=norm_layer)) 266 | 267 | return nn.Sequential(*layers) 268 | 269 | def _forward_impl(self, x: Tensor) -> Dict: 270 | x0 = self.conv1(x) 271 | x0 = self.bn1(x0) 272 | x0 = self.relu(x0) 273 | x1 = self.maxpool(x0) 274 | 275 | x2 = self.layer1(x1) 276 | x3 = self.layer2(x2) 277 | x4 = self.layer3(x3) 278 | x5 = self.layer4(x4) 279 | 280 | x6 = self.avgpool(x5) 281 | x6 = torch.flatten(x6, 1) 282 | 283 | return x6 284 | 285 | def forward(self, x: Tensor) -> Dict: 286 | return self._forward_impl(x) 287 | 288 | 289 | def resnet_backbone(arch: str = 'resnet50', 290 | num_channel_in: int = 3, 291 | in_planes: int = 64, 292 | layer_planes: List[int] = [64, 128, 256, 512], 293 | **kwargs: Any) -> ResNetBackbone: 294 | """Creates a resnet backbone.""" 295 | blocks = { 296 | 'resnet18': BasicBlock, 297 | 'resnet34': BasicBlock, 298 | 'resnet50': Bottleneck, 299 | 'resnet101': Bottleneck, 300 | 'resnet152': Bottleneck, 301 | } 302 | 303 | layers = { 304 | 'resnet18': [2, 2, 2, 2], 305 | 'resnet34': [3, 4, 6, 3], 306 | 'resnet50': [3, 4, 6, 3], 307 | 'resnet101': [3, 4, 23, 3], 308 | 'resnet152': [3, 8, 36, 3], 309 | } 310 | 311 | return ResNetBackbone(num_channel_in=num_channel_in, 312 | in_planes=in_planes, 313 | layer_planes=layer_planes, 314 | block=blocks[arch], 315 | layers=layers[arch], 316 | **kwargs) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastGlioma: foundation models for fast, label-free detection of glioma infiltration 2 | 3 | [**Paper**](https://www.nature.com/articles/s41586-024-08169-3) / 4 | [**Interactive Demo**](https://fastglioma.mlins.org) / 5 | [**Models**](https://huggingface.co/mlinslab/fastglioma) / 6 | [**MLiNS Lab**](https://mlins.org) 7 | 8 | Code repository for our paper 'Foundation models for fast, label-free detection of glioma infiltration.' We employ a foundational model training strategy to predict the degree of diffuse glioma infiltration intraoperatively using stimulated Raman histology and deep learning. 9 | 10 | ## TL;DR 11 | 12 | *Image tumor with **Fast SRH** >> **FastGlioma** >> Degree of tumor infiltration out* (end-to-end: ~10 seconds) 13 | 14 | ## Abstract 15 | 16 | A critical challenge in glioma treatment is detecting tumor infiltration during surgery to achieve safe maximal resection. Unfortunately, safely resectable residual tumor is found in the majority of glioma patients after surgery, causing early recurrence and decreased survival. We present **FastGlioma**, a visual foundation model for fast (<10 seconds) and accurate detection of glioma infiltration in fresh, unprocessed surgical tissue. FastGlioma was pretrained using large-scale self-supervision (∼4 million images) on rapid, label-free, optical microscopy, and fine-tuned to output a normalized score that indicates the degree of tumor infiltration within whole slide optical images. In a prospective, multicenter, international testing cohort of diffuse glioma patients (n=220), FastGlioma was able to detect and quantify the degree of tumor infiltration with an average area under the ROC curve of 92.1 ± 0.9\%. FastGlioma outperformed image-guided and fluorescence-guided adjuncts for detecting tumor infiltration during surgery by a wide margin in a head-to-head, prospective study (n=129). FastGlioma performance remained high across diverse patient demographics, medical centers, and diffuse glioma molecular subtypes as defined by the World Health Organization (WHO). FastGlioma shows zero-shot generalization to other adult and pediatric brain tumor diagnoses, demonstrating the potential for our foundation model to serve as a general-purpose adjunct for guiding brain tumor surgeries. These findings represent the transformative potential of medical foundation models to unlock the role of artificial intelligence in the care of cancer patients. 17 | 18 | ## Intended Use 19 | *FastGlioma is for investigational use only*. FastGlioma is intended for patients who have adult-type diffuse gliomas as defined by the World Health Organization (WHO). These include: 20 | 21 |   1. Astrocytoma, IDH-mutant 22 | 23 |   2. Oligodendroglioma, IDH-mutant, and 1p/19q-codeleted 24 | 25 |   3. Glioblastoma, IDH-wildtype 26 | 27 | Study neurosurgeons were allowed to include patients based on (1) a previous pathologic diagnosis of adult-type diffuse glioma or (2) high likelihood of adult-type diffuse glioma diagnosis based on clinical presentation and radiographic features. Intraoperative pathologic diagnosis via frozen sectioning or SRH imaging was completed in the majority of patients to provide further preliminary evidence of diffuse glioma diagnosis prior to margin sampling for FastGlioma. While our preliminary data show good zero-shot performance on a variety of other tumors and clinical settings, FastGlioma is **not** intended for surgical resection guidance around eloquent cortical or subcortical structures, pediatric patients, non-primary brain tumors, or non-neoplastic pathologic tissue. 28 | 29 | FastGlioma was trained using ordinal labels that correspond four increasing degrees of tumor infiltration: 0, 1, 2, or 3. However, because tumor infiltration is a continuous variable, FastGlioma outputs a continuous normalized score between 0-1 to indicate the degree of tumor infiltration. Based on training and testing results, we recommend guidelines regarding FastGlioma scores: 30 | 31 | | Pathologists Score | FastGlioma range | Interpretation | 32 | |----------|----------|----------| 33 | | Score 0 | 0-25% | Normal or non-neoplastic tissue | 34 | | Score 1 | 26-50% | Atypical cells, cannot rule out tumor | 35 | | Score 2 | 51-85% | Sparse tumor infiltration | 36 | | Score 3 | 86-100% | Dense tumor infiltration | 37 | 38 | Please note that the nontumor-tumor threshold corresponds to a FastGlioma score of 50%. We hope to provide surgeon's with real-time, accurate, and clinically actionable diagnostic information. Ultimately, the decision to resect additional tissue we leave to the operating surgeon and the clinical context. 39 | 40 | ## Overview 41 | 42 | ![Overview](/figures/Figure_1.png) 43 | 44 | **FastGlioma workflow.** A patient with a suspected diffuse glioma undergoes surgical resection. During tumor resection, the surgeon samples tissue from the surgical margin. The portable SRH imaging system acquires microscopic images in the operating room, performed by a single technician using simple touchscreen instructions. A freshly excised surgical specimen is loaded directly into a custom microscope slide and inserted into the SRH imager without the need for tissue processing. Additional details on image acquisition can be found in Extended Data Fig. 1. SRH images can be virtually stained using an H&E-like colorscheme for clinician review as shown above. A whole slide SRH image is divided into patches and each patch undergoes a feedforward pass through a patch tokenizer (Extended Data Fig. 3a). The patch tokens, plus an appended classification token , are then input into a whole slide SRH encoder that is a vision transformer. The patch tokenizer and whole slide encoder are pretrained as a visual foundation model using large-scale self-supervision (Extended Data Fig. 3b). For tumor infiltration scoring, a slide scorer model is fine-tuned to output a normalized continuous score between 0-1 that predicts the degree of tumor infiltration within the whole slide image that corresponds to a 4-tier whole slide ordinal infiltration scale as defined by expert neuropathologists (Extended Data Fig. 2 and 4). Ordinal labels are weak because they apply to the slide-level only. Despite the weak labels, FastGlioma provides regional interpretability by identifying area within whole slides SRH images with high probability of tumor infiltration. Scale bars, 100 microns. 45 | 46 | ## Results 47 | 48 | ![Results](/figures/Figure_2.png) 49 | 50 | **FastGlioma performance.** a, Prediction results for the full prospective, international, multicenter testing cohort of diffuse gliomas patients (n = 220) are shown. ROC curves (plotted as mean ± s.d.) show average performance for predicting four levels of tumor infiltration. See Extended Data Fig 6 for subgroup analysis. SRH foundation model pretraining showed strong prediction performance without fine-tuning. FastGlioma that included fine-tuning with ordinal metric learning had a 3.2\% increase in overall performance. FastGlioma outperforms models trained using standard supervised training (84.7 ± 1.1\% mAUC) as shown in Supplementary Data Table 4. b, Box and whisker plots, shown in the standardized quartile format, of FastGlioma infiltration scores by ground truth value are shown. Scores had strong correlation with ground truth ordinal scores (ρ = 0.77 95\% confidence interval 0.74-0.78). Individual scores are shown in a histogram and correspond to AUROC values in 2a. c, FastGlioma performance on full resolution versus low resolution SRH images is shown (plotted as mean ± s.d.). FastGlioma allows for 10X increase in imaging speed with minimal performance tradeoff. d, Whole slide SRH representations are plotted on a linear discriminant axis. FastGlioma learned representations that rank whole slide SRH images on a near-linear tumor infiltration axis. e, Subgroup analysis by WHO adult-type diffuse glioma subtypes (ROC curves plotted as mean ± s.d.). FastGlioma performs well across all three adult-type diffuse gliomas. Importantly, FastGlioma performs well on lower grade gliomas where tumor infiltration and tissue cellularity can be low (Extended Data Fig. 7). Low grade and lower tumor infiltration are major challenges for other surgical adjuncts, such as fluorescence-guided surgery. 51 | 52 | © This code is made available for academic purposes. Imaging and clinical information for this project was collected with IRB approval (HUM00083059) and is protected under HIPAA. Representative images and predictions can be found at [**fastglioma.mlins.org**](https://fastglioma.mlins.org). 53 | 54 | # Training, evaluation, and inference 55 | 56 | This repository currently supports inference on the [OpenSRH dataset](https://opensrh.mlins.org/), the largest publically available stimulated Raman histology dataset, with FastGlioma models available on [HuggingFace](https://huggingface.co/mlinslab/fastglioma/), as well as training/evaluation scripts for developing your own self-supervised SRH foundation models. 57 | 58 | ## Directory organization 59 | ``` 60 | fastglioma/ 61 | ├── fastglioma/ # Library for FastGlioma training 62 | │ ├── datasets/ # PyTorch OpenSRH datasets 63 | │ ├── losses/ # FastGlioma loss functions with contrastive/ordinal metric learning 64 | │ ├── models/ # PyTorch models for training, evaluation, and inference 65 | │ ├── utils/ # Utility functions 66 | │ ├── training/ # Training scripts 67 | │ ├── eval/ # Evaluation scripts 68 | │ ├── inference/ # Inference scripts 69 | │ ├── tf/ # TensorFlow implementation/inference scripts 70 | ├── figures/ # Figures in the README file 71 | ├── THIRD_PARTY # License information for third party code 72 | ├── setup.py # Setup file including list of dependencies 73 | ├── LICENSE # MIT license for the repo 74 | └── README.md 75 | ``` 76 | 77 | ## Installation 78 | 79 | 1. Clone FastGlioma github repo 80 | ```console 81 | git clone git@github.com:MLNeurosurg/fastglioma.git 82 | ``` 83 | 2. Install miniconda: follow instructions 84 | [here](https://docs.conda.io/en/latest/miniconda.html) 85 | 3. Create conda environment 86 | ```console 87 | conda create -n fastglioma python=3.9 88 | ``` 89 | 4. Activate conda environment 90 | ```console 91 | conda activate fastglioma 92 | ``` 93 | 5. Install package and dependencies 94 | ```console 95 | 96 | pip install -e . 97 | ``` 98 | 99 | ## Dataset and Models 100 | 101 | The OpenSRH dataset and FastGlioma models are available for non-commerical use. Please download the OpenSRH dataset from the [OpenSRH website](https://opensrh.mlins.org/) according to the instructions provided. Additionally, please request access to FastGlioma models on [Hugging Face](https://huggingface.co/mlinslab/fastglioma/). 102 | 103 | ## Inference 104 | 105 | 1. Log into Hugging Face and navigate to the inference directory 106 | ```console 107 | huggingface-cli login 108 | cd fastglioma/inference 109 | ``` 110 | 2. Specify inference configuration file 111 | ```console 112 | vi config/infer.yaml 113 | ``` 114 | 3. Generate predictions 115 | ```console 116 | python run_inference.py -c config/infer.yaml 117 | ``` 118 | 119 | ## Training and evaluation 120 | 121 | 1. Train SRH patch and whole-slide foundation models 122 | ```console 123 | cd fastglioma/ 124 | 125 | # Train/evaluate patch tokenizer 126 | python train/train_patch.py -c train/config/train_hidisc.yaml 127 | python eval/eval_knn.py -c eval/config/eval_hidisc.yaml 128 | 129 | # Save patch embeddings 130 | python eval/save_embedding.py -c eval/config/save_hidisc.yaml 131 | 132 | # Train/evaluate slide transformer 133 | python train/train_slide.py -c train/config/train_scm.yaml 134 | python eval/eval_knn.py -c eval/config/eval_scm.yaml 135 | ``` 136 | 137 | 2. Fine-tune slide foundation model for tumor infiltration scoring 138 | ```console 139 | python train/train_scorer.py -c train/config/train_ordmet.yaml 140 | ``` 141 | 142 | Note that the OpenSRH dataset only includes patient-level annotations. For training the slide scorer, we include a `slide_class` key for each slide in the metadata, corresponding to the ordinal tumor infiltration label. This can be toggled by setting the `use_patient_class` flag to `False` in the config file. 143 | ```console 144 | "NIO_001": { 145 | "patient_id": "NIO_001", 146 | "class": "hgg", 147 | "slides": { 148 | "1": { 149 | "slide_id": "1", 150 | "slide_class": "0", # added slide-level tumor infiltration label from 0-3 151 | "tumor_patches": ... 152 | }, 153 | ... 154 | } 155 | } 156 | ``` 157 | 158 | ## License Information 159 | The code is licensed under the MIT License. 160 | See LICENSE for license information and THIRD_PARTY for third party notices. -------------------------------------------------------------------------------- /fastglioma/tf/transformer.py: -------------------------------------------------------------------------------- 1 | """TransformerMIL in TensorFlow. 2 | 3 | Matches PyTorch implementation exactly. 4 | 5 | Copyright (c) 2024 University of Michigan. All rights reserved. 6 | Licensed under the MIT License. See LICENSE for license information. 7 | """ 8 | 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | import numpy as np 12 | import logging 13 | 14 | 15 | def gelu(x): 16 | cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) 17 | return x * cdf 18 | 19 | 20 | class FFPEG(keras.layers.Layer): 21 | 22 | def __init__(self, 23 | embed_dim, 24 | dim_ff=96, 25 | dim_mlp=36, 26 | gamma=0.25, 27 | prefix_len=0, 28 | pos_emb_grad=True, 29 | **kwargs): 30 | super().__init__() 31 | self.dim_ff_ = dim_ff 32 | self.dim_mlp_ = dim_mlp 33 | self.gamma_ = gamma 34 | self.embed_dim_ = embed_dim 35 | self.prefix_len = prefix_len 36 | 37 | # Equivalent to PyTorch's nn.Parameter 38 | self.cls_pos_emb = self.add_weight(shape=(1, prefix_len, embed_dim), 39 | initializer='zeros', 40 | trainable=True, 41 | name='cls_pos_emb') 42 | 43 | self.num_pos_ = 1 44 | self.pos_dim_ = 2 45 | 46 | self._ff_embed = keras.layers.Dense( 47 | dim_ff // 2, 48 | use_bias=False, 49 | kernel_initializer=keras.initializers.RandomNormal(mean=0., 50 | stddev=gamma)) 51 | 52 | if not pos_emb_grad: 53 | self._ff_embed.trainable = False 54 | 55 | self._mlp = keras.Sequential([ 56 | keras.layers.LayerNormalization(epsilon=1e-5), 57 | keras.layers.Dense(dim_mlp), 58 | keras.layers.Lambda(gelu), # Custom GELU implementation 59 | keras.layers.LayerNormalization(epsilon=1e-5), 60 | keras.layers.Dense(embed_dim // self.num_pos_) 61 | ]) 62 | 63 | def call(self, H, coords, return_ff=False): 64 | bsz = tf.shape(H)[0] 65 | n = tf.shape(H)[1] - self.prefix_len 66 | 67 | x = tf.expand_dims(tf.expand_dims(tf.cast(coords, tf.float32), 0), -2) 68 | 69 | ff_vec = self._ff_embed(x) 70 | 71 | f = tf.concat([tf.cos(ff_vec), tf.sin(ff_vec)], axis=-1) 72 | f = f / tf.sqrt(float(self.dim_ff_)) 73 | 74 | if return_ff: 75 | return f 76 | 77 | pe = self._mlp(f) 78 | pe = tf.reshape(pe, [bsz, n, self.embed_dim_]) 79 | pe = tf.concat([tf.repeat(self.cls_pos_emb, bsz, axis=0), pe], axis=1) 80 | 81 | return H + pe 82 | 83 | 84 | class Attention(keras.layers.Layer): 85 | 86 | def __init__(self, 87 | dim, 88 | num_heads=8, 89 | qkv_bias=False, 90 | attn_drop=0., 91 | proj_drop=0.): 92 | super().__init__() 93 | self.num_heads = num_heads 94 | self.dim = dim 95 | self.head_dim = dim // num_heads 96 | self.scale = self.head_dim**-0.5 97 | 98 | # Initialize weights with same distribution as PyTorch 99 | self.qkv = keras.layers.Dense(dim * 3, 100 | use_bias=qkv_bias, 101 | kernel_initializer='glorot_uniform', 102 | bias_initializer='zeros') 103 | self.attn_drop = keras.layers.Dropout(attn_drop) 104 | self.proj = keras.layers.Dense(dim, 105 | kernel_initializer='glorot_uniform', 106 | bias_initializer='zeros') 107 | self.proj_drop = keras.layers.Dropout(proj_drop) 108 | 109 | def call(self, x, training=False): 110 | B = tf.shape(x)[0] 111 | N = tf.shape(x)[1] 112 | C = self.dim 113 | 114 | # Match PyTorch's exact reshape and permute operations 115 | qkv = self.qkv(x) # (B, N, 3*C) 116 | qkv = tf.reshape(qkv, [B, N, 3, self.num_heads, C // self.num_heads]) 117 | qkv = tf.transpose( 118 | qkv, [2, 0, 3, 1, 4]) # (3, B, num_heads, N, C//num_heads) 119 | q, k, v = qkv[0], qkv[1], qkv[2] 120 | 121 | # Exact PyTorch attention computation order 122 | attn = tf.matmul(q, k, transpose_b=True) # (B, num_heads, N, N) 123 | attn = attn * self.scale 124 | attn = tf.nn.softmax(attn, axis=-1) 125 | attn = self.attn_drop(attn, training=training) 126 | 127 | x = tf.matmul(attn, v) # (B, num_heads, N, C//num_heads) 128 | x = tf.transpose(x, [0, 2, 1, 3]) # (B, N, num_heads, C//num_heads) 129 | x = tf.reshape(x, [B, N, C]) 130 | 131 | x = self.proj(x) 132 | x = self.proj_drop(x, training=training) 133 | 134 | return x, attn 135 | 136 | 137 | class Block(keras.layers.Layer): 138 | 139 | def __init__(self, 140 | dim, 141 | num_heads, 142 | mlp_ratio=4., 143 | qkv_bias=False, 144 | drop=0., 145 | attn_drop=0., 146 | **kwargs): 147 | super().__init__() 148 | self.norm1 = keras.layers.LayerNormalization( 149 | epsilon=1e-5) # Match PyTorch epsilon 150 | self.attn = Attention(dim=dim, 151 | num_heads=num_heads, 152 | qkv_bias=qkv_bias, 153 | attn_drop=attn_drop, 154 | proj_drop=drop) 155 | 156 | self.norm2 = keras.layers.LayerNormalization( 157 | epsilon=1e-5) # Match PyTorch epsilon 158 | mlp_hidden_dim = int(dim * mlp_ratio) 159 | 160 | # Use custom GELU that matches PyTorch exactly 161 | self.mlp = keras.Sequential([ 162 | keras.layers.Dense(mlp_hidden_dim, 163 | kernel_initializer='glorot_uniform', 164 | bias_initializer='zeros'), 165 | keras.layers.Lambda(gelu), # Custom GELU implementation 166 | keras.layers.Dropout(drop), 167 | keras.layers.Dense(dim, 168 | kernel_initializer='glorot_uniform', 169 | bias_initializer='zeros'), 170 | keras.layers.Dropout(drop) 171 | ]) 172 | 173 | def call(self, x, training=False, return_attention=False): 174 | norm_x = self.norm1(x) 175 | y, attn = self.attn(norm_x, training=training) 176 | if return_attention: 177 | return attn 178 | 179 | x = x + y 180 | residual = x 181 | 182 | x = self.mlp(self.norm2(x), training=training) 183 | x = residual + x 184 | 185 | return x 186 | 187 | 188 | class TransformerMIL(keras.Model): 189 | 190 | def __init__(self, 191 | global_pool='token', 192 | embed_dim=768, 193 | depth=12, 194 | num_heads=12, 195 | mlp_ratio=4., 196 | qkv_bias=False, 197 | pos_emb_type=None, 198 | drop_rate=0., 199 | attn_drop_rate=0., 200 | **kwargs): 201 | super().__init__() 202 | self.global_pool = global_pool 203 | assert self.global_pool in ['', 'avg', 'token'] 204 | 205 | self.cls_token = self.add_weight(shape=(1, kwargs.get("prefix_len", 206 | 1), embed_dim), 207 | initializer='zeros', 208 | trainable=True, 209 | name='cls_token') 210 | 211 | self.pos_embed = keras.layers.Lambda(lambda x: x) 212 | if pos_emb_type: 213 | self.pos_embed = FFPEG(embed_dim=embed_dim, **kwargs) 214 | 215 | self.blocks = [ 216 | Block(dim=embed_dim, 217 | num_heads=num_heads, 218 | mlp_ratio=mlp_ratio, 219 | qkv_bias=qkv_bias, 220 | drop=drop_rate, 221 | attn_drop=attn_drop_rate) for _ in range(depth) 222 | ] 223 | 224 | self.norm = keras.layers.LayerNormalization(epsilon=1e-5) 225 | self.dim_out = embed_dim 226 | 227 | def call(self, x, coords=None, training=False): 228 | if len(tf.shape(x)) == 2: 229 | x = tf.expand_dims(x, 0) 230 | 231 | batch_size = tf.shape(x)[0] 232 | x = tf.concat([tf.tile(self.cls_token, [batch_size, 1, 1]), x], axis=1) 233 | 234 | if coords is not None: 235 | x = self.pos_embed(x, coords) 236 | else: 237 | x = self.pos_embed(x) 238 | 239 | for block in self.blocks: 240 | x = block(x, training=training) 241 | 242 | x = self.norm(x) 243 | 244 | if self.global_pool: 245 | x = tf.reduce_mean( 246 | x[:, 1:], axis=1) if self.global_pool == 'avg' else x[:, 0] 247 | 248 | return x 249 | 250 | 251 | def convert_pytorch_transformer_to_tf(pytorch_model, tf_model): 252 | # Build model first 253 | batch_size = 1 254 | n_patches = 16 255 | dummy_input = tf.zeros((batch_size, n_patches, pytorch_model.dim_out), 256 | dtype=tf.float32) 257 | dummy_coords = tf.zeros((1, n_patches, 2), dtype=tf.float32) 258 | _ = tf_model(dummy_input, coords=dummy_coords) 259 | 260 | state_dict = pytorch_model.state_dict() 261 | tf_weights = [] 262 | 263 | # 1. cls_pos_emb (1, 8, 512) 264 | tf_weights.append(state_dict['pos_embed.cls_pos_emb'].numpy()) 265 | 266 | # 2. ffpeg/dense/kernel (2, 48) 267 | tf_weights.append(state_dict['pos_embed._ff_embed.weight'].numpy().transpose()) #yapf:disable 268 | 269 | # 3-10. FFPEG MLP weights 270 | tf_weights.append(state_dict['pos_embed._mlp.0.weight'].numpy()) # layer_norm gamma (96,) #yapf:disable 271 | tf_weights.append(state_dict['pos_embed._mlp.0.bias'].numpy()) # layer_norm beta (96,) #yapf:disable 272 | tf_weights.append(state_dict['pos_embed._mlp.1.weight'].numpy().transpose()) # dense_1 kernel (96, 36) #yapf:disable 273 | tf_weights.append(state_dict['pos_embed._mlp.1.bias'].numpy()) # dense_1 bias (36,) #yapf:disable 274 | tf_weights.append(state_dict['pos_embed._mlp.3.weight'].numpy()) # layer_norm_1 gamma (36,) #yapf:disable 275 | tf_weights.append(state_dict['pos_embed._mlp.3.bias'].numpy()) # layer_norm_1 beta (36,) #yapf:disable 276 | tf_weights.append(state_dict['pos_embed._mlp.4.weight'].numpy().transpose()) # dense_2 kernel (36, 512) #yapf:disable 277 | tf_weights.append(state_dict['pos_embed._mlp.4.bias'].numpy()) # dense_2 bias (512,) #yapf:disable 278 | 279 | # Map transformer blocks 280 | for i in range(len(pytorch_model.blocks)): 281 | # Layer norm 2/4 (512,) 282 | tf_weights.append(state_dict[f'blocks.{i}.norm1.weight'].numpy()) 283 | tf_weights.append(state_dict[f'blocks.{i}.norm1.bias'].numpy()) 284 | 285 | # Attention weights 286 | qkv_weight = state_dict[f'blocks.{i}.attn.qkv.weight'].numpy() 287 | tf_weights.append(qkv_weight.transpose()) # dense_3/7 kernel (512, 1536) #yapf:disable 288 | 289 | proj_weight = state_dict[f'blocks.{i}.attn.proj.weight'].numpy() 290 | tf_weights.append(proj_weight.transpose()) # dense_4/8 kernel (512, 512) #yapf:disable 291 | tf_weights.append(state_dict[f'blocks.{i}.attn.proj.bias'].numpy()) # dense_4/8 bias (512,) #yapf:disable 292 | 293 | # Layer norm 3/5 (512,) 294 | tf_weights.append(state_dict[f'blocks.{i}.norm2.weight'].numpy()) 295 | tf_weights.append(state_dict[f'blocks.{i}.norm2.bias'].numpy()) 296 | 297 | # MLP weights 298 | tf_weights.append(state_dict[f'blocks.{i}.mlp.fc1.weight'].numpy().transpose()) # dense_5/9 kernel (512, 2048) #yapf:disable 299 | tf_weights.append(state_dict[f'blocks.{i}.mlp.fc1.bias'].numpy()) # dense_5/9 bias (2048,) #yapf:disable 300 | tf_weights.append(state_dict[f'blocks.{i}.mlp.fc2.weight'].numpy().transpose()) # dense_6/10 kernel (2048, 512) #yapf:disable 301 | tf_weights.append(state_dict[f'blocks.{i}.mlp.fc2.bias'].numpy()) # dense_6/10 bias (512,) #yapf:disable 302 | 303 | # Final layer norm (512,) 304 | tf_weights.append(state_dict['norm.weight'].numpy()) 305 | tf_weights.append(state_dict['norm.bias'].numpy()) 306 | 307 | # cls_token (1, 8, 512) 308 | tf_weights.append(state_dict['cls_token'].numpy()) 309 | 310 | # Verify shapes match 311 | logging.debug("\nVerifying weight shapes:") 312 | for i, (w, tw) in enumerate(zip(tf_model.weights, tf_weights)): 313 | logging.debug(f"{i}: {w.name} - Expected: {w.shape}, Got: {tw.shape}") 314 | assert w.shape == tw.shape, f"Shape mismatch for {w.name}: Expected {w.shape}, got {tw.shape}" 315 | 316 | tf_model.set_weights(tf_weights) 317 | return tf_model 318 | -------------------------------------------------------------------------------- /fastglioma/datasets/emb_proc.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | from functools import partial 4 | from typing import List, Tuple, Dict, Optional, Any 5 | import math 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch.nn import ModuleList 11 | 12 | import torchvision 13 | from torchvision.transforms import Compose, RandomApply 14 | from torchvision.transforms import functional as F 15 | 16 | from torchvision.transforms.transforms import _setup_angle, _check_sequence_input 17 | from torch import Tensor 18 | 19 | 20 | def emb_collate_fn(batch): 21 | return { 22 | "embeddings": [data['embeddings'] for data in batch], 23 | "label": torch.stack([data['label'] for data in batch]), 24 | "path": [data['path'] for data in batch], 25 | "coords": [data['coords'] for data in batch] 26 | } 27 | 28 | 29 | def get_emb_transformations( 30 | cf: Optional[Dict] = None) -> Tuple[Compose, Compose]: 31 | 32 | if cf: 33 | train_dict = cf["data"]["train_augmentation"] 34 | valid_dict = cf["data"]["valid_augmentation"] 35 | aug_prob = cf["data"]["rand_aug_prob"] 36 | else: 37 | train_dict = [] 38 | valid_dict = [] 39 | aug_prob = 0 40 | 41 | if valid_dict == "same": 42 | valid_dict = train_dict 43 | 44 | rand_apply_p = lambda which, **kwargs: RandomApply( 45 | ModuleList([which(**kwargs)]), p=aug_prob) 46 | rand_apply = lambda which, p, **kwargs: RandomApply( 47 | ModuleList([which(**kwargs)]), p=p) 48 | 49 | callable_dict = { 50 | "random_splitting": partial(rand_apply_p, NViewRandomSplitting), 51 | "random_masking": partial(rand_apply_p, NViewRandomMasking), 52 | "random_cropping": partial(rand_apply_p, NViewRandomCropping), 53 | } 54 | 55 | def validate_aug_names(aug_cf): 56 | aug_names = [a["which"] for a in aug_cf] 57 | aug_set = set(aug_names) 58 | if ("random_splitting" in aug_set) or ("random_partitioning" 59 | in aug_set): 60 | assert ( 61 | ("random_splitting" in aug_set) ^ 62 | ("random_partitioning" in aug_set) 63 | ), "random_splitting and random_partitioning are mutually exclusive" 64 | 65 | validate_aug_names(train_dict) 66 | validate_aug_names(valid_dict) 67 | 68 | train_xform = Compose( 69 | [callable_dict[a["which"]](**a["params"]) for a in train_dict]) 70 | valid_xform = Compose( 71 | [callable_dict[a["which"]](**a["params"]) for a in valid_dict]) 72 | 73 | logging.info(f"train_xform\n{train_xform}") 74 | logging.info(f"valid_xform\n{valid_xform}") 75 | 76 | return train_xform, valid_xform 77 | 78 | 79 | class NViewRandomSplitting(torch.nn.Module): 80 | """Randomly split all the tokens into N random views.""" 81 | 82 | def __init__(self, 83 | masking_ratio: List[float] = [0.7, 0.3], 84 | fixed_order=False): 85 | super().__init__() 86 | assert sum(masking_ratio) == 1 87 | self.n_views_ = len(masking_ratio) 88 | self.splitting_ratio_ = torch.tensor(masking_ratio) 89 | 90 | self.idx_f_ = self.get_two_idx if self.n_views_ == 2 else self.get_n_idx 91 | self.asgmt_order_f_ = torch.arange if fixed_order else torch.randperm 92 | 93 | def get_n_idx(self, length): 94 | elt_sizes = (self.splitting_ratio_ * length).floor() 95 | elt_sizes[-1] = length - elt_sizes[:-1].sum() 96 | end = elt_sizes.to(int).cumsum(dim=0) 97 | start = torch.hstack((torch.tensor([0]), end[:-1])) 98 | all_idx = torch.randperm(length) 99 | return [all_idx[i:j] for i, j in zip(start, end)] 100 | 101 | def get_two_idx(self, length): 102 | sample_size = int(self.splitting_ratio_[0] * length) 103 | all_idx = torch.randperm(length) 104 | return [all_idx[:sample_size], all_idx[sample_size:]] 105 | 106 | def forward(self, inst: Dict[str, Any]): 107 | length = len(inst['embeddings'][0]) 108 | assert len(inst['embeddings']) == self.n_views_, ( 109 | f"length of embedding is {len(inst['embeddings'])}, " + 110 | f"n_views is {self.n_views_}") 111 | idxs = self.idx_f_(length) 112 | asgmt_order = self.asgmt_order_f_(self.n_views_) 113 | 114 | inst["embeddings"] = [ 115 | emb[idxs[asgmt_order[i]], :] 116 | for i, emb in enumerate(inst["embeddings"]) 117 | ] 118 | inst["coords"] = [ 119 | emb[idxs[asgmt_order[i]], :] 120 | for i, emb in enumerate(inst["coords"]) 121 | ] 122 | 123 | return inst 124 | 125 | 126 | class NViewRandomPartitioning(torch.nn.Module): 127 | 128 | def __init__(self): 129 | raise NotImplementedError() 130 | 131 | def forward(self, inst: Dict[str, Any]): 132 | raise NotImplementedError() 133 | 134 | 135 | class NViewRandomMasking(torch.nn.Module): 136 | 137 | def __init__(self, 138 | masking_ratio_ranges: List[List[float]] = [[0.8, 1], 139 | [0.3, 0.7]], 140 | max_num_tokens: Optional[List[int]] = None, 141 | fixed_order=False): 142 | super().__init__() 143 | self.n_views_ = len(masking_ratio_ranges) 144 | mrr = torch.tensor(masking_ratio_ranges) 145 | self.masking_range_ = torch.diff(mrr, dim=1).squeeze(dim=1) 146 | self.masking_min_ = mrr[:, 0] 147 | if max_num_tokens: 148 | self.num_max_token_ = torch.tensor(max_num_tokens) 149 | else: 150 | self.num_max_token_ = None 151 | print(self.num_max_token_) 152 | self.asgmt_order_f_ = torch.arange if fixed_order else torch.randperm 153 | 154 | def forward(self, inst: Dict[str, Any]): 155 | lengths = torch.tensor([len(e) for e in inst['embeddings']]) 156 | assert len(inst['embeddings']) == self.n_views_, ( 157 | f"length of embedding is {len(inst['embeddings'])}, " + 158 | f"n_views is {self.n_views_}") 159 | 160 | asgmt_ord = self.asgmt_order_f_(self.n_views_) 161 | sizes = ((torch.rand(self.n_views_) * self.masking_range_[asgmt_ord] + 162 | self.masking_min_[asgmt_ord]) * lengths).round().to(int) 163 | 164 | if self.num_max_token_ is not None: 165 | sizes = torch.minimum(sizes, self.num_max_token_) 166 | 167 | idxs = [torch.randperm(l)[:s] for s, l in zip(sizes, lengths)] 168 | 169 | inst["embeddings"] = [ 170 | emb[idxs[i], :] for i, emb in enumerate(inst["embeddings"]) 171 | ] 172 | inst["coords"] = [ 173 | emb[idxs[i], :] for i, emb in enumerate(inst["coords"]) 174 | ] 175 | 176 | return inst 177 | 178 | 179 | class NViewOneRandomCropping(torch.nn.Module): 180 | 181 | def __init__(self, 182 | masking_size_ranges: List[int] = [1500, 1700], 183 | masking_aspect_ratio_range: List[float] = [1, 1], 184 | min_crop_area_thres=1600): 185 | super().__init__() 186 | msr = torch.tensor(masking_size_ranges) 187 | self.masking_area_range_ = msr[1] - msr[0] 188 | self.masking_area_min_ = msr[0] 189 | 190 | marr = torch.tensor(masking_aspect_ratio_range) 191 | self.aspect_range_ = marr[1] - marr[0] 192 | self.aspect_min_ = marr[0] 193 | 194 | self.min_crop_area_thres_ = min_crop_area_thres 195 | assert min_crop_area_thres >= masking_size_ranges[-1] 196 | 197 | def forward(self, inst: Dict[str, Any]): 198 | num_tokens = [len(e) for e in inst['embeddings']] 199 | assert len(set(num_tokens)) == 1 200 | num_tokens = num_tokens[0] 201 | if num_tokens <= self.min_crop_area_thres_: return inst 202 | 203 | coords_uncropped = inst["coords"][0] 204 | 205 | # attempt to exclude some edge regions for a center-ish crop 206 | min_r, min_c = coords_uncropped.min(dim=0).values 207 | max_r, max_c = coords_uncropped.max(dim=0).values 208 | exclude_region_side = torch.sqrt( 209 | torch.tensor(self.min_crop_area_thres_)) // 2 210 | 211 | if max_r - min_r > exclude_region_side * 2 + 1: 212 | max_r = max_r - exclude_region_side 213 | min_r = min_r + exclude_region_side 214 | if max_c - min_c > exclude_region_side * 2 + 1: 215 | max_c = max_c - exclude_region_side 216 | min_c = min_c + exclude_region_side 217 | coords_filt = coords_uncropped[ 218 | filt_coords(coords_uncropped, min_r, max_r, min_c, max_c), :] 219 | 220 | if len(coords_filt) == 0: 221 | logging.warning( 222 | f"bug found when computing one crop for {inst['path']}." + 223 | f"coords_filt shape {coords_filt.shape}" 224 | f"uncropped min {coords_uncropped.min(dim=0).values}; " + 225 | f"uncropped max {coords_uncropped.max(dim=0).values}; " + 226 | f"(min_r, min_c, max_r, max_c) ({(min_r, min_c, max_r, max_c)});" 227 | + f"exclude_region_side {exclude_region_side}") 228 | coords_filt = coords_uncropped 229 | 230 | centroid_idx = (len(coords_filt) * torch.rand(1)).to(int) 231 | centroid = coords_filt[centroid_idx, :].squeeze() 232 | 233 | # get bbox size 234 | area = ((torch.rand(1) * self.masking_area_range_ + 235 | self.masking_area_min_)).round().to(int) 236 | aspect = (torch.rand(1) * self.aspect_range_ + self.aspect_min_) 237 | dr = torch.sqrt(area / aspect) 238 | dc = (dr * aspect) 239 | dr = (dr / 2).round() 240 | dc = (dc / 2).round() 241 | 242 | r0, r1 = centroid[0] - dr, centroid[0] + dr 243 | c0, c1 = centroid[1] - dc, centroid[1] + dc 244 | 245 | idxs = filt_coords(coords_uncropped, r0, r1, c0, c1) 246 | 247 | inst["embeddings"] = [emb[idxs, :] for emb in inst["embeddings"]] 248 | inst["coords"] = [emb[idxs, :] for emb in inst["coords"]] 249 | 250 | return inst 251 | 252 | 253 | def filt_coords(coords, min_r, max_r, min_c, max_c): 254 | return torch.logical_and( 255 | torch.logical_and(coords[:, 0] > min_r, coords[:, 0] < max_r), 256 | torch.logical_and(coords[:, 1] > min_c, coords[:, 1] < max_c)) 257 | 258 | 259 | class NViewRandomCropping(torch.nn.Module): 260 | 261 | def __init__(self, 262 | masking_size_ranges: List[List[int]] = [[100, 900], 263 | [100, 900]], 264 | masking_aspect_ratio_range: List[List[float]] = [[0.3, 3], 265 | [0.3, 3]], 266 | fixed_order=False): 267 | super().__init__() 268 | assert len(masking_size_ranges) == len(masking_aspect_ratio_range) 269 | self.n_views_ = len(masking_size_ranges) 270 | 271 | msr = torch.tensor(masking_size_ranges) 272 | self.masking_area_range_ = torch.diff(msr, dim=1).squeeze(dim=1) 273 | self.masking_area_min_ = msr[:, 0] 274 | 275 | marr = torch.tensor(masking_aspect_ratio_range) 276 | self.aspect_range_ = torch.diff(marr, dim=1).squeeze(dim=1) 277 | self.aspect_min_ = marr[:, 0] 278 | 279 | self.msr_ = masking_size_ranges 280 | self.marr_ = masking_aspect_ratio_range 281 | 282 | self.asgmt_order_f_ = torch.arange if fixed_order else torch.randperm 283 | 284 | def forward(self, inst: Dict[str, Any]): 285 | lengths = torch.tensor([len(e) for e in inst['embeddings']]) 286 | assert len(inst['embeddings']) == self.n_views_, ( 287 | f"length of embedding is {len(inst['embeddings'])}, " + 288 | f"n_views is {self.n_views_}") 289 | 290 | # randomly picking a centroid on the slide 291 | centroid_idx = (lengths * torch.rand(self.n_views_)).to(int) 292 | centroids = torch.stack( 293 | [i[j, :] for i, j in zip(inst["coords"], centroid_idx)]) 294 | 295 | # get bbox size 296 | asgmt_ord = self.asgmt_order_f_(self.n_views_) 297 | areas = ( 298 | (torch.rand(self.n_views_) * self.masking_area_range_[asgmt_ord] + 299 | self.masking_area_min_[asgmt_ord])).round().to(int) 300 | aspects = (torch.rand(self.n_views_) * self.aspect_range_[asgmt_ord] + 301 | self.aspect_min_[asgmt_ord]) 302 | dr = torch.sqrt(areas / aspects) 303 | dc = (dr * aspects) 304 | dr = (dr / 2).round() 305 | dc = (dc / 2).round() 306 | 307 | r_min, r_max = centroids[:, 0] - dr, centroids[:, 0] + dr 308 | c_min, c_max = centroids[:, 1] - dc, centroids[:, 1] + dc 309 | 310 | idxs = [ 311 | filt_coords(coords, r0, r1, c0, c1) 312 | for (coords, r0, r1, c0, 313 | c1) in zip(inst["coords"], r_min, r_max, c_min, c_max) 314 | ] 315 | 316 | inst["embeddings"] = [ 317 | emb[idxs[i], :] for i, emb in enumerate(inst["embeddings"]) 318 | ] 319 | inst["coords"] = [ 320 | emb[idxs[i], :] for i, emb in enumerate(inst["coords"]) 321 | ] 322 | 323 | return inst -------------------------------------------------------------------------------- /fastglioma/eval/eval_knn.py: -------------------------------------------------------------------------------- 1 | """kNN evaluation modules and script. 2 | 3 | Copyright (c) 2024 University of Michigan. All rights reserved. 4 | Licensed under the MIT License. See LICENSE for license information. 5 | """ 6 | 7 | import os 8 | import logging 9 | from shutil import copy2 10 | from functools import partial 11 | from typing import List, Union, Dict, Any 12 | 13 | import yaml 14 | import numpy as np 15 | import pandas as pd 16 | from sklearn.metrics import confusion_matrix 17 | from tqdm import tqdm 18 | 19 | import torch 20 | from torchvision.transforms import Compose 21 | 22 | import pytorch_lightning as pl 23 | from torchmetrics import AveragePrecision, Accuracy 24 | 25 | from fastglioma.datasets.srh_dataset import PatchDataset 26 | from fastglioma.datasets.embedding_dataset import SlideEmbeddingDataset 27 | from fastglioma.datasets.improc import get_transformations 28 | from fastglioma.datasets.emb_proc import get_emb_transformations, emb_collate_fn 29 | from fastglioma.utils.common import (parse_args, get_exp_name, config_loggers, 30 | get_num_worker) 31 | from fastglioma.train.train_patch import HiDiscSystem 32 | from fastglioma.train.train_slide import SlideSSLSystem 33 | 34 | 35 | # code for kNN prediction is from the github repo IgorSusmelj/barlowtwins 36 | # https://github.com/IgorSusmelj/barlowtwins/blob/main/utils.py 37 | def knn_predict(feature, feature_bank, feature_labels, classes: int, 38 | knn_k: int, knn_t: float): 39 | """Helper method to run kNN predictions on features from a feature bank. 40 | 41 | Args: 42 | feature: Tensor of shape [N, D] consisting of N D-dimensional features 43 | feature_bank: Tensor of a database of features used for kNN 44 | feature_labels: Labels for the features in our feature_bank 45 | classes: Number of classes (e.g. 10 for CIFAR-10) 46 | knn_k: Number of k neighbors used for kNN 47 | knn_t: Temperature 48 | """ 49 | # cos similarity between each feature vector and feature bank ---> [B, N] 50 | sim_matrix = torch.mm(feature, feature_bank) 51 | # [B, K] 52 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) 53 | # [B, K] 54 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), 55 | dim=-1, 56 | index=sim_indices) 57 | # we do a reweighting of the similarities 58 | sim_weight = (sim_weight / knn_t).exp() 59 | # counts for each class 60 | one_hot_label = torch.zeros(feature.size(0) * knn_k, 61 | classes, 62 | device=sim_labels.device) 63 | # [B*K, C] 64 | one_hot_label = one_hot_label.scatter(dim=-1, 65 | index=sim_labels.view(-1, 1), 66 | value=1.0) 67 | # weighted score ---> [B, C] 68 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * 69 | sim_weight.unsqueeze(dim=-1), 70 | dim=1) 71 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 72 | return pred_labels, pred_scores 73 | 74 | 75 | def get_embeddings_patch(cf: Dict[str, Any], 76 | exp_root: str) -> Dict[str, Union[torch.Tensor, List[str]]]: 77 | """Run forward pass on the dataset, and generate embeddings and logits""" 78 | train_xform, valid_xform = get_transformations(cf) 79 | 80 | logging.info(f"train_xform \n{train_xform}") 81 | logging.info(f"valid_xform \n{valid_xform}") 82 | 83 | # get dataset / loader 84 | train_dset = PatchDataset(data_root=cf["data"]["db_root"], 85 | studies="train", 86 | transform=train_xform, 87 | balance_patch_per_class=False) 88 | 89 | train_loader = torch.utils.data.DataLoader( 90 | train_dset, 91 | batch_size=cf["eval"]["predict_batch_size"], 92 | drop_last=False, 93 | pin_memory=True, 94 | num_workers=get_num_worker(), 95 | persistent_workers=True) 96 | 97 | val_dset = PatchDataset(data_root=cf["data"]["db_root"], 98 | studies="val", 99 | transform=valid_xform, 100 | balance_patch_per_class=False) 101 | 102 | val_loader = torch.utils.data.DataLoader( 103 | val_dset, 104 | batch_size=cf["eval"]["predict_batch_size"], 105 | drop_last=False, 106 | pin_memory=True, 107 | num_workers=get_num_worker(), 108 | persistent_workers=True) 109 | 110 | # load lightning checkpoint 111 | ckpt_path = os.path.join(cf["infra"]["log_dir"], cf["infra"]["exp_name"], 112 | cf["eval"]["ckpt_path"]) 113 | 114 | model = HiDiscSystem.load_from_checkpoint(ckpt_path, 115 | cf=cf, 116 | num_it_per_ep=0, 117 | max_epochs=-1, 118 | nc=0) 119 | 120 | # create trainer 121 | trainer = pl.Trainer(accelerator="gpu", 122 | devices=1, 123 | max_epochs=-1, 124 | default_root_dir=exp_root, 125 | enable_checkpointing=False, 126 | logger=False) 127 | 128 | # generate predictions 129 | train_predictions = trainer.predict(model, dataloaders=train_loader) 130 | val_predictions = trainer.predict(model, dataloaders=val_loader) 131 | 132 | def process_predictions(predictions): 133 | pred = {} 134 | for k in predictions[0].keys(): 135 | if k == "path": 136 | pred[k] = [pk for p in predictions for pk in p[k][0]] 137 | else: 138 | pred[k] = torch.cat([p[k] for p in predictions]) 139 | return pred 140 | 141 | train_predictions = process_predictions(train_predictions) 142 | val_predictions = process_predictions(val_predictions) 143 | 144 | train_embs = torch.nn.functional.normalize(train_predictions["embeddings"], 145 | p=2, 146 | dim=1).T 147 | val_embs = torch.nn.functional.normalize(val_predictions["embeddings"], 148 | p=2, 149 | dim=1) 150 | 151 | # knn evaluation 152 | batch_size = cf["eval"]["knn"]["batch_size"] 153 | all_scores = [] 154 | for k in tqdm(range(val_embs.shape[0] // batch_size + 1)): 155 | start_coeff = batch_size * k 156 | end_coeff = min(batch_size * (k + 1), val_embs.shape[0]) 157 | val_embs_k = val_embs[start_coeff:end_coeff] # 1536 x 2048 158 | 159 | pred_labels, pred_scores = knn_predict( 160 | val_embs_k, 161 | train_embs, 162 | train_predictions["label"], 163 | len(train_loader.dataset.classes_), 164 | knn_k=cf["eval"]["knn"]["k"], 165 | knn_t=cf["eval"]["knn"]["t"]) 166 | 167 | all_scores.append( 168 | torch.nn.functional.normalize(pred_scores, p=1, dim=1)) 169 | torch.cuda.empty_cache() 170 | 171 | val_predictions["logits"] = torch.vstack(all_scores) 172 | return val_predictions 173 | 174 | 175 | def get_embeddings_slide(cf: Dict[str, Any], 176 | exp_root: str) -> Dict[str, Union[torch.Tensor, List[str]]]: 177 | """Run forward pass on the dataset, and generate embeddings and logits""" 178 | train_xform, valid_xform = get_emb_transformations(cf) 179 | 180 | logging.info(f"train_xform \n{train_xform}") 181 | logging.info(f"valid_xform \n{valid_xform}") 182 | 183 | # get dataset / loader 184 | train_dset = SlideEmbeddingDataset(data_root=cf["data"]["db_root"], 185 | embedding_root=cf["data"]["embedding_root"], 186 | tag=cf["data"]["tag"], 187 | studies="train", 188 | transform=train_xform, 189 | balance_slide_per_class=False, 190 | num_transforms=1) 191 | 192 | train_loader = torch.utils.data.DataLoader( 193 | train_dset, 194 | batch_size=cf["eval"]["predict_batch_size"], 195 | drop_last=False, 196 | pin_memory=True, 197 | num_workers=get_num_worker(), 198 | persistent_workers=True, 199 | collate_fn=emb_collate_fn) 200 | 201 | val_dset = SlideEmbeddingDataset(data_root=cf["data"]["db_root"], 202 | embedding_root=cf["data"]["embedding_root"], 203 | tag=cf["data"]["tag"], 204 | studies="val", 205 | transform=valid_xform, 206 | balance_slide_per_class=False, 207 | num_transforms=1) 208 | 209 | val_loader = torch.utils.data.DataLoader( 210 | val_dset, 211 | batch_size=cf["eval"]["predict_batch_size"], 212 | drop_last=False, 213 | pin_memory=True, 214 | num_workers=get_num_worker(), 215 | persistent_workers=True, 216 | collate_fn=emb_collate_fn) 217 | 218 | # load lightning checkpoint 219 | ckpt_path = os.path.join(cf["infra"]["log_dir"], cf["infra"]["exp_name"], 220 | cf["eval"]["ckpt_path"]) 221 | 222 | model = SlideSSLSystem.load_from_checkpoint(ckpt_path, 223 | cf=cf, 224 | num_it_per_ep=0, 225 | max_epochs=-1, 226 | nc=0) 227 | 228 | # create trainer 229 | trainer = pl.Trainer(accelerator="gpu", 230 | devices=1, 231 | max_epochs=-1, 232 | default_root_dir=exp_root, 233 | enable_checkpointing=False, 234 | logger=False) 235 | 236 | # generate predictions 237 | train_predictions = trainer.predict(model, dataloaders=train_loader) 238 | val_predictions = trainer.predict(model, dataloaders=val_loader) 239 | 240 | def process_predictions(predictions): 241 | pred = {} 242 | for k in predictions[0].keys(): 243 | if k == "path": 244 | pred[k] = [pk for p in predictions for pk in p[k][0]] 245 | else: 246 | pred[k] = torch.cat([p[k] for p in predictions]) 247 | return pred 248 | 249 | train_predictions = process_predictions(train_predictions) 250 | val_predictions = process_predictions(val_predictions) 251 | 252 | train_embs = torch.nn.functional.normalize(train_predictions["embeddings"], 253 | p=2, 254 | dim=1).T 255 | val_embs = torch.nn.functional.normalize(val_predictions["embeddings"], 256 | p=2, 257 | dim=1) 258 | 259 | # knn evaluation 260 | batch_size = cf["eval"]["knn"]["batch_size"] 261 | all_scores = [] 262 | for k in tqdm(range(val_embs.shape[0] // batch_size + 1)): 263 | start_coeff = batch_size * k 264 | end_coeff = min(batch_size * (k + 1), val_embs.shape[0]) 265 | val_embs_k = val_embs[start_coeff:end_coeff] # 1536 x 2048 266 | 267 | pred_labels, pred_scores = knn_predict( 268 | val_embs_k, 269 | train_embs, 270 | train_predictions["label"], 271 | len(train_loader.dataset.classes_), 272 | knn_k=cf["eval"]["knn"]["k"], 273 | knn_t=cf["eval"]["knn"]["t"]) 274 | 275 | all_scores.append( 276 | torch.nn.functional.normalize(pred_scores, p=1, dim=1)) 277 | torch.cuda.empty_cache() 278 | 279 | val_predictions["logits"] = torch.vstack(all_scores) 280 | return val_predictions 281 | 282 | 283 | def make_specs_patch(predictions: Dict[str, Union[torch.Tensor, List[str]]]) -> None: 284 | """Compute all specs for an experiment""" 285 | 286 | # aggregate prediction into a dataframe 287 | pred = pd.DataFrame.from_dict({ 288 | "path": 289 | predictions["path"], 290 | "labels": [l.item() for l in list(predictions["label"])], 291 | "logits": [l.tolist() for l in list(predictions["logits"])] 292 | }) 293 | pred["logits"] = pred["logits"].apply( 294 | lambda x: torch.nn.functional.softmax(torch.tensor(x), dim=0)) 295 | 296 | # add patient and slide info from patch paths 297 | pred["patient"] = pred["path"].apply(lambda x: x.split("/")[-4]) 298 | pred["slide"] = pred["path"].apply(lambda x: "/".join( 299 | [x.split("/")[-4], x.split("/")[-3]])) 300 | 301 | # aggregate logits 302 | get_agged_logits = lambda pred, mode: pd.DataFrame( 303 | pred.groupby(by=[mode, "labels"])["logits"].apply( 304 | lambda x: [sum(y) for y in zip(*x)])).reset_index() 305 | 306 | slides = get_agged_logits(pred, "slide") 307 | patients = get_agged_logits(pred, "patient") 308 | 309 | normalize_f = lambda x: torch.nn.functional.normalize(x, dim=1, p=1) 310 | patch_logits = normalize_f(torch.tensor(np.vstack(pred["logits"]))) 311 | slides_logits = normalize_f(torch.tensor(np.vstack(slides["logits"]))) 312 | patient_logits = normalize_f(torch.tensor(np.vstack(patients["logits"]))) 313 | 314 | patch_label = torch.tensor(pred["labels"]) 315 | slides_label = torch.tensor(slides["labels"]) 316 | patient_label = torch.tensor(patients["labels"]) 317 | 318 | # generate metrics 319 | def get_all_metrics(logits, label): 320 | map = AveragePrecision(task="multiclass", num_classes=7) 321 | acc = Accuracy(task="multiclass", num_classes=7) 322 | t2 = Accuracy(task="multiclass", num_classes=7, top_k=2) 323 | t3 = Accuracy(task="multiclass", num_classes=7, top_k=3) 324 | mca = Accuracy(task="multiclass", num_classes=7, average="macro") 325 | 326 | acc_val = acc(logits, label) 327 | t2_val = t2(logits, label) 328 | t3_val = t3(logits, label) 329 | mca_val = mca(logits, label) 330 | map_val = map(logits, label) 331 | 332 | return torch.stack((acc_val, t2_val, t3_val, mca_val, map_val)) 333 | 334 | all_metrics = torch.vstack((get_all_metrics(patch_logits, patch_label), 335 | get_all_metrics(slides_logits, slides_label), 336 | get_all_metrics(patient_logits, 337 | patient_label))) 338 | all_metrics = pd.DataFrame(all_metrics, 339 | columns=["acc", "t2", "t3", "mca", "map"], 340 | index=["patch", "slide", "patient"]) 341 | 342 | # generate confusion matrices 343 | patch_conf = confusion_matrix(y_true=patch_label, 344 | y_pred=patch_logits.argmax(dim=1)) 345 | 346 | slide_conf = confusion_matrix(y_true=slides_label, 347 | y_pred=slides_logits.argmax(dim=1)) 348 | 349 | patient_conf = confusion_matrix(y_true=patient_label, 350 | y_pred=patient_logits.argmax(dim=1)) 351 | 352 | print("\nmetrics") 353 | print(all_metrics) 354 | print("\npatch confusion matrix") 355 | print(patch_conf) 356 | print("\nslide confusion matrix") 357 | print(slide_conf) 358 | print("\npatient confusion matrix") 359 | print(patient_conf) 360 | 361 | return 362 | 363 | 364 | def make_specs_slide(predictions: Dict[str, Union[torch.Tensor, List[str]]]) -> None: 365 | """Compute all specs for an experiment""" 366 | 367 | # aggregate prediction into a dataframe 368 | pred = pd.DataFrame.from_dict({ 369 | "path": 370 | predictions["path"], 371 | "labels": [l.item() for l in list(predictions["label"])], 372 | "logits": [l.tolist() for l in list(predictions["logits"])] 373 | }) 374 | pred["logits"] = pred["logits"].apply( 375 | lambda x: torch.nn.functional.softmax(torch.tensor(x), dim=0)) 376 | 377 | # add patient and slide info from patch paths 378 | pred["patient"] = pred["path"].apply(lambda x: x.split("/")[-3]) 379 | 380 | # aggregate logits 381 | get_agged_logits = lambda pred, mode: pd.DataFrame( 382 | pred.groupby(by=[mode, "labels"])["logits"].apply( 383 | lambda x: [sum(y) for y in zip(*x)])).reset_index() 384 | 385 | patients = get_agged_logits(pred, "patient") 386 | 387 | normalize_f = lambda x: torch.nn.functional.normalize(x, dim=1, p=1) 388 | slides_logits = normalize_f(torch.tensor(np.vstack(pred["logits"]))) 389 | patient_logits = normalize_f(torch.tensor(np.vstack(patients["logits"]))) 390 | 391 | slides_label = torch.tensor(pred["labels"]) 392 | patient_label = torch.tensor(patients["labels"]) 393 | 394 | # generate metrics 395 | def get_all_metrics(logits, label): 396 | map = AveragePrecision(task="multiclass", num_classes=7) 397 | acc = Accuracy(task="multiclass", num_classes=7) 398 | t2 = Accuracy(task="multiclass", num_classes=7, top_k=2) 399 | t3 = Accuracy(task="multiclass", num_classes=7, top_k=3) 400 | mca = Accuracy(task="multiclass", num_classes=7, average="macro") 401 | 402 | acc_val = acc(logits, label) 403 | t2_val = t2(logits, label) 404 | t3_val = t3(logits, label) 405 | mca_val = mca(logits, label) 406 | map_val = map(logits, label) 407 | 408 | return torch.stack((acc_val, t2_val, t3_val, mca_val, map_val)) 409 | 410 | all_metrics = torch.vstack((get_all_metrics(slides_logits, slides_label), 411 | get_all_metrics(patient_logits, 412 | patient_label))) 413 | all_metrics = pd.DataFrame(all_metrics, 414 | columns=["acc", "t2", "t3", "mca", "map"], 415 | index=["slide", "patient"]) 416 | 417 | # generate confusion matrices 418 | slide_conf = confusion_matrix(y_true=slides_label, 419 | y_pred=slides_logits.argmax(dim=1)) 420 | 421 | patient_conf = confusion_matrix(y_true=patient_label, 422 | y_pred=patient_logits.argmax(dim=1)) 423 | 424 | print("\nmetrics") 425 | print(all_metrics) 426 | print("\nslide confusion matrix") 427 | print(slide_conf) 428 | print("\npatient confusion matrix") 429 | print(patient_conf) 430 | 431 | return 432 | 433 | 434 | def setup_eval_paths(cf, get_exp_name, cmt_append): 435 | """Get name of the ouput dirs and create them in the file system.""" 436 | log_root = cf["infra"]["log_dir"] 437 | exp_name = cf["infra"]["exp_name"] 438 | instance_name = cf["eval"]["ckpt_path"].split("/")[0] 439 | eval_instance_name = "_".join([get_exp_name(cf), cmt_append]) 440 | exp_root = os.path.join(log_root, exp_name, instance_name, "evals", 441 | eval_instance_name) 442 | 443 | # generate needed folders, evals will be embedded in experiment folders 444 | pred_dir = os.path.join(exp_root, 'predictions') 445 | config_dir = os.path.join(exp_root, 'config') 446 | for dir_name in [pred_dir, config_dir]: 447 | if not os.path.exists(dir_name): 448 | os.makedirs(dir_name) 449 | 450 | # if there is a previously generated prediction, also return the 451 | # prediction filename so we don't have to predict again 452 | if cf["eval"].get("eval_predictions", None): 453 | other_eval_instance_name = cf["eval"]["eval_predictions"] 454 | pred_fname = os.path.join(log_root, exp_name, instance_name, "evals", 455 | other_eval_instance_name, "predictions", 456 | "predictions.pt") 457 | else: 458 | pred_fname = None 459 | 460 | return exp_root, pred_dir, partial(copy2, dst=config_dir), pred_fname 461 | 462 | 463 | def main(): 464 | """Driver script for evaluation pipeline.""" 465 | cf_fd = parse_args() 466 | cf = yaml.load(cf_fd, Loader=yaml.FullLoader) 467 | exp_root, pred_dir, cp_config, pred_fname = setup_eval_paths( 468 | cf, get_exp_name, "") 469 | pl.seed_everything(cf["infra"]["seed"]) 470 | 471 | # logging and copying config files 472 | cp_config(cf_fd.name) 473 | config_loggers(exp_root) 474 | 475 | # get predictions 476 | if not cf["eval"].get("eval_predictions", None): 477 | logging.info("generating predictions") 478 | if cf["model"]["train_alg"] == "hidisc": 479 | predictions = get_embeddings_patch(cf, exp_root) 480 | elif cf["model"]["train_alg"] == "scm": 481 | predictions = get_embeddings_slide(cf, exp_root) 482 | else: 483 | raise NotImplementedError(f"train_alg {cf['model']['train_alg']} not implemented") 484 | torch.save(predictions, os.path.join(pred_dir, "predictions.pt")) 485 | else: 486 | logging.info("loading predictions") 487 | predictions = torch.load(pred_fname) 488 | 489 | # generate specs 490 | if cf["model"]["train_alg"] == "hidisc": 491 | make_specs_patch(predictions) 492 | elif cf["model"]["train_alg"] == "scm": 493 | make_specs_slide(predictions) 494 | else: 495 | raise NotImplementedError(f"train_alg {cf['model']['train_alg']} not implemented") 496 | 497 | 498 | if __name__ == "__main__": 499 | main() --------------------------------------------------------------------------------