├── README.md ├── code ├── Data_Generate.py ├── argument.py ├── entmax │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── activations.cpython-37.pyc │ │ ├── entmax.cpython-37.pyc │ │ ├── losses.cpython-37.pyc │ │ └── root_finding.cpython-37.pyc │ ├── activations.py │ ├── entmax.py │ ├── losses.py │ ├── root_finding.py │ ├── test_losses.py │ ├── test_mask.py │ ├── test_root_finding.py │ └── test_topk.py ├── evaluate.py ├── local_utils │ ├── Gempooling.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── dice_bce_loss.cpython-37.pyc │ │ ├── focal_loss.cpython-37.pyc │ │ ├── init_log.cpython-37.pyc │ │ ├── kappa_CE.cpython-37.pyc │ │ ├── label_smoothing.cpython-37.pyc │ │ ├── load_pretrainedmodel.cpython-37.pyc │ │ ├── load_pretrainedmodel.cpython-38.pyc │ │ ├── metrics.cpython-37.pyc │ │ ├── misc.cpython-37.pyc │ │ ├── misc.cpython-38.pyc │ │ ├── poly_ly.cpython-37.pyc │ │ ├── seed_everything.cpython-37.pyc │ │ ├── tools.cpython-37.pyc │ │ └── tools.cpython-38.pyc │ ├── bn_update.py │ ├── dice_bce_loss.py │ ├── focal_loss.py │ ├── init_log.py │ ├── kappa_CE.py │ ├── label_smoothing.py │ ├── load_pretrainedmodel.py │ ├── metrics.py │ ├── misc.py │ ├── poly_ly.py │ ├── seed_everything.py │ └── tools.py ├── spectr.py ├── spectr_block.py ├── train_main.py └── vit_modeling.py ├── dataset └── four_fold.json └── model └── link.md /README.md: -------------------------------------------------------------------------------- 1 | # SpecTr: Spectral Transformer for Microscopic Hyperspectral Pathology Image Segmentation (TCSVT 2023) 2 | 3 | Official Code for "SpecTr: Spectral Transformer for Microscopic Hyperspectral Pathology Image Segmentation" 4 | by Boxiang Yun, Baiying Lei, Jieneng Chen, Huiyu Wang, Song Qiu, Wei Shen, Qingli Li, Yan Wang* 5 | 6 | ## Introduction 7 | (TCSVT 2023) Official code for "[SpecTr: Spectral Transformer for Microscopic Hyperspectral Pathology Image Segmentation](https://ieeexplore.ieee.org/abstract/document/10288474)". 8 | ![SpecTr](https://github.com/DeepMed-Lab-ECNU/SpecTr/assets/36001411/38346e9e-bf97-441f-a099-b3b4f729b584) 9 | 10 | 11 | ## Requirements 12 | This repository is based on PyTorch 1.10, CUDA 11.1, Python 3.9.7, and segmentation-models-pytorch 0.3.3. All experiments in our paper were conducted on NVIDIA GeForce RTX 3090 GPU with an identical experimental setting. 13 | 14 | ## Usage 15 | We provide `code`, `dataset`, and `model` for the MDC dataset. 16 | 17 | The official dataset can be found at [MDC](http://bio-hsi.ecnu.edu.cn/). However, due to its size, we also provide preprocessed [data](https://www.kaggle.com/datasets/hfutybx/mhsi-choledoch-dataset-preprocessed-dataset) (including denoising and resizing operations) for reproducing our paper experiments." 18 | 19 | Download the dataset and move to the dataset fold. 20 | 21 | To train a model, 22 | ``` 23 | CUDA_VISIBLE_DEVICES=0 python train_main.py -r ./dataset/MDC -sn 60 -cut 192 -e 75 24 | ``` 25 | 26 | To test a model, 27 | ``` 28 | CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python evaluate.py -r ./dataset/MDC -sn 60 -cut 192 -name SpecTr_XXXX 29 | ``` 30 | 31 | ## Acknowledgements 32 | Some modules in our code were inspired by [vit-pytorch](https://github.com/lucidrains/vit-pytorch) and [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch). We appreciate the effort of these authors to provide open-source code for the community. Hope our work can also contribute to related research. 33 | 34 | ## Related Work 35 | "[Factor Space and Spectrum for Medical Hyperspectral Image Segmentation (MICCAI 2023)](https://link.springer.com/chapter/10.1007/978-3-031-43901-8_15)" 36 | 37 | ## Questions 38 | If you encounter any issues accessing the dataset, such as unable to sign in [MDC](http://bio-hsi.ecnu.edu.cn/), please contact me at '52265904012@stu.ecnu.edu.cn' 39 | -------------------------------------------------------------------------------- /code/Data_Generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jun 23 17:14:29 2020 5 | @author: Boxiang Yun School:ECNU&HFUT Email:971950297@qq.com 6 | """ 7 | from torch.utils.data.dataset import Dataset 8 | import skimage.io 9 | #from skimage.metrics import normalized_mutual_information 10 | from sklearn.metrics import normalized_mutual_info_score 11 | import numpy as np 12 | import cv2 13 | import os 14 | from argument import Transform 15 | from spectral import * 16 | from spectral import open_image 17 | import random 18 | import math 19 | from scipy.ndimage import zoom 20 | import warnings 21 | warnings.filterwarnings('ignore') 22 | from einops import repeat 23 | 24 | class Data_Generate_Cho(Dataset):# 25 | def __init__(self, img_paths, seg_paths=None, 26 | cutting=None, transform=None, 27 | channels=None, outtype='3d'): 28 | self.img_paths = img_paths 29 | self.seg_paths = seg_paths 30 | self.transform = transform 31 | self.cutting = cutting 32 | self.channels = channels 33 | self.outtype = outtype 34 | 35 | def __getitem__(self,index): 36 | img_path = self.img_paths[index] 37 | mask_path = self.seg_paths[index] 38 | mask = cv2.imread(mask_path, 0)/255 39 | img = envi.open(img_path)[:, :, :] 40 | img = img[:, :, self.channels] if self.channels is not None else img 41 | 42 | if img.shape != mask.shape: 43 | mask = cv2.resize(mask, (img.shape[1], img.shape[0])) 44 | 45 | if self.transform != None: 46 | img, mask = self.transform((img, mask)) 47 | 48 | mask = mask.astype(np.uint8) 49 | if self.cutting is not None: 50 | while(1): 51 | xx = random.randint(0, img.shape[0] - self.cutting) 52 | yy = random.randint(0, img.shape[1] - self.cutting) 53 | patch_img = img[xx:xx + self.cutting, yy:yy + self.cutting] 54 | patch_mask = mask[xx:xx + self.cutting, yy:yy + self.cutting] 55 | if patch_mask.sum()!=0: break 56 | img = patch_img 57 | mask = patch_mask 58 | 59 | 60 | img = img[:, :, None] if len(img.shape)==2 else img 61 | img = np.transpose(img, (2, 0, 1)) 62 | if self.outtype == '3d': 63 | img = img[None] 64 | mask = mask[None, ].astype(np.float32) 65 | img = img.astype(np.float32) 66 | return img, mask 67 | 68 | def __len__(self): 69 | return len(self.img_paths) -------------------------------------------------------------------------------- /code/argument.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import albumentations as A 4 | 5 | def resize(image, size=(128, 128)): 6 | return cv2.resize(image, size) 7 | 8 | def add_gaussian_noise(x, sigma): 9 | x += np.random.randn(*x.shape) * sigma 10 | x = np.clip(x, 0., 1.) 11 | return x 12 | 13 | def _evaluate_ratio(ratio): 14 | if ratio <= 0.: 15 | return False 16 | return np.random.uniform() < ratio 17 | 18 | def norm(img): 19 | return (img-img.min())/(img.max()-img.min()) 20 | 21 | 22 | def apply_aug(aug, image, mask=None): 23 | if mask is None: 24 | return aug(image=image)['image'] 25 | else: 26 | augment = aug(image=image,mask=mask) 27 | return augment['image'],augment['mask'] 28 | 29 | 30 | class Transform: 31 | def __init__(self, size=None, train=True, 32 | BrightContrast_ration=0., noise_ratio=0., cutout_ratio=0., scale_ratio=0., 33 | gamma_ratio=0., grid_distortion_ratio=0., elastic_distortion_ratio=0., 34 | piece_affine_ratio=0., ssr_ratio=0., Rotate_ratio=0.,Flip_ratio=0.): 35 | 36 | self.size = size 37 | self.train = train 38 | self.noise_ratio = noise_ratio 39 | self.BrightContrast_ration = BrightContrast_ration 40 | self.cutout_ratio = cutout_ratio 41 | self.grid_distortion_ratio = grid_distortion_ratio 42 | self.elastic_distortion_ratio = elastic_distortion_ratio 43 | self.piece_affine_ratio = piece_affine_ratio 44 | self.ssr_ratio = ssr_ratio 45 | self.Rotate_ratio = Rotate_ratio 46 | self.Flip_ratio = Flip_ratio 47 | self.scale_ratio = scale_ratio 48 | self.gamma_ratio = gamma_ratio 49 | 50 | def __call__(self, example): 51 | if self.train: 52 | x, y = example 53 | else: 54 | x = example 55 | # --- Augmentation --- 56 | # --- Train/Test common preprocessing --- 57 | 58 | if self.size is not None: 59 | x = resize(x, size=self.size) 60 | 61 | # albumentations... 62 | 63 | # # 1. blur 64 | if _evaluate_ratio(self.BrightContrast_ration): 65 | x = apply_aug(A.RandomBrightnessContrast(p=1.0), x) 66 | # 67 | if _evaluate_ratio(self.noise_ratio): 68 | r = np.random.uniform() 69 | if r < 0.50: 70 | x = apply_aug(A.GaussNoise(var_limit=5. / 255., p=1.0), x) 71 | else: 72 | x = apply_aug(A.MultiplicativeNoise(p=1.0), x) 73 | 74 | if _evaluate_ratio(self.grid_distortion_ratio): 75 | x,y = apply_aug(A.GridDistortion(p=1.0), x,y) 76 | 77 | if _evaluate_ratio(self.elastic_distortion_ratio): 78 | x,y = apply_aug(A.ElasticTransform( 79 | sigma=50, alpha=1, p=1.0), x,y) 80 | 81 | if _evaluate_ratio(self.gamma_ratio): 82 | x, y = apply_aug(A.RandomGamma((70, 150), p=1.0), x, y) 83 | # 84 | if _evaluate_ratio(self.Rotate_ratio): 85 | x,y = apply_aug(A.Rotate(p=1.0),x,y) 86 | 87 | if _evaluate_ratio(self.Flip_ratio): 88 | x,y = apply_aug(A.Flip(p=1.0),x,y) 89 | 90 | if _evaluate_ratio(self.scale_ratio): 91 | x, y = apply_aug(A.RandomScale(p=1.0, scale_limit=0.15), x, y) 92 | 93 | if self.train: 94 | return x, y 95 | else: 96 | return x 97 | -------------------------------------------------------------------------------- /code/entmax/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.dev0" 2 | 3 | from entmax.activations import sparsemax, entmax15, Sparsemax, Entmax15 4 | from entmax.root_finding import ( 5 | sparsemax_bisect, 6 | entmax_bisect, 7 | SparsemaxBisect, 8 | EntmaxBisect, 9 | ) 10 | from entmax.losses import ( 11 | sparsemax_loss, 12 | entmax15_loss, 13 | sparsemax_bisect_loss, 14 | entmax_bisect_loss, 15 | SparsemaxLoss, 16 | SparsemaxBisectLoss, 17 | Entmax15Loss, 18 | EntmaxBisectLoss, 19 | ) 20 | from entmax.entmax import EntmaxAlpha 21 | -------------------------------------------------------------------------------- /code/entmax/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/entmax/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/entmax/__pycache__/activations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/entmax/__pycache__/activations.cpython-37.pyc -------------------------------------------------------------------------------- /code/entmax/__pycache__/entmax.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/entmax/__pycache__/entmax.cpython-37.pyc -------------------------------------------------------------------------------- /code/entmax/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/entmax/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /code/entmax/__pycache__/root_finding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/entmax/__pycache__/root_finding.cpython-37.pyc -------------------------------------------------------------------------------- /code/entmax/activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of entmax (Peters et al., 2019). See 3 | https://arxiv.org/pdf/1905.05702 for detailed description. 4 | 5 | This builds on previous work with sparsemax (Martins & Astudillo, 2016). 6 | See https://arxiv.org/pdf/1602.02068. 7 | """ 8 | 9 | # Author: Ben Peters 10 | # Author: Vlad Niculae 11 | # License: MIT 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Function 16 | 17 | 18 | def _make_ix_like(X, dim): 19 | d = X.size(dim) 20 | rho = torch.arange(1, d + 1, device=X.device, dtype=X.dtype) 21 | view = [1] * X.dim() 22 | view[0] = -1 23 | return rho.view(view).transpose(0, dim) 24 | 25 | 26 | def _roll_last(X, dim): 27 | if dim == -1: 28 | return X 29 | elif dim < 0: 30 | dim = X.dim() - dim 31 | 32 | perm = [i for i in range(X.dim()) if i != dim] + [dim] 33 | return X.permute(perm) 34 | 35 | 36 | def _sparsemax_threshold_and_support(X, dim=-1, k=None): 37 | """Core computation for sparsemax: optimal threshold and support size. 38 | 39 | Parameters 40 | ---------- 41 | X : torch.Tensor 42 | The input tensor to compute thresholds over. 43 | 44 | dim : int 45 | The dimension along which to apply sparsemax. 46 | 47 | k : int or None 48 | number of largest elements to partial-sort over. For optimal 49 | performance, should be slightly bigger than the expected number of 50 | nonzeros in the solution. If the solution is more than k-sparse, 51 | this function is recursively called with a 2*k schedule. 52 | If `None`, full sorting is performed from the beginning. 53 | 54 | Returns 55 | ------- 56 | tau : torch.Tensor like `X`, with all but the `dim` dimension intact 57 | the threshold value for each vector 58 | support_size : torch LongTensor, shape like `tau` 59 | the number of nonzeros in each vector. 60 | """ 61 | 62 | if k is None or k >= X.shape[dim]: # do full sort 63 | topk, _ = torch.sort(X, dim=dim, descending=True) 64 | else: 65 | topk, _ = torch.topk(X, k=k, dim=dim) 66 | 67 | topk_cumsum = topk.cumsum(dim) - 1 68 | rhos = _make_ix_like(topk, dim) 69 | support = rhos * topk > topk_cumsum 70 | 71 | support_size = support.sum(dim=dim).unsqueeze(dim) 72 | tau = topk_cumsum.gather(dim, support_size - 1) 73 | tau /= support_size.to(X.dtype) 74 | 75 | if k is not None and k < X.shape[dim]: 76 | unsolved = (support_size == k).squeeze(dim) 77 | 78 | if torch.any(unsolved): 79 | in_ = _roll_last(X, dim)[unsolved] 80 | tau_, ss_ = _sparsemax_threshold_and_support(in_, dim=-1, k=2 * k) 81 | _roll_last(tau, dim)[unsolved] = tau_ 82 | _roll_last(support_size, dim)[unsolved] = ss_ 83 | 84 | return tau, support_size 85 | 86 | 87 | def _entmax_threshold_and_support(X, dim=-1, k=None): 88 | """Core computation for 1.5-entmax: optimal threshold and support size. 89 | 90 | Parameters 91 | ---------- 92 | X : torch.Tensor 93 | The input tensor to compute thresholds over. 94 | 95 | dim : int 96 | The dimension along which to apply 1.5-entmax. 97 | 98 | k : int or None 99 | number of largest elements to partial-sort over. For optimal 100 | performance, should be slightly bigger than the expected number of 101 | nonzeros in the solution. If the solution is more than k-sparse, 102 | this function is recursively called with a 2*k schedule. 103 | If `None`, full sorting is performed from the beginning. 104 | 105 | Returns 106 | ------- 107 | tau : torch.Tensor like `X`, with all but the `dim` dimension intact 108 | the threshold value for each vector 109 | support_size : torch LongTensor, shape like `tau` 110 | the number of nonzeros in each vector. 111 | """ 112 | 113 | if k is None or k >= X.shape[dim]: # do full sort 114 | Xsrt, _ = torch.sort(X, dim=dim, descending=True) 115 | else: 116 | Xsrt, _ = torch.topk(X, k=k, dim=dim) 117 | 118 | rho = _make_ix_like(Xsrt, dim) 119 | mean = Xsrt.cumsum(dim) / rho 120 | mean_sq = (Xsrt ** 2).cumsum(dim) / rho 121 | ss = rho * (mean_sq - mean ** 2) 122 | delta = (1 - ss) / rho 123 | 124 | # NOTE this is not exactly the same as in reference algo 125 | # Fortunately it seems the clamped values never wrongly 126 | # get selected by tau <= sorted_z. Prove this! 127 | delta_nz = torch.clamp(delta, 0) 128 | tau = mean - torch.sqrt(delta_nz) 129 | 130 | support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim) 131 | tau_star = tau.gather(dim, support_size - 1) 132 | 133 | if k is not None and k < X.shape[dim]: 134 | unsolved = (support_size == k).squeeze(dim) 135 | 136 | if torch.any(unsolved): 137 | X_ = _roll_last(X, dim)[unsolved] 138 | tau_, ss_ = _entmax_threshold_and_support(X_, dim=-1, k=2 * k) 139 | _roll_last(tau_star, dim)[unsolved] = tau_ 140 | _roll_last(support_size, dim)[unsolved] = ss_ 141 | 142 | return tau_star, support_size 143 | 144 | 145 | class SparsemaxFunction(Function): 146 | @classmethod 147 | def forward(cls, ctx, X, dim=-1, k=None): 148 | ctx.dim = dim 149 | max_val, _ = X.max(dim=dim, keepdim=True) 150 | X = X - max_val # same numerical stability trick as softmax 151 | tau, supp_size = _sparsemax_threshold_and_support(X, dim=dim, k=k) 152 | output = torch.clamp(X - tau, min=0) 153 | ctx.save_for_backward(supp_size, output) 154 | return output 155 | 156 | @classmethod 157 | def backward(cls, ctx, grad_output): 158 | supp_size, output = ctx.saved_tensors 159 | dim = ctx.dim 160 | grad_input = grad_output.clone() 161 | grad_input[output == 0] = 0 162 | 163 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze(dim) 164 | v_hat = v_hat.unsqueeze(dim) 165 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 166 | return grad_input, None, None 167 | 168 | 169 | class Entmax15Function(Function): 170 | @classmethod 171 | def forward(cls, ctx, X, dim=0, k=None): 172 | ctx.dim = dim 173 | 174 | max_val, _ = X.max(dim=dim, keepdim=True) 175 | X = X - max_val # same numerical stability trick as for softmax 176 | X = X / 2 # divide by 2 to solve actual Entmax 177 | 178 | tau_star, _ = _entmax_threshold_and_support(X, dim=dim, k=k) 179 | 180 | Y = torch.clamp(X - tau_star, min=0) ** 2 181 | ctx.save_for_backward(Y) 182 | return Y 183 | 184 | @classmethod 185 | def backward(cls, ctx, dY): 186 | Y, = ctx.saved_tensors 187 | gppr = Y.sqrt() # = 1 / g'' (Y) 188 | dX = dY * gppr 189 | q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) 190 | q = q.unsqueeze(ctx.dim) 191 | dX -= q * gppr 192 | return dX, None, None 193 | 194 | 195 | def sparsemax(X, dim=-1, k=None): 196 | """sparsemax: normalizing sparse transform (a la softmax). 197 | 198 | Solves the projection: 199 | 200 | min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1. 201 | 202 | Parameters 203 | ---------- 204 | X : torch.Tensor 205 | The input tensor. 206 | 207 | dim : int 208 | The dimension along which to apply sparsemax. 209 | 210 | k : int or None 211 | number of largest elements to partial-sort over. For optimal 212 | performance, should be slightly bigger than the expected number of 213 | nonzeros in the solution. If the solution is more than k-sparse, 214 | this function is recursively called with a 2*k schedule. 215 | If `None`, full sorting is performed from the beginning. 216 | 217 | Returns 218 | ------- 219 | P : torch tensor, same shape as X 220 | The projection result, such that P.sum(dim=dim) == 1 elementwise. 221 | """ 222 | 223 | return SparsemaxFunction.apply(X, dim, k) 224 | 225 | 226 | def entmax15(X, dim=-1, k=None): 227 | """1.5-entmax: normalizing sparse transform (a la softmax). 228 | 229 | Solves the optimization problem: 230 | 231 | max_p - H_1.5(p) s.t. p >= 0, sum(p) == 1. 232 | 233 | where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5. 234 | 235 | Parameters 236 | ---------- 237 | X : torch.Tensor 238 | The input tensor. 239 | 240 | dim : int 241 | The dimension along which to apply 1.5-entmax. 242 | 243 | k : int or None 244 | number of largest elements to partial-sort over. For optimal 245 | performance, should be slightly bigger than the expected number of 246 | nonzeros in the solution. If the solution is more than k-sparse, 247 | this function is recursively called with a 2*k schedule. 248 | If `None`, full sorting is performed from the beginning. 249 | 250 | Returns 251 | ------- 252 | P : torch tensor, same shape as X 253 | The projection result, such that P.sum(dim=dim) == 1 elementwise. 254 | """ 255 | 256 | return Entmax15Function.apply(X, dim, k) 257 | 258 | 259 | class Sparsemax(nn.Module): 260 | def __init__(self, dim=-1, k=None): 261 | """sparsemax: normalizing sparse transform (a la softmax). 262 | 263 | Solves the projection: 264 | 265 | min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1. 266 | 267 | Parameters 268 | ---------- 269 | dim : int 270 | The dimension along which to apply sparsemax. 271 | 272 | k : int or None 273 | number of largest elements to partial-sort over. For optimal 274 | performance, should be slightly bigger than the expected number of 275 | nonzeros in the solution. If the solution is more than k-sparse, 276 | this function is recursively called with a 2*k schedule. 277 | If `None`, full sorting is performed from the beginning. 278 | """ 279 | self.dim = dim 280 | self.k = k 281 | super(Sparsemax, self).__init__() 282 | 283 | def forward(self, X): 284 | return sparsemax(X, dim=self.dim, k=self.k) 285 | 286 | 287 | class Entmax15(nn.Module): 288 | def __init__(self, dim=-1, k=None): 289 | """1.5-entmax: normalizing sparse transform (a la softmax). 290 | 291 | Solves the optimization problem: 292 | 293 | max_p - H_1.5(p) s.t. p >= 0, sum(p) == 1. 294 | 295 | where H_1.5(p) is the Tsallis alpha-entropy with alpha=1.5. 296 | 297 | Parameters 298 | ---------- 299 | dim : int 300 | The dimension along which to apply 1.5-entmax. 301 | 302 | k : int or None 303 | number of largest elements to partial-sort over. For optimal 304 | performance, should be slightly bigger than the expected number of 305 | nonzeros in the solution. If the solution is more than k-sparse, 306 | this function is recursively called with a 2*k schedule. 307 | If `None`, full sorting is performed from the beginning. 308 | """ 309 | self.dim = dim 310 | self.k = k 311 | super(Entmax15, self).__init__() 312 | 313 | def forward(self, X): 314 | return entmax15(X, dim=self.dim, k=self.k) 315 | -------------------------------------------------------------------------------- /code/entmax/entmax.py: -------------------------------------------------------------------------------- 1 | ## Implementation of Entmax has been adapted from https://github.com/deep-spin/entmax/ 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | home = str(Path.home()) 9 | 10 | 11 | class AlphaChooser(torch.nn.Module): 12 | def __init__(self, head_count): 13 | super(AlphaChooser, self).__init__() 14 | self.pre_alpha = nn.Parameter(torch.randn(head_count)) 15 | 16 | def forward(self): 17 | alpha = 1 + torch.sigmoid(self.pre_alpha) 18 | return torch.clamp(alpha, min=1.01, max=2) 19 | 20 | 21 | class EntmaxAlpha(nn.Module): 22 | def __init__(self, head_count, dim=0): 23 | super(EntmaxAlpha, self).__init__() 24 | self.dim = dim 25 | self.alpha_chooser = nn.Parameter(AlphaChooser(head_count)()) 26 | self.alpha = self.alpha_chooser 27 | 28 | def forward(self, att_scores): 29 | batch_size, head_count, query_len, key_len = att_scores.size() 30 | 31 | expanded_alpha = ( 32 | self.alpha.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 33 | ) # [1,nb_heads,1,1] 34 | expanded_alpha = expanded_alpha.expand( 35 | (batch_size, -1, query_len, 1) 36 | ) # [bs, nb_heads, query_len,1] 37 | p_star = entmax_bisect(att_scores, expanded_alpha) 38 | return p_star 39 | 40 | 41 | class EntmaxBisectFunction(Function): 42 | @classmethod 43 | def _gp(cls, x, alpha): 44 | return x ** (alpha - 1) 45 | 46 | @classmethod 47 | def _gp_inv(cls, y, alpha): 48 | return y ** (1 / (alpha - 1)) 49 | 50 | @classmethod 51 | def _p(cls, X, alpha): 52 | return cls._gp_inv(torch.clamp(X, min=0), alpha) 53 | 54 | @classmethod 55 | def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True): 56 | 57 | if not isinstance(alpha, torch.Tensor): 58 | alpha = torch.tensor(alpha, dtype=X.dtype, device=X.device) 59 | 60 | alpha_shape = list(X.shape) 61 | alpha_shape[dim] = 1 62 | alpha = alpha.expand(*alpha_shape) 63 | 64 | ctx.alpha = alpha 65 | ctx.dim = dim 66 | d = X.shape[dim] 67 | 68 | X = X * (alpha - 1) 69 | 70 | max_val, _ = X.max(dim=dim, keepdim=True) 71 | 72 | tau_lo = max_val - cls._gp(1, alpha) 73 | tau_hi = max_val - cls._gp(1 / d, alpha) 74 | 75 | f_lo = cls._p(X - tau_lo, alpha).sum(dim) - 1 76 | 77 | dm = tau_hi - tau_lo 78 | 79 | for it in range(n_iter): 80 | 81 | dm /= 2 82 | tau_m = tau_lo + dm 83 | p_m = cls._p(X - tau_m, alpha) 84 | f_m = p_m.sum(dim) - 1 85 | 86 | mask = (f_m * f_lo >= 0).unsqueeze(dim) 87 | tau_lo = torch.where(mask, tau_m, tau_lo) 88 | 89 | if ensure_sum_one: 90 | p_m /= p_m.sum(dim=dim).unsqueeze(dim=dim) 91 | 92 | ctx.save_for_backward(p_m) 93 | 94 | return p_m 95 | 96 | @classmethod 97 | def backward(cls, ctx, dY): 98 | (Y,) = ctx.saved_tensors 99 | 100 | gppr = torch.where(Y > 0, Y ** (2 - ctx.alpha), Y.new_zeros(1)) 101 | 102 | dX = dY * gppr 103 | q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) 104 | q = q.unsqueeze(ctx.dim) 105 | dX -= q * gppr 106 | 107 | d_alpha = None 108 | if ctx.needs_input_grad[1]: 109 | 110 | # alpha gradient computation 111 | # d_alpha = (partial_y / partial_alpha) * dY 112 | # NOTE: ensure alpha is not close to 1 113 | # since there is an indetermination 114 | # batch_size, _ = dY.shape 115 | 116 | # shannon terms 117 | S = torch.where(Y > 0, Y * torch.log(Y), Y.new_zeros(1)) 118 | # shannon entropy 119 | ent = S.sum(ctx.dim).unsqueeze(ctx.dim) 120 | Y_skewed = gppr / gppr.sum(ctx.dim).unsqueeze(ctx.dim) 121 | 122 | d_alpha = dY * (Y - Y_skewed) / ((ctx.alpha - 1) ** 2) 123 | d_alpha -= dY * (S - Y_skewed * ent) / (ctx.alpha - 1) 124 | d_alpha = d_alpha.sum(ctx.dim).unsqueeze(ctx.dim) 125 | 126 | return dX, d_alpha, None, None, None 127 | 128 | 129 | def entmax_bisect(X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True): 130 | """alpha-entmax: normalizing sparse transform (a la softmax). 131 | Solves the optimization problem: 132 | max_p - H_a(p) s.t. p >= 0, sum(p) == 1. 133 | where H_a(p) is the Tsallis alpha-entropy with custom alpha >= 1, 134 | using a bisection (root finding, binary search) algorithm. 135 | This function is differentiable with respect to both X and alpha. 136 | Parameters 137 | ---------- 138 | X : torch.Tensor 139 | The input tensor. 140 | alpha : float or torch.Tensor 141 | Tensor of alpha parameters (> 1) to use. If scalar 142 | or python float, the same value is used for all rows, otherwise, 143 | it must have shape (or be expandable to) 144 | alpha.shape[j] == (X.shape[j] if j != dim else 1) 145 | A value of alpha=2 corresponds to sparsemax, and alpha=1 corresponds to 146 | softmax (but computing it this way is likely unstable). 147 | dim : int 148 | The dimension along which to apply alpha-entmax. 149 | n_iter : int 150 | Number of bisection iterations. For float32, 24 iterations should 151 | suffice for machine precision. 152 | ensure_sum_one : bool, 153 | Whether to divide the result by its sum. If false, the result might 154 | sum to close but not exactly 1, which might cause downstream problems. 155 | Returns 156 | ------- 157 | P : torch tensor, same shape as X 158 | The projection result, such that P.sum(dim=dim) == 1 elementwise. 159 | """ 160 | return EntmaxBisectFunction.apply(X, alpha, dim, n_iter, ensure_sum_one) 161 | -------------------------------------------------------------------------------- /code/entmax/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | 5 | from entmax.activations import sparsemax, entmax15 6 | from entmax.root_finding import entmax_bisect, sparsemax_bisect 7 | 8 | 9 | class _GenericLoss(nn.Module): 10 | def __init__(self, ignore_index=-100, reduction="elementwise_mean"): 11 | assert reduction in ["elementwise_mean", "sum", "none"] 12 | self.reduction = reduction 13 | self.ignore_index = ignore_index 14 | super(_GenericLoss, self).__init__() 15 | 16 | def forward(self, X, target): 17 | loss = self.loss(X, target) 18 | if self.ignore_index >= 0: 19 | ignored_positions = target == self.ignore_index 20 | size = float((target.size(0) - ignored_positions.sum()).item()) 21 | loss.masked_fill_(ignored_positions, 0.0) 22 | else: 23 | size = float(target.size(0)) 24 | if self.reduction == "sum": 25 | loss = loss.sum() 26 | elif self.reduction == "elementwise_mean": 27 | loss = loss.sum() / size 28 | return loss 29 | 30 | 31 | class _GenericLossFunction(Function): 32 | @classmethod 33 | def forward(cls, ctx, X, target, alpha, proj_args): 34 | """ 35 | X (FloatTensor): n x num_classes 36 | target (LongTensor): n, the indices of the target classes 37 | """ 38 | assert X.shape[0] == target.shape[0] 39 | 40 | p_star = cls.project(X, alpha, **proj_args) 41 | loss = cls.omega(p_star, alpha) 42 | 43 | p_star.scatter_add_(1, target.unsqueeze(1), torch.full_like(p_star, -1)) 44 | loss += torch.einsum("ij,ij->i", p_star, X) 45 | ctx.save_for_backward(p_star) 46 | 47 | return loss 48 | 49 | @classmethod 50 | def backward(cls, ctx, grad_output): 51 | p_star, = ctx.saved_tensors 52 | grad = grad_output.unsqueeze(1) * p_star 53 | ret = (grad,) 54 | 55 | # pad with as many Nones as needed 56 | return ret + (None,) * (1 + cls.n_fwd_args) 57 | 58 | 59 | class SparsemaxLossFunction(_GenericLossFunction): 60 | 61 | n_fwd_args = 1 62 | 63 | @classmethod 64 | def project(cls, X, alpha, k): 65 | return sparsemax(X, dim=-1, k=k) 66 | 67 | @classmethod 68 | def omega(cls, p_star, alpha): 69 | return (1 - (p_star ** 2).sum(dim=1)) / 2 70 | 71 | @classmethod 72 | def forward(cls, ctx, X, target, k=None): 73 | return super().forward(ctx, X, target, alpha=2, proj_args=dict(k=k)) 74 | 75 | 76 | class SparsemaxBisectLossFunction(_GenericLossFunction): 77 | 78 | n_fwd_args = 1 79 | 80 | @classmethod 81 | def project(cls, X, alpha, n_iter): 82 | return sparsemax_bisect(X, n_iter=n_iter) 83 | 84 | @classmethod 85 | def omega(cls, p_star, alpha): 86 | return (1 - (p_star ** 2).sum(dim=1)) / 2 87 | 88 | @classmethod 89 | def forward(cls, ctx, X, target, n_iter=50): 90 | return super().forward( 91 | ctx, X, target, alpha=2, proj_args=dict(n_iter=n_iter) 92 | ) 93 | 94 | 95 | class Entmax15LossFunction(_GenericLossFunction): 96 | 97 | n_fwd_args = 1 98 | 99 | @classmethod 100 | def project(cls, X, alpha, k=None): 101 | return entmax15(X, dim=-1, k=k) 102 | 103 | @classmethod 104 | def omega(cls, p_star, alpha): 105 | return (1 - (p_star * torch.sqrt(p_star)).sum(dim=1)) / 0.75 106 | 107 | @classmethod 108 | def forward(cls, ctx, X, target, k=None): 109 | return super().forward(ctx, X, target, alpha=1.5, proj_args=dict(k=k)) 110 | 111 | 112 | class EntmaxBisectLossFunction(_GenericLossFunction): 113 | 114 | n_fwd_args = 2 115 | 116 | @classmethod 117 | def project(cls, X, alpha, n_iter): 118 | return entmax_bisect(X, alpha=alpha, n_iter=n_iter, ensure_sum_one=True) 119 | 120 | @classmethod 121 | def omega(cls, p_star, alpha): 122 | return (1 - (p_star ** alpha).sum(dim=1)) / (alpha * (alpha - 1)) 123 | 124 | @classmethod 125 | def forward(cls, ctx, X, target, alpha=1.5, n_iter=50): 126 | return super().forward( 127 | ctx, X, target, alpha, proj_args=dict(n_iter=n_iter) 128 | ) 129 | 130 | 131 | def sparsemax_loss(X, target, k=None): 132 | """sparsemax loss: sparse alternative to cross-entropy 133 | 134 | Computed using a partial sorting strategy. 135 | 136 | Parameters 137 | ---------- 138 | X : torch.Tensor, shape=(n_samples, n_classes) 139 | The input 2D tensor of predicted scores 140 | 141 | target : torch.LongTensor, shape=(n_samples,) 142 | The ground truth labels, 0 <= target < n_classes. 143 | 144 | k : int or None 145 | number of largest elements to partial-sort over. For optimal 146 | performance, should be slightly bigger than the expected number of 147 | nonzeros in the solution. If the solution is more than k-sparse, 148 | this function is recursively called with a 2*k schedule. 149 | If `None`, full sorting is performed from the beginning. 150 | 151 | Returns 152 | ------- 153 | losses, torch.Tensor, shape=(n_samples,) 154 | The loss incurred at each sample. 155 | """ 156 | return SparsemaxLossFunction.apply(X, target, k) 157 | 158 | 159 | def sparsemax_bisect_loss(X, target, n_iter=50): 160 | """sparsemax loss: sparse alternative to cross-entropy 161 | 162 | Computed using bisection. 163 | 164 | Parameters 165 | ---------- 166 | X : torch.Tensor, shape=(n_samples, n_classes) 167 | The input 2D tensor of predicted scores 168 | 169 | target : torch.LongTensor, shape=(n_samples,) 170 | The ground truth labels, 0 <= target < n_classes. 171 | 172 | n_iter : int 173 | Number of bisection iterations. For float32, 24 iterations should 174 | suffice for machine precision. 175 | 176 | Returns 177 | ------- 178 | losses, torch.Tensor, shape=(n_samples,) 179 | The loss incurred at each sample. 180 | """ 181 | return SparsemaxBisectLossFunction.apply(X, target, n_iter) 182 | 183 | 184 | def entmax15_loss(X, target, k=None): 185 | """1.5-entmax loss: sparse alternative to cross-entropy 186 | 187 | Computed using a partial sorting strategy. 188 | 189 | Parameters 190 | ---------- 191 | X : torch.Tensor, shape=(n_samples, n_classes) 192 | The input 2D tensor of predicted scores 193 | 194 | target : torch.LongTensor, shape=(n_samples,) 195 | The ground truth labels, 0 <= target < n_classes. 196 | 197 | k : int or None 198 | number of largest elements to partial-sort over. For optimal 199 | performance, should be slightly bigger than the expected number of 200 | nonzeros in the solution. If the solution is more than k-sparse, 201 | this function is recursively called with a 2*k schedule. 202 | If `None`, full sorting is performed from the beginning. 203 | 204 | Returns 205 | ------- 206 | losses, torch.Tensor, shape=(n_samples,) 207 | The loss incurred at each sample. 208 | """ 209 | return Entmax15LossFunction.apply(X, target, k) 210 | 211 | 212 | def entmax_bisect_loss(X, target, alpha=1.5, n_iter=50): 213 | """alpha-entmax loss: sparse alternative to cross-entropy 214 | 215 | Computed using bisection, supporting arbitrary alpha > 1. 216 | 217 | Parameters 218 | ---------- 219 | X : torch.Tensor, shape=(n_samples, n_classes) 220 | The input 2D tensor of predicted scores 221 | 222 | target : torch.LongTensor, shape=(n_samples,) 223 | The ground truth labels, 0 <= target < n_classes. 224 | 225 | alpha : float or torch.Tensor 226 | Tensor of alpha parameters (> 1) to use for each row of X. If scalar 227 | or python float, the same value is used for all rows. A value of 228 | alpha=2 corresponds to sparsemax, and alpha=1 corresponds to softmax 229 | (but computing it this way is likely unstable). 230 | 231 | n_iter : int 232 | Number of bisection iterations. For float32, 24 iterations should 233 | suffice for machine precision. 234 | 235 | Returns 236 | ------- 237 | losses, torch.Tensor, shape=(n_samples,) 238 | The loss incurred at each sample. 239 | """ 240 | return EntmaxBisectLossFunction.apply(X, target, alpha, n_iter) 241 | 242 | 243 | class SparsemaxBisectLoss(_GenericLoss): 244 | def __init__( 245 | self, n_iter=50, ignore_index=-100, reduction="elementwise_mean" 246 | ): 247 | self.n_iter = n_iter 248 | super(SparsemaxBisectLoss, self).__init__(ignore_index, reduction) 249 | 250 | def loss(self, X, target): 251 | return sparsemax_bisect_loss(X, target, self.n_iter) 252 | 253 | 254 | class SparsemaxLoss(_GenericLoss): 255 | def __init__(self, k=None, ignore_index=-100, reduction="elementwise_mean"): 256 | self.k = k 257 | super(SparsemaxLoss, self).__init__(ignore_index, reduction) 258 | 259 | def loss(self, X, target): 260 | return sparsemax_loss(X, target, self.k) 261 | 262 | 263 | class EntmaxBisectLoss(_GenericLoss): 264 | def __init__( 265 | self, 266 | alpha=1.5, 267 | n_iter=50, 268 | ignore_index=-100, 269 | reduction="elementwise_mean", 270 | ): 271 | self.alpha = alpha 272 | self.n_iter = n_iter 273 | super(EntmaxBisectLoss, self).__init__(ignore_index, reduction) 274 | 275 | def loss(self, X, target): 276 | return entmax_bisect_loss(X, target, self.alpha, self.n_iter) 277 | 278 | 279 | class Entmax15Loss(_GenericLoss): 280 | def __init__(self, k=100, ignore_index=-100, reduction="elementwise_mean"): 281 | self.k = k 282 | super(Entmax15Loss, self).__init__(ignore_index, reduction) 283 | 284 | def loss(self, X, target): 285 | return entmax15_loss(X, target, self.k) 286 | -------------------------------------------------------------------------------- /code/entmax/root_finding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bisection implementation of alpha-entmax (Peters et al., 2019). 3 | Backward pass wrt alpha per (Correia et al., 2019). See 4 | https://arxiv.org/pdf/1905.05702 for detailed description. 5 | """ 6 | # Author: Goncalo M Correia 7 | # Author: Ben Peters 8 | # Author: Vlad Niculae 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Function 13 | 14 | 15 | class EntmaxBisectFunction(Function): 16 | @classmethod 17 | def _gp(cls, x, alpha): 18 | return x ** (alpha - 1) 19 | 20 | @classmethod 21 | def _gp_inv(cls, y, alpha): 22 | return y ** (1 / (alpha - 1)) 23 | 24 | @classmethod 25 | def _p(cls, X, alpha): 26 | return cls._gp_inv(torch.clamp(X, min=0), alpha) 27 | 28 | @classmethod 29 | def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True): 30 | 31 | if not isinstance(alpha, torch.Tensor): 32 | alpha = torch.tensor(alpha, dtype=X.dtype, device=X.device) 33 | # if alpha.device != X.device: 34 | # alpha = alpha.to(X.device) 35 | 36 | alpha_shape = list(X.shape) 37 | alpha_shape[dim] = 1 38 | alpha = alpha.expand(*alpha_shape) 39 | 40 | ctx.alpha = alpha 41 | ctx.dim = dim 42 | d = X.shape[dim] 43 | 44 | X = X * (alpha - 1) 45 | 46 | max_val, _ = X.max(dim=dim, keepdim=True) 47 | 48 | tau_lo = max_val - cls._gp(1, alpha) 49 | tau_hi = max_val - cls._gp(1 / d, alpha) 50 | 51 | f_lo = cls._p(X - tau_lo, alpha).sum(dim) - 1 52 | 53 | dm = tau_hi - tau_lo 54 | 55 | for it in range(n_iter): 56 | 57 | dm /= 2 58 | tau_m = tau_lo + dm 59 | p_m = cls._p(X - tau_m, alpha) 60 | f_m = p_m.sum(dim) - 1 61 | 62 | mask = (f_m * f_lo >= 0).unsqueeze(dim) 63 | tau_lo = torch.where(mask, tau_m, tau_lo) 64 | 65 | if ensure_sum_one: 66 | p_m /= p_m.sum(dim=dim).unsqueeze(dim=dim) 67 | 68 | ctx.save_for_backward(p_m) 69 | 70 | return p_m 71 | 72 | @classmethod 73 | def backward(cls, ctx, dY): 74 | Y, = ctx.saved_tensors 75 | 76 | gppr = torch.where(Y > 0, Y ** (2 - ctx.alpha), Y.new_zeros(1)) 77 | 78 | dX = dY * gppr 79 | q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) 80 | q = q.unsqueeze(ctx.dim) 81 | dX -= q * gppr 82 | 83 | d_alpha = None 84 | if ctx.needs_input_grad[1]: 85 | 86 | # alpha gradient computation 87 | # d_alpha = (partial_y / partial_alpha) * dY 88 | # NOTE: ensure alpha is not close to 1 89 | # since there is an indetermination 90 | # batch_size, _ = dY.shape 91 | 92 | # shannon terms 93 | S = torch.where(Y > 0, Y * torch.log(Y), Y.new_zeros(1)) 94 | # shannon entropy 95 | ent = S.sum(ctx.dim).unsqueeze(ctx.dim) 96 | Y_skewed = gppr / gppr.sum(ctx.dim).unsqueeze(ctx.dim) 97 | 98 | d_alpha = dY * (Y - Y_skewed) / ((ctx.alpha - 1) ** 2) 99 | d_alpha -= dY * (S - Y_skewed * ent) / (ctx.alpha - 1) 100 | d_alpha = d_alpha.sum(ctx.dim).unsqueeze(ctx.dim) 101 | 102 | return dX, d_alpha, None, None, None 103 | 104 | 105 | # slightly more efficient special case for sparsemax 106 | class SparsemaxBisectFunction(EntmaxBisectFunction): 107 | @classmethod 108 | def _gp(cls, x, alpha): 109 | return x 110 | 111 | @classmethod 112 | def _gp_inv(cls, y, alpha): 113 | return y 114 | 115 | @classmethod 116 | def _p(cls, x, alpha): 117 | return torch.clamp(x, min=0) 118 | 119 | @classmethod 120 | def forward(cls, ctx, X, dim=-1, n_iter=50, ensure_sum_one=True): 121 | return super().forward( 122 | ctx, X, alpha=2, dim=dim, n_iter=50, ensure_sum_one=True 123 | ) 124 | 125 | @classmethod 126 | def backward(cls, ctx, dY): 127 | Y, = ctx.saved_tensors 128 | gppr = (Y > 0).to(dtype=dY.dtype) 129 | dX = dY * gppr 130 | q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) 131 | q = q.unsqueeze(ctx.dim) 132 | dX -= q * gppr 133 | return dX, None, None, None 134 | 135 | 136 | def entmax_bisect(X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True): 137 | """alpha-entmax: normalizing sparse transform (a la softmax). 138 | 139 | Solves the optimization problem: 140 | 141 | max_p - H_a(p) s.t. p >= 0, sum(p) == 1. 142 | 143 | where H_a(p) is the Tsallis alpha-entropy with custom alpha >= 1, 144 | using a bisection (root finding, binary search) algorithm. 145 | 146 | This function is differentiable with respect to both X and alpha. 147 | 148 | Parameters 149 | ---------- 150 | X : torch.Tensor 151 | The input tensor. 152 | 153 | alpha : float or torch.Tensor 154 | Tensor of alpha parameters (> 1) to use. If scalar 155 | or python float, the same value is used for all rows, otherwise, 156 | it must have shape (or be expandable to) 157 | alpha.shape[j] == (X.shape[j] if j != dim else 1) 158 | A value of alpha=2 corresponds to sparsemax, and alpha=1 corresponds to 159 | softmax (but computing it this way is likely unstable). 160 | 161 | dim : int 162 | The dimension along which to apply alpha-entmax. 163 | 164 | n_iter : int 165 | Number of bisection iterations. For float32, 24 iterations should 166 | suffice for machine precision. 167 | 168 | ensure_sum_one : bool, 169 | Whether to divide the result by its sum. If false, the result might 170 | sum to close but not exactly 1, which might cause downstream problems. 171 | 172 | Returns 173 | ------- 174 | P : torch tensor, same shape as X 175 | The projection result, such that P.sum(dim=dim) == 1 elementwise. 176 | """ 177 | return EntmaxBisectFunction.apply(X, alpha, dim, n_iter, ensure_sum_one) 178 | 179 | 180 | def sparsemax_bisect(X, dim=-1, n_iter=50, ensure_sum_one=True): 181 | """sparsemax: normalizing sparse transform (a la softmax), via bisection. 182 | 183 | Solves the projection: 184 | 185 | min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1. 186 | 187 | Parameters 188 | ---------- 189 | X : torch.Tensor 190 | The input tensor. 191 | 192 | dim : int 193 | The dimension along which to apply sparsemax. 194 | 195 | n_iter : int 196 | Number of bisection iterations. For float32, 24 iterations should 197 | suffice for machine precision. 198 | 199 | ensure_sum_one : bool, 200 | Whether to divide the result by its sum. If false, the result might 201 | sum to close but not exactly 1, which might cause downstream problems. 202 | 203 | Note: This function does not yet support normalizing along anything except 204 | the last dimension. Please use transposing and views to achieve more 205 | general behavior. 206 | 207 | Returns 208 | ------- 209 | P : torch tensor, same shape as X 210 | The projection result, such that P.sum(dim=dim) == 1 elementwise. 211 | """ 212 | return SparsemaxBisectFunction.apply(X, dim, n_iter, ensure_sum_one) 213 | 214 | 215 | class SparsemaxBisect(nn.Module): 216 | def __init__(self, dim=-1, n_iter=None): 217 | """sparsemax: normalizing sparse transform (a la softmax) via bisection 218 | 219 | Solves the projection: 220 | 221 | min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1. 222 | 223 | Parameters 224 | ---------- 225 | dim : int 226 | The dimension along which to apply sparsemax. 227 | 228 | n_iter : int 229 | Number of bisection iterations. For float32, 24 iterations should 230 | suffice for machine precision. 231 | """ 232 | self.dim = dim 233 | self.n_iter = n_iter 234 | super().__init__() 235 | 236 | def forward(self, X): 237 | return sparsemax_bisect(X, dim=self.dim, n_iter=self.n_iter) 238 | 239 | 240 | class EntmaxBisect(nn.Module): 241 | def __init__(self, alpha=1.5, dim=-1, n_iter=50): 242 | """alpha-entmax: normalizing sparse map (a la softmax) via bisection. 243 | 244 | Solves the optimization problem: 245 | 246 | max_p - H_a(p) s.t. p >= 0, sum(p) == 1. 247 | 248 | where H_a(p) is the Tsallis alpha-entropy with custom alpha >= 1, 249 | using a bisection (root finding, binary search) algorithm. 250 | 251 | Parameters 252 | ---------- 253 | alpha : float or torch.Tensor 254 | Tensor of alpha parameters (> 1) to use. If scalar 255 | or python float, the same value is used for all rows, otherwise, 256 | it must have shape (or be expandable to) 257 | alpha.shape[j] == (X.shape[j] if j != dim else 1) 258 | A value of alpha=2 corresponds to sparsemax; alpha=1 corresponds 259 | to softmax (but computing it this way is likely unstable). 260 | 261 | dim : int 262 | The dimension along which to apply alpha-entmax. 263 | 264 | n_iter : int 265 | Number of bisection iterations. For float32, 24 iterations should 266 | suffice for machine precision. 267 | 268 | """ 269 | self.dim = dim 270 | self.n_iter = n_iter 271 | self.alpha = alpha 272 | super().__init__() 273 | 274 | def forward(self, X): 275 | return entmax_bisect( 276 | X, alpha=self.alpha, dim=self.dim, n_iter=self.n_iter 277 | ) 278 | -------------------------------------------------------------------------------- /code/entmax/test_losses.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.autograd import gradcheck, grad 4 | from functools import partial 5 | 6 | from entmax.losses import ( 7 | SparsemaxLoss, 8 | Entmax15Loss, 9 | SparsemaxBisectLoss, 10 | EntmaxBisectLoss, 11 | ) 12 | 13 | 14 | # make data 15 | Xs = [ 16 | torch.randn(4, 10, dtype=torch.float64, requires_grad=True) 17 | for _ in range(5) 18 | ] 19 | 20 | ys = [torch.max(torch.randn_like(X), dim=1)[1] for X in Xs] 21 | 22 | 23 | losses = [ 24 | SparsemaxLoss, 25 | partial(SparsemaxLoss, k=5), 26 | Entmax15Loss, 27 | partial(Entmax15Loss, k=5), 28 | SparsemaxBisectLoss, 29 | EntmaxBisectLoss, 30 | ] 31 | 32 | 33 | @pytest.mark.parametrize("Loss", losses) 34 | def test_non_neg(Loss): 35 | 36 | for X, y in zip(Xs, ys): 37 | ls = Loss(reduction="none") 38 | lval = ls(X, y) 39 | assert torch.all(lval >= 0) 40 | 41 | 42 | @pytest.mark.parametrize("Loss", losses) 43 | @pytest.mark.parametrize("ignore_index", (False, True)) 44 | @pytest.mark.parametrize("reduction", ("sum", "elementwise_mean")) 45 | def test_loss(Loss, ignore_index, reduction): 46 | 47 | for X, y in zip(Xs, ys): 48 | iix = y[0] if ignore_index else -100 49 | ls = Loss(ignore_index=iix, reduction=reduction) 50 | gradcheck(ls, (X, y), eps=1e-5) 51 | 52 | 53 | @pytest.mark.parametrize("Loss", losses) 54 | def test_index_ignored(Loss): 55 | 56 | x = torch.randn(20, 6, dtype=torch.float64, requires_grad=True) 57 | _, y = torch.max(torch.randn_like(x), dim=1) 58 | 59 | loss_ignore = Loss(reduction="sum", ignore_index=y[0]) 60 | loss_noignore = Loss(reduction="sum", ignore_index=-100) 61 | 62 | assert loss_ignore(x, y) < loss_noignore(x, y) 63 | 64 | 65 | if __name__ == "__main__": 66 | test_sparsemax_loss() 67 | test_entmax_loss() 68 | test_sparsemax_bisect_loss() 69 | test_entmax_bisect_loss() 70 | -------------------------------------------------------------------------------- /code/entmax/test_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from entmax.activations import Sparsemax, Entmax15 5 | 6 | from entmax.root_finding import sparsemax_bisect, entmax_bisect 7 | 8 | funcs = [ 9 | Sparsemax(dim=1), 10 | Entmax15(dim=1), 11 | Sparsemax(dim=1, k=512), 12 | Entmax15(dim=1, k=512), 13 | sparsemax_bisect, 14 | entmax_bisect, 15 | ] 16 | 17 | 18 | @pytest.mark.parametrize("func", funcs) 19 | @pytest.mark.parametrize("dtype", (torch.float32, torch.float64)) 20 | def test_mask(func, dtype): 21 | torch.manual_seed(42) 22 | x = torch.randn(2, 6, dtype=dtype) 23 | x[:, 3:] = -float("inf") 24 | x0 = x[:, :3] 25 | 26 | y = func(x) 27 | y0 = func(x0) 28 | 29 | y[:, :3] -= y0 30 | 31 | assert torch.allclose(y, torch.zeros_like(y)) 32 | 33 | 34 | @pytest.mark.parametrize("alpha", (1.25, 1.5, 1.75, 2.25)) 35 | def test_mask_alphas(alpha): 36 | torch.manual_seed(42) 37 | x = torch.randn(2, 6) 38 | x[:, 3:] = -float("inf") 39 | x0 = x[:, :3] 40 | 41 | y = entmax_bisect(x, alpha) 42 | y0 = entmax_bisect(x0, alpha) 43 | 44 | y[:, :3] -= y0 45 | 46 | assert torch.allclose(y, torch.zeros_like(y)) 47 | -------------------------------------------------------------------------------- /code/entmax/test_root_finding.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from itertools import product 3 | from functools import partial 4 | 5 | import torch 6 | from torch.autograd import gradcheck 7 | 8 | from entmax.root_finding import sparsemax_bisect, entmax_bisect 9 | from entmax.activations import sparsemax, entmax15 10 | 11 | 12 | # @pytest.mark.parametrize("dim", (0, 1, 2)) 13 | # def test_dim(dim, Map): 14 | # for _ in range(10): 15 | # x = torch.randn(5, 6, 7, requires_grad=True, dtype=torch.float64) 16 | # # gradcheck(f, (x,)) 17 | 18 | 19 | def test_sparsemax(): 20 | x = 0.5 * torch.randn(4, 6, dtype=torch.float32) 21 | p1 = sparsemax(x, 1) 22 | p2 = sparsemax_bisect(x) 23 | assert torch.sum((p1 - p2) ** 2) < 1e-7 24 | 25 | 26 | def test_entmax15(): 27 | x = 0.5 * torch.randn(4, 6, dtype=torch.float32) 28 | p1 = entmax15(x, 1) 29 | p2 = entmax_bisect(x, alpha=1.5) 30 | assert torch.sum((p1 - p2) ** 2) < 1e-7 31 | 32 | 33 | def test_sparsemax_grad(): 34 | x = torch.randn(4, 6, dtype=torch.float64, requires_grad=True) 35 | gradcheck(sparsemax_bisect, (x,), eps=1e-5) 36 | 37 | 38 | @pytest.mark.parametrize("alpha", (1.2, 1.5, 1.75, 2.25)) 39 | def test_entmax_grad(alpha): 40 | alpha = torch.tensor(alpha, dtype=torch.float64, requires_grad=True) 41 | x = torch.randn(4, 6, dtype=torch.float64, requires_grad=True) 42 | gradcheck(entmax_bisect, (x, alpha), eps=1e-5) 43 | 44 | 45 | def test_entmax_correct_multiple_alphas(): 46 | n = 4 47 | x = torch.randn(n, 6, dtype=torch.float64, requires_grad=True) 48 | alpha = 1.05 + torch.rand((n, 1), dtype=torch.float64, requires_grad=True) 49 | 50 | p1 = entmax_bisect(x, alpha) 51 | p2_ = [ 52 | entmax_bisect(x[i].unsqueeze(0), alpha[i].item()).squeeze() 53 | for i in range(n) 54 | ] 55 | p2 = torch.stack(p2_) 56 | 57 | assert torch.allclose(p1, p2) 58 | 59 | 60 | def test_entmax_grad_multiple_alphas(): 61 | 62 | n = 4 63 | x = torch.randn(n, 6, dtype=torch.float64, requires_grad=True) 64 | alpha = 1.05 + torch.rand((n, 1), dtype=torch.float64, requires_grad=True) 65 | gradcheck(entmax_bisect, (x, alpha), eps=1e-5) 66 | 67 | 68 | @pytest.mark.parametrize("dim", (0, 1, 2, 3)) 69 | def test_arbitrary_dimension(dim): 70 | shape = [3, 4, 2, 5] 71 | X = torch.randn(*shape, dtype=torch.float64) 72 | 73 | alpha_shape = shape 74 | alpha_shape[dim] = 1 75 | 76 | alphas = 1.05 + torch.rand(alpha_shape, dtype=torch.float64) 77 | 78 | P = entmax_bisect(X, alpha=alphas, dim=dim) 79 | 80 | ranges = [ 81 | list(range(k)) if i != dim else [slice(None)] 82 | for i, k in enumerate(shape) 83 | ] 84 | 85 | for ix in product(*ranges): 86 | x = X[ix].unsqueeze(0) 87 | alpha = alphas[ix].item() 88 | p_true = entmax_bisect(x, alpha=alpha, dim=-1) 89 | assert torch.allclose(P[ix], p_true) 90 | 91 | 92 | @pytest.mark.parametrize("dim", (0, 1, 2, 3)) 93 | def test_arbitrary_dimension_grad(dim): 94 | shape = [3, 4, 2, 5] 95 | 96 | alpha_shape = shape 97 | alpha_shape[dim] = 1 98 | 99 | f = partial(entmax_bisect, dim=dim) 100 | 101 | X = torch.randn(*shape, dtype=torch.float64, requires_grad=True) 102 | alphas = 1.05 + torch.rand( 103 | alpha_shape, dtype=torch.float64, requires_grad=True 104 | ) 105 | gradcheck(f, (X, alphas), eps=1e-5) 106 | -------------------------------------------------------------------------------- /code/entmax/test_topk.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.autograd import gradcheck 4 | 5 | 6 | from entmax.activations import ( 7 | _sparsemax_threshold_and_support, 8 | _entmax_threshold_and_support, 9 | Sparsemax, 10 | Entmax15, 11 | ) 12 | 13 | 14 | @pytest.mark.parametrize("dim", (0, 1, 2)) 15 | @pytest.mark.parametrize("Map", (Sparsemax, Entmax15)) 16 | def test_mapping(dim, Map): 17 | f = Map(dim=dim, k=3) 18 | x = torch.randn(3, 4, 5, requires_grad=True, dtype=torch.float64) 19 | gradcheck(f, (x,)) 20 | 21 | 22 | @pytest.mark.parametrize("dim", (0, 1, 2)) 23 | @pytest.mark.parametrize("coef", (0.00001, 0.5, 10000)) 24 | def test_entmax_topk(dim, coef): 25 | x = coef * torch.randn(3, 4, 5) 26 | tau1, supp1 = _entmax_threshold_and_support(x, dim=dim, k=None) 27 | tau2, supp2 = _entmax_threshold_and_support(x, dim=dim, k=5) 28 | 29 | assert torch.all(tau1 == tau2) 30 | assert torch.all(supp1 == supp2) 31 | 32 | 33 | @pytest.mark.parametrize("dim", (0, 1, 2)) 34 | @pytest.mark.parametrize("coef", (0.00001, 0.5, 10000)) 35 | @pytest.mark.parametrize("k", (5, 30)) 36 | def test_sparsemax_topk(dim, coef, k): 37 | 38 | x = coef * torch.randn(3, 4, 5) 39 | tau1, supp1 = _sparsemax_threshold_and_support(x, dim=dim, k=None) 40 | tau2, supp2 = _sparsemax_threshold_and_support(x, dim=dim, k=k) 41 | 42 | assert torch.all(tau1 == tau2) 43 | assert torch.all(supp1 == supp2) 44 | -------------------------------------------------------------------------------- /code/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Apr 4 09:53:41 2023 4 | @author: Boxiang Yun School:ECNU Email:boxiangyun@gmail.com 5 | """ 6 | import argparse 7 | import json 8 | from skimage.metrics import hausdorff_distance 9 | import torch 10 | 11 | import numpy as np 12 | import math 13 | from spectr import SpecTr 14 | 15 | from torch.utils.data import DataLoader 16 | import os 17 | from local_utils.seed_everything import seed_reproducer 18 | import pandas as pd 19 | 20 | from tqdm import tqdm 21 | from Data_Generate import Data_Generate_Cho 22 | from local_utils.metrics import iou, dice 23 | 24 | 25 | def main(args): 26 | seed_reproducer(42) 27 | root_path = args.root_path 28 | dataset_divide = args.dataset_divide 29 | experiment_name = args.experiment_name 30 | checkpoint = args.checkpoint 31 | worker = args.worker 32 | device = args.device 33 | batch = args.batch 34 | 35 | model_path = os.path.join(checkpoint, experiment_name) 36 | 37 | images_root_path = os.path.join(root_path, args.dataset_hyper) 38 | mask_root_path = os.path.join(root_path, args.dataset_mask) 39 | dataset_json = os.path.join(root_path, dataset_divide) 40 | with open(dataset_json, 'r') as load_f: 41 | dataset_dict = json.load(load_f) 42 | 43 | device = torch.device(device) 44 | channels = args.channels_index 45 | spectral_number = args.spectral_number if channels is None else len(args.channels_index) 46 | multi_class = 1 47 | 48 | labels, outs = [], [] 49 | model = SpecTr(choose_translayer=args.choose_translayer, 50 | spatial_size=(args.cutting, args.cutting), 51 | max_seq=spectral_number, 52 | classes=multi_class, 53 | decode_choice=args.decode_choice, 54 | init_values=args.init_values) 55 | model = model.to(device) 56 | 57 | history = {'val_iou': [], 'val_dice': [], 'val_haus': []} 58 | #For slide window operation in the validation stage 59 | def patch_index(shape, patchsize, stride): 60 | s, h, w = shape 61 | sx = (w - patchsize[1]) // stride[1] + 1 62 | sy = (h - patchsize[0]) // stride[0] + 1 63 | sz = (s - patchsize[2]) // stride[2] + 1 64 | 65 | for x in range(sx): 66 | xs = stride[1] * x 67 | for y in range(sy): 68 | ys = stride[0] * y 69 | for z in range(sz): 70 | zs = stride[2] * z 71 | yield slice(zs, zs + patchsize[2]), slice(ys, ys + patchsize[0]), slice(xs, xs + patchsize[1]) 72 | 73 | for k in args.fold: 74 | val_files = dataset_dict[f'fold{k}'] 75 | val_images_path = [os.path.join(images_root_path, i) for i in dataset_dict[f'fold{k}']] 76 | val_masks_path = [os.path.join(mask_root_path, f'{i[:-4]}.png') for i in dataset_dict[f'fold{k}']] 77 | print(f'the number of valfiles is {len(val_files)}') 78 | val_db = Data_Generate_Cho(val_images_path, val_masks_path, cutting=None, transform=None, 79 | channels=channels, outtype=args.outtype) 80 | val_loader = DataLoader(val_db, batch_size=1, shuffle=False, num_workers=worker) 81 | 82 | model.load_state_dict(torch.load(os.path.join(model_path, f'final_{k}fold_74.pth'), 83 | map_location=lambda storage, loc: storage.cuda())) 84 | 85 | print('now start evaluate ...') 86 | model.eval() 87 | for idx, sample in enumerate(tqdm(val_loader)): 88 | image, label = sample 89 | image = image.squeeze() 90 | spectrum_shape, shape_h, shape_w = image.shape 91 | patch_idx = list(patch_index((spectrum_shape, shape_h, shape_w), (args.cutting, args.cutting, spectrum_shape), 92 | (64, 128, 1))) # origan shape is 256, 320; 128=320-192, 64=256-192 93 | num_collect = torch.zeros((shape_h, shape_w), dtype=torch.uint8).to(device) 94 | pred_collect = torch.zeros((shape_h, shape_w)).to(device) 95 | for i in range(0, len(patch_idx), batch): 96 | with torch.no_grad(): 97 | output = model(torch.stack([image[x] for x in patch_idx[i:i + batch]])[None].to(device)).squeeze(1) 98 | for j in range(output.size(0)): 99 | num_collect[patch_idx[i + j][1:]] += 1 100 | pred_collect[patch_idx[i + j][1:]] += output[j] 101 | 102 | out = pred_collect / num_collect.float() 103 | out[torch.isnan(out)] = 0 104 | out, label = out.cpu().detach().numpy()[None][None], label.cpu().detach().numpy() 105 | out = np.where(out > 0.5, 1, 0) 106 | label = np.where(label > 0.5, 1, 0) 107 | labels.extend(label) 108 | outs.extend(out) 109 | 110 | val_iou = np.array([iou(o[0], l[0]) for l, o in zip(labels, outs)]) 111 | val_dice = np.array([dice(o[0], l[0]) for l, o in zip(labels, outs)]) 112 | val_haus = np.array([hausdorff_distance(o[0], l[0]) for l, o in zip(labels, outs)]) #if 113 | # hausdorff_distance(o[0], l[0]) != float('inf')]) 114 | 115 | history['val_iou'].append(val_iou.mean()) 116 | history['val_dice'].append(val_dice.mean()) 117 | history['val_haus'].append(val_haus.mean()) 118 | history['val_iou'].append(val_iou.std()) 119 | history['val_dice'].append(val_dice.std()) 120 | history['val_haus'].append(val_haus.std()) 121 | 122 | print(f"the valid dataset iou & dice & hausdorff_distance is {val_iou.mean()} & {val_dice.mean()} & {val_haus.mean()}") 123 | history_pd = pd.DataFrame(history) 124 | history_pd.to_csv(os.path.join(model_path, f'metric.csv'), index=False) 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--root_path', '-r', type=str, default='./Cholangiocarcinoma/L') 130 | parser.add_argument('--dataset_hyper', '-dh', type=str, default='MHSI') 131 | parser.add_argument('--dataset_mask', '-dm', type=str, default='Mask') 132 | parser.add_argument('--dataset_divide', '-dd', type=str, default='four_fold.json') 133 | parser.add_argument('--device', '-dev', type=str, default='cuda:0') 134 | parser.add_argument('--fold', '-fold', type=int, default=[1, 2, 3, 4], nargs='+') 135 | 136 | parser.add_argument('--spectral_number', '-sn', default=60, type=int) 137 | parser.add_argument('--channels_index', '-chi', type=int, default=None, nargs='+') 138 | parser.add_argument('--worker', '-nw', type=int, 139 | default=4) 140 | parser.add_argument('--batch', '-b', type=int, default=1) 141 | parser.add_argument('--outtype', '-outt', type=str, 142 | default='3d') 143 | parser.add_argument('--checkpoint', '-o', type=str, default='checkpoint') 144 | parser.add_argument('--experiment_name', '-name', type=str, 145 | default='SpecTr_XXX') 146 | parser.add_argument('--choose_translayer', '-ct', nargs='+', type=int, default=[0, 1, 1, 1]) 147 | parser.add_argument('--cutting', '-cut', default=192, type=int) 148 | parser.add_argument('--epochs', '-e', type=int, default=75) 149 | parser.add_argument('--decode_choice', '-dc', default='3D', choices=['3D', 'decoder_2D']) 150 | parser.add_argument('--init_values', '-initv', type=float, default=0.01) 151 | 152 | args = parser.parse_args() 153 | 154 | main(args) 155 | -------------------------------------------------------------------------------- /code/local_utils/Gempooling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Mar 21 12:25:42 2020 4 | 5 | @author: ybx 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.parameter import Parameter 10 | import torch.nn.functional as F 11 | 12 | def gem(x, p=3, eps=1e-6): 13 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 14 | 15 | 16 | class GeM(nn.Module): 17 | 18 | def __init__(self, p=3, eps=1e-6): 19 | super(GeM,self).__init__() 20 | self.p = Parameter(torch.ones(1)*p) 21 | self.eps = eps 22 | 23 | def forward(self, x): 24 | return gem(x, p=self.p, eps=self.eps) 25 | 26 | def __repr__(self): 27 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 28 | 29 | -------------------------------------------------------------------------------- /code/local_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__init__.py -------------------------------------------------------------------------------- /code/local_utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/dice_bce_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/dice_bce_loss.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/focal_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/focal_loss.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/init_log.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/init_log.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/kappa_CE.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/kappa_CE.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/label_smoothing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/label_smoothing.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/load_pretrainedmodel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/load_pretrainedmodel.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/load_pretrainedmodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/load_pretrainedmodel.cpython-38.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/poly_ly.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/poly_ly.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/seed_everything.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/seed_everything.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /code/local_utils/__pycache__/tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab-ECNU/SpecTr/8afce0f14e2ef9727cebd148850855f23a48ce5d/code/local_utils/__pycache__/tools.cpython-38.pyc -------------------------------------------------------------------------------- /code/local_utils/bn_update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | 5 | def bn_update(loader, model, device=None): 6 | r"""Updates BatchNorm running_mean, running_var buffers in the model. 7 | 8 | It performs one pass over data in `loader` to estimate the activation 9 | statistics for BatchNorm layers in the model. 10 | 11 | Args: 12 | loader (torch.utils.data.DataLoader): dataset loader to compute the 13 | activation statistics on. Each data batch should be either a 14 | tensor, or a list/tuple whose first element is a tensor 15 | containing data. 16 | 17 | model (torch.nn.Module): model for which we seek to update BatchNorm 18 | statistics. 19 | 20 | device (torch.device, optional): If set, data will be trasferred to 21 | :attr:`device` before being passed into :attr:`model`. 22 | """ 23 | print('update bN..') 24 | 25 | if not _check_bn(model): 26 | return 27 | was_training = model.training 28 | model.train() 29 | momenta = {} 30 | model.apply(_reset_bn) 31 | model.apply(lambda module: _get_momenta(module, momenta)) 32 | n = 0 33 | for input in tqdm(loader): 34 | input = input['example'] 35 | if isinstance(input, (list, tuple)): 36 | input = input[0] 37 | b = input.size(0) 38 | 39 | momentum = b / float(n + b) 40 | for module in momenta.keys(): 41 | module.momentum = momentum 42 | 43 | if device is not None: 44 | input = input.to(device) 45 | 46 | model(input) 47 | n += b 48 | 49 | model.apply(lambda module: _set_momenta(module, momenta)) 50 | model.train(was_training) 51 | 52 | 53 | def _check_bn_apply(module, flag): 54 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 55 | flag[0] = True 56 | 57 | 58 | def _check_bn(model): 59 | flag = [False] 60 | model.apply(lambda module: _check_bn_apply(module, flag)) 61 | return flag[0] 62 | 63 | 64 | def _reset_bn(module): 65 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 66 | module.running_mean = torch.zeros_like(module.running_mean) 67 | module.running_var = torch.ones_like(module.running_var) 68 | 69 | 70 | def _get_momenta(module, momenta): 71 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 72 | momenta[module] = module.momentum 73 | 74 | 75 | def _set_momenta(module, momenta): 76 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 77 | module.momentum = momenta[module] 78 | 79 | 80 | -------------------------------------------------------------------------------- /code/local_utils/dice_bce_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from segmentation_models_pytorch.utils.losses import DiceLoss 3 | 4 | class Dice_BCE_Loss(nn.Module): 5 | def __init__(self, bce_weight=0.5, dice_weight=0.5): 6 | super(Dice_BCE_Loss, self).__init__() 7 | self.bce_weight = bce_weight 8 | self.dice_weight = dice_weight 9 | self.dice_loss = DiceLoss() 10 | self.bce_loss = nn.BCELoss() 11 | 12 | def forward(self, input, target): 13 | return self.bce_weight * self.dice_loss(input, target) + self.dice_weight * self.bce_loss(input, target) 14 | -------------------------------------------------------------------------------- /code/local_utils/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def onehot_encoding(label, n_classes): 6 | return torch.zeros(label.size(0), n_classes).to(label.device).scatter_( 7 | 1, label.view(-1, 1), 1) 8 | 9 | class FocalLoss(nn.Module): 10 | def __init__(self, gamma=0, alpha=None, size_average=True): 11 | super(FocalLoss, self).__init__() 12 | self.gamma = gamma 13 | self.alpha = alpha 14 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 15 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 16 | self.size_average = size_average 17 | 18 | def forward(self, input, target): 19 | if input.dim()>2: 20 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 21 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 22 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 23 | target = target.view(-1,1) 24 | 25 | logpt = F.log_softmax(input) 26 | logpt = logpt.gather(1,target) 27 | logpt = logpt.view(-1) 28 | pt = logpt.data.exp() 29 | 30 | if self.alpha is not None: 31 | if self.alpha.type()!=input.data.type(): 32 | self.alpha = self.alpha.type_as(input.data) 33 | at = self.alpha.gather(0,target.data.view(-1)) 34 | logpt = logpt * at 35 | 36 | loss = -1 * (1-pt)**self.gamma * logpt 37 | if self.size_average: 38 | return loss.mean() 39 | else: 40 | return loss.sum() -------------------------------------------------------------------------------- /code/local_utils/init_log.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Apr 4 09:53:41 2023 5 | @author: Boxiang Yun School:ECNU Email:boxiangyun@gmail.com 6 | """ 7 | import logging 8 | import os 9 | import random 10 | from argparse import ArgumentParser 11 | from logging import Logger 12 | from logging.handlers import TimedRotatingFileHandler 13 | 14 | # Third party libraries 15 | import cv2 16 | import numpy as np 17 | import pandas as pd 18 | import torch 19 | 20 | def mkdir(path: str): 21 | """Create directory. 22 | 23 | Create directory if it is not exist, else do nothing. 24 | 25 | Parameters 26 | ---------- 27 | path: str 28 | Path of your directory. 29 | 30 | Examples 31 | -------- 32 | mkdir("data/raw/train/") 33 | """ 34 | try: 35 | if path is None: 36 | pass 37 | else: 38 | os.stat(path) 39 | except Exception: 40 | os.makedirs(path) 41 | 42 | def init_logger(log_name, log_dir=None): 43 | """日志模块 44 | Reference: https://juejin.im/post/5bc2bd3a5188255c94465d31 45 | 日志器初始化 46 | 日志模块功能: 47 | 1. 日志同时打印到到屏幕和文件 48 | 2. 默认保留近一周的日志文件 49 | 日志等级: 50 | NOTSET(0)、DEBUG(10)、INFO(20)、WARNING(30)、ERROR(40)、CRITICAL(50) 51 | 如果设定等级为10, 则只会打印10以上的信息 52 | 53 | Parameters 54 | ---------- 55 | log_name : str 56 | 日志文件名 57 | log_dir : str 58 | 日志保存的目录 59 | 60 | Returns 61 | ------- 62 | RootLogger 63 | Python日志实例 64 | """ 65 | 66 | mkdir(log_dir) 67 | 68 | # 若多处定义Logger,根据log_name确保日志器的唯一性 69 | if log_name not in Logger.manager.loggerDict: 70 | logging.root.handlers.clear() 71 | logger = logging.getLogger(log_name) 72 | logger.setLevel(logging.DEBUG) 73 | 74 | # 定义日志信息格式 75 | datefmt = "%Y-%m-%d %H:%M:%S" 76 | format_str = "[%(asctime)s] %(filename)s[%(lineno)4s] : %(levelname)s %(message)s" 77 | formatter = logging.Formatter(format_str, datefmt) 78 | 79 | # 日志等级INFO以上输出到屏幕 80 | console_handler = logging.StreamHandler() 81 | console_handler.setLevel(logging.INFO) 82 | console_handler.setFormatter(formatter) 83 | logger.addHandler(console_handler) 84 | 85 | if log_dir is not None: 86 | # 日志等级INFO以上输出到{log_name}.log文件 87 | file_info_handler = TimedRotatingFileHandler( 88 | filename=os.path.join(log_dir, "%s.log" % log_name), when="D", backupCount=7 89 | ) 90 | file_info_handler.setFormatter(formatter) 91 | file_info_handler.setLevel(logging.INFO) 92 | logger.addHandler(file_info_handler) 93 | 94 | logger = logging.getLogger(log_name) 95 | 96 | return logger -------------------------------------------------------------------------------- /code/local_utils/kappa_CE.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Apr 4 09:53:41 2023 5 | @author: Boxiang Yun School:ECNU Email:boxiangyun@gmail.com 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class Kappa_CE(nn.Module): 12 | def __init__(self, reduction='mean'): 13 | super(Kappa_CE, self).__init__() 14 | self.reduction = reduction 15 | 16 | def forward(self, input, target): 17 | n_classes = input.size(1) 18 | onehot = F.one_hot(target,n_classes) 19 | logp = F.log_softmax(input, dim=1) 20 | weight = (torch.abs((torch.argmax(input,dim=1) - target)) + 1).view(-1,1) 21 | loss = torch.sum(-logp * onehot * weight, dim=1) 22 | 23 | if self.reduction == 'none': 24 | return loss 25 | elif self.reduction == 'mean': 26 | return loss.mean() 27 | elif self.reduction == 'sum': 28 | return loss.sum() 29 | else: 30 | raise ValueError( 31 | '`reduction` must be one of \'none\', \'mean\', or \'sum\'.') 32 | 33 | if __name__ == '__main__': 34 | kloss = Kappa_CE() 35 | a = torch.tensor([5,5,5,5]) 36 | b = torch.tensor([[0.1,0,0,0,0,0.9], 37 | [0,0,0,0,0.9,0.1], 38 | [0,0,0,0.9,0,0.1], 39 | [0,0,0.9,0,0,0.1]]) 40 | print(kloss(b,a)) -------------------------------------------------------------------------------- /code/local_utils/label_smoothing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Apr 4 09:53:41 2023 5 | @author: Boxiang Yun School:ECNU Email:boxiangyun@gmail.com 6 | """ 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | def onehot_encoding(label, n_classes): 12 | return torch.zeros(label.size(0), n_classes).to(label.device).scatter_( 13 | 1, label.view(-1, 1), 1) 14 | 15 | 16 | def cross_entropy_loss(input, target, reduction): 17 | logp = F.log_softmax(input, dim=1) 18 | loss = torch.sum(-logp * target, dim=1) 19 | if reduction == 'none': 20 | return loss 21 | elif reduction == 'mean': 22 | return loss.mean() 23 | elif reduction == 'sum': 24 | return loss.sum() 25 | else: 26 | raise ValueError( 27 | '`reduction` must be one of \'none\', \'mean\', or \'sum\'.') 28 | 29 | def label_smoothing_criterion(epsilon=0.1, reduction='mean'): 30 | def _label_smoothing_criterion(preds, targets): 31 | n_classes = preds.size(1) 32 | device = preds.device 33 | 34 | onehot = onehot_encoding(targets, n_classes).float().to(device) 35 | targets = onehot * (1 - epsilon) + torch.ones_like(onehot).to( 36 | device) * epsilon / n_classes 37 | loss = cross_entropy_loss(preds, targets, reduction) 38 | if reduction == 'none': 39 | return loss 40 | elif reduction == 'mean': 41 | return loss.mean() 42 | elif reduction == 'sum': 43 | return loss.sum() 44 | else: 45 | raise ValueError( 46 | '`reduction` must be one of \'none\', \'mean\', or \'sum\'.') 47 | 48 | return _label_smoothing_criterion -------------------------------------------------------------------------------- /code/local_utils/load_pretrainedmodel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Apr 4 09:53:41 2023 5 | @author: Boxiang Yun School:ECNU Email:boxiangyun@gmail.com 6 | """ 7 | import torch.nn as nn 8 | import torch 9 | 10 | def load_pretrainedmodel(loaded_model, pretrainedmodel_path): 11 | 12 | save_model = torch.load(pretrainedmodel_path) 13 | model_dict = loaded_model.state_dict() 14 | state_dict = {k:v for k,v in save_model.items() if 'encoders' in k} 15 | model_dict.update(state_dict) 16 | loaded_model.load_state_dict(model_dict) 17 | return loaded_model 18 | 19 | def load_gvt_pretrainedmodel(loaded_model, pretrainedmodel_path): 20 | 21 | save_model = torch.load(pretrainedmodel_path) 22 | model_dict = loaded_model.state_dict() 23 | state_dict = {k:v for idx,(k,v) in enumerate(save_model.items()) if 'head' not in k} 24 | model_dict.update(state_dict) 25 | loaded_model.load_state_dict(model_dict) 26 | return loaded_model 27 | 28 | def load_pvt_pretrainedmodel(loaded_model, pretrainedmodel_path, pretrained_pos=False): 29 | save_model = torch.load(pretrainedmodel_path) 30 | model_dict = loaded_model.state_dict() 31 | state_dict = {k:v for k,v in save_model.items() if ('cls_token' not in k) and ('head' not in k)} 32 | state_dict = {f'encoder.{k}':v for k,v in state_dict.items()} 33 | print(state_dict.keys()) 34 | model_dict.update(state_dict) 35 | print(f"success load {len(state_dict.keys())} parameters") 36 | 37 | loaded_model.load_state_dict(model_dict) 38 | 39 | return loaded_model 40 | def load_swin_pretrainedmodel(loaded_model,pretrainedmodel_path): 41 | save_model = (torch.load(pretrainedmodel_path)) 42 | state_dict = {k:v for k,v in save_model['state_dict'].items() if ('backbone' in k)} 43 | state_dict = {f'{k[9:]}':v for k,v in state_dict.items()} 44 | loaded_model.encoder.load_state_dict(state_dict) 45 | return loaded_model 46 | #def load_transformer_pretrainedmodel(loaded_model, pretrainedmodel_path, pretrained_pos=False): 47 | # save_model = torch.load(pretrainedmodel_path) 48 | # model_dict = loaded_model.state_dict() 49 | # if pretrained_pos: 50 | # save_model['pos_embed'] = save_model['pos_embed'].resize_(1,10,768)# get [:9] from [:16] sequence length 51 | # state_dict = {k:v for k,v in save_model.items() if k in model_dict} 52 | # else: 53 | # state_dict = {k:v for k,v in save_model.items() if k in model_dict and k not in 'pos_embed'} 54 | # 55 | # print(f"success load {len(state_dict.keys())} parameters") 56 | # model_dict.update(state_dict) 57 | # loaded_model.load_state_dict(model_dict) 58 | # 59 | # return loaded_model 60 | 61 | if __name__ == '__main__': 62 | x = torch.load('/home/ubuntu/T/redhouse/utils/attention_3d/\ 63 | redhouse-checkpoint-new/forpretrained_3dUnet_03class_0117/best_forpretrained_3dUnet_03class_0117.pth') 64 | for k,v in x.items(): 65 | print(k) 66 | # 67 | -------------------------------------------------------------------------------- /code/local_utils/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import recall_score,f1_score 2 | import numpy as np 3 | from skimage.metrics import hausdorff_distance 4 | from sklearn.metrics import jaccard_score 5 | from einops import rearrange 6 | 7 | def iou(y_hat, y): 8 | y_hat = y_hat.reshape(-1) 9 | y =y.reshape(-1) 10 | return jaccard_score(y_hat, y) 11 | 12 | def eval_f1score(pred, label): 13 | final_score = f1_score(label, pred, average='macro') 14 | return final_score 15 | 16 | 17 | def assert_shape(test, reference): 18 | assert test.shape == reference.shape, "Shape mismatch: {} and {}".format( 19 | test.shape, reference.shape) 20 | 21 | 22 | class ConfusionMatrix: 23 | def __init__(self, test=None, reference=None): 24 | 25 | self.tp = None 26 | self.fp = None 27 | self.tn = None 28 | self.fn = None 29 | self.size = None 30 | self.reference_empty = None 31 | self.reference_full = None 32 | self.test_empty = None 33 | self.test_full = None 34 | self.set_reference(reference) 35 | self.set_test(test) 36 | 37 | def set_test(self, test): 38 | 39 | self.test = test 40 | self.reset() 41 | 42 | def set_reference(self, reference): 43 | 44 | self.reference = reference 45 | self.reset() 46 | 47 | def reset(self): 48 | 49 | self.tp = None 50 | self.fp = None 51 | self.tn = None 52 | self.fn = None 53 | self.size = None 54 | self.test_empty = None 55 | self.test_full = None 56 | self.reference_empty = None 57 | self.reference_full = None 58 | 59 | def compute(self): 60 | 61 | if self.test is None or self.reference is None: 62 | raise ValueError("'test' and 'reference' must both be set to compute confusion matrix.") 63 | 64 | assert_shape(self.test, self.reference) 65 | 66 | self.tp = int(((self.test != 0) * (self.reference != 0)).sum()) 67 | self.fp = int(((self.test != 0) * (self.reference == 0)).sum()) 68 | self.tn = int(((self.test == 0) * (self.reference == 0)).sum()) 69 | self.fn = int(((self.test == 0) * (self.reference != 0)).sum()) 70 | self.size = int(np.prod(self.reference.shape, dtype=np.int64)) 71 | self.test_empty = not np.any(self.test) 72 | self.test_full = np.all(self.test) 73 | self.reference_empty = not np.any(self.reference) 74 | self.reference_full = np.all(self.reference) 75 | 76 | def get_matrix(self): 77 | 78 | for entry in (self.tp, self.fp, self.tn, self.fn): 79 | if entry is None: 80 | self.compute() 81 | break 82 | 83 | return self.tp, self.fp, self.tn, self.fn 84 | 85 | def get_size(self): 86 | 87 | if self.size is None: 88 | self.compute() 89 | return self.size 90 | 91 | def get_existence(self): 92 | 93 | for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full): 94 | if case is None: 95 | self.compute() 96 | break 97 | 98 | return self.test_empty, self.test_full, self.reference_empty, self.reference_full 99 | 100 | 101 | def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, eps=1e-7, beta=1, **kwargs): 102 | """2TP / (2TP + FP + FN)""" 103 | 104 | if confusion_matrix is None: 105 | confusion_matrix = ConfusionMatrix(test, reference) 106 | 107 | tp, fp, tn, fn = confusion_matrix.get_matrix() 108 | test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() 109 | 110 | # if test_empty and reference_empty: 111 | # if nan_for_nonexisting: 112 | # return float("NaN") 113 | # else: 114 | # return 0. 115 | 116 | return float(((1 + beta ** 2) * tp + eps) / ((1 + beta ** 2) * tp + fp + beta ** 2 * fn + eps)) 117 | 118 | 119 | def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, eps=1e-7, **kwargs): 120 | """TP / (TP + FP + FN)""" 121 | 122 | if confusion_matrix is None: 123 | confusion_matrix = ConfusionMatrix(test, reference) 124 | 125 | tp, fp, tn, fn = confusion_matrix.get_matrix() 126 | test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence() 127 | 128 | if test_empty and reference_empty: 129 | if nan_for_nonexisting: 130 | return float("NaN") 131 | else: 132 | return 0. 133 | 134 | return float(tp + eps / (tp + fp + fn + eps)) 135 | 136 | 137 | def sensitivity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, eps=1e-7, **kwargs): 138 | """TP / (TP + FN)""" 139 | if confusion_matrix is None: 140 | confusion_matrix = ConfusionMatrix(test, reference) 141 | 142 | tp, fp, tn, fn = confusion_matrix.get_matrix() 143 | 144 | return float(tp + eps/ (tp + fn + eps)) 145 | 146 | 147 | def specificity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, eps=1e-7, **kwargs): 148 | """TN / (TN + FP)""" 149 | if confusion_matrix is None: 150 | confusion_matrix = ConfusionMatrix(test, reference) 151 | 152 | tp, fp, tn, fn = confusion_matrix.get_matrix() 153 | 154 | return float(tn + eps/ (tn + fp + eps)) 155 | 156 | 157 | # def f1_score(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, beta=1, eps=1e-7, **kwargs): 158 | # """TN / (TN + FP)""" 159 | # if confusion_matrix is None: 160 | # confusion_matrix = ConfusionMatrix(test, reference) 161 | # tp, fp, tn, fn = confusion_matrix.get_matrix() 162 | # 163 | # return ((1 + beta ** 2) * tp + eps) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps) 164 | 165 | 166 | def auc(test=None, reference=None): 167 | return roc_auc_score(reference, test) 168 | 169 | 170 | # ((1 + beta ** 2) * tp + eps) \ ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + eps) 171 | 172 | 173 | def hausdorff_distance_case(pred, label): 174 | pred,label = pred.astype(np.uint8),label.astype(np.uint8) 175 | assert pred.shape[0] == 1 ,"one class be predicted" 176 | 177 | return hausdorff_distance(pred.reshape(-1),label.reshape(-1)) 178 | 179 | if __name__ == '__main__': 180 | pred = [0, 1, 2, 0, 1, 2] 181 | label = [2, 1, 0, 2, 1, 0] 182 | score = recall_score(pred, label, average='macro') 183 | print(score) -------------------------------------------------------------------------------- /code/local_utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value 6 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 7 | """ 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def B_postprocess_output(network_output: torch.Tensor): 25 | out = network_output.cpu().data.numpy() 26 | #pred = np.argmax(out, axis=1).tolist() 27 | pred = np.where(out>0.5,1,0) 28 | return pred 29 | 30 | 31 | if __name__ == '__main__': 32 | out = torch.tensor([[1,2,3],[8,1,4]]) 33 | # print(postprocess_output(out)) -------------------------------------------------------------------------------- /code/local_utils/poly_ly.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9): 17 | return initial_lr * (1 - epoch / max_epochs)**exponent -------------------------------------------------------------------------------- /code/local_utils/seed_everything.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Apr 4 09:53:41 2023 5 | @author: Boxiang Yun School:ECNU Email:boxiangyun@gmail.com 6 | """ 7 | import os 8 | import random 9 | 10 | # Third party libraries 11 | import numpy as np 12 | import torch 13 | 14 | def seed_reproducer(seed=2020): 15 | """Reproducer for pytorch experiment. 16 | 17 | Parameters 18 | ---------- 19 | seed: int, optional (default = 2020) 20 | Radnom seed. 21 | 22 | Example 23 | ------- 24 | seed_reproducer(seed=2020). 25 | """ 26 | random.seed(seed) 27 | os.environ["PYTHONHASHSEED"] = str(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | if torch.cuda.is_available(): 31 | torch.cuda.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed)#set all gpus seed 33 | torch.backends.cudnn.deterministic = True 34 | torch.backends.cudnn.benchmark = False#if input data type and channels' changes arent' large use it improve train efficient 35 | torch.backends.cudnn.enabled = True 36 | -------------------------------------------------------------------------------- /code/local_utils/tools.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import torch 4 | import shutil 5 | import torch.nn as nn 6 | 7 | class EarlyStopping: 8 | """Early stops the training if validation loss doesn't improve after a given patience.""" 9 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'): 10 | """ 11 | Args: 12 | patience (int): How long to wait after last time validation loss improved. 13 | Default: 7 14 | verbose (bool): If True, prints a message for each validation loss improvement. 15 | Default: False 16 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 17 | Default: 0 18 | path (str): Path for the checkpoint to be saved to. 19 | Default: 'checkpoint.pt' 20 | """ 21 | self.patience = patience 22 | self.verbose = verbose 23 | self.counter = 0 24 | self.best_score = None 25 | self.early_stop = False 26 | self.val_loss_min = np.Inf 27 | self.delta = delta 28 | self.path = path 29 | 30 | def __call__(self, val_loss, model): 31 | 32 | score = -val_loss 33 | 34 | if self.best_score is None: 35 | self.best_score = score 36 | self.save_checkpoint(val_loss, model) 37 | elif score < self.best_score + self.delta: 38 | self.counter += 1 39 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 40 | if self.counter >= self.patience: 41 | self.early_stop = True 42 | else: 43 | self.best_score = score 44 | self.save_checkpoint(val_loss, model) 45 | self.counter = 0 46 | 47 | def save_checkpoint(self, val_loss, model): 48 | '''Saves model when validation loss decrease.''' 49 | if self.verbose: 50 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 51 | torch.save(model.state_dict(), self.path) 52 | self.val_loss_min = val_loss 53 | 54 | def save_dict(save_file:str, dict_obj: dict): 55 | with open(save_file, 'w', encoding='utf-8') as f: 56 | writer = csv.writer(f) 57 | for key in dict_obj: 58 | writer.writerow([key, dict_obj[key]]) 59 | 60 | def save_pyfile(save_file:str, out_file:str): 61 | shutil.copy(save_file, out_file) 62 | 63 | 64 | def get_layer(model, name): 65 | layer = model 66 | for attr in name.split("."): 67 | layer = getattr(layer, attr) 68 | return layer 69 | 70 | def set_layer(model, name, layer): 71 | try: 72 | attrs, name = name.rsplit(".", 1) 73 | model = get_layer(model, attrs) 74 | except ValueError: 75 | pass 76 | setattr(model, name, layer) 77 | 78 | 79 | def convert_2d_gn(model): 80 | for name, module in model.named_modules(): 81 | if isinstance(module, nn.BatchNorm2d): 82 | # Get current bn layer 83 | bn = get_layer(model, name) 84 | # Create new gn layer 85 | if bn.num_features == 1: 86 | gn = nn.GroupNorm(1, bn.num_features) 87 | else: 88 | gn = nn.GroupNorm(8, bn.num_features) 89 | # Assign gn 90 | print("Swapping {} with {}".format(bn, gn)) 91 | set_layer(model, name, gn) 92 | return model 93 | 94 | 95 | -------------------------------------------------------------------------------- /code/spectr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Apr 4 09:53:41 2023 5 | @author: Boxiang Yun School:ECNU Email:boxiangyun@gmail.com 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from vit_modeling import Spectral_ZipBlock_four, ParallelBlock_CAT 11 | from spectr_block import Trans_block, AdaptivePool_Encoder 12 | from typing import Optional, Union 13 | 14 | class Conv3x3GNReLU(nn.Module): 15 | def __init__(self, in_channels, out_channels, upsample=False, up_size=(7, 128, 128)): 16 | super().__init__() 17 | self.upsample = upsample 18 | self.block = nn.Sequential( 19 | nn.Conv3d( 20 | in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False 21 | ), 22 | # nn.GroupNorm(32, out_channels), 23 | nn.GroupNorm(32, out_channels), 24 | nn.ReLU(inplace=True), 25 | ) 26 | self.up_size = up_size 27 | 28 | def forward(self, x): 29 | x = self.block(x) 30 | if self.upsample: 31 | x = F.interpolate(x, size=self.up_size, mode='trilinear') 32 | # x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) 33 | return x 34 | 35 | class GNConv3x3ReLU(nn.Module): 36 | def __init__(self, in_channels, out_channels, upsample=False, up_size=(128, 128)): 37 | super().__init__() 38 | self.upsample = upsample 39 | self.block = nn.Sequential( 40 | nn.GroupNorm(32, in_channels), 41 | nn.Conv3d( 42 | in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False 43 | ), 44 | nn.ReLU(inplace=True), 45 | ) 46 | self.up_size = up_size 47 | 48 | def forward(self, x): 49 | x = self.block(x) 50 | if self.upsample: 51 | x = F.interpolate(x, size=self.up_size, mode='trilinear') 52 | # x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) 53 | return x 54 | 55 | class GNConv3x3ReLU_2D(nn.Module): 56 | def __init__(self, in_channels, out_channels, upsample=False): 57 | super().__init__() 58 | self.upsample = upsample 59 | self.block = nn.Sequential( 60 | nn.GroupNorm(8, in_channels), 61 | nn.Conv2d( 62 | in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False 63 | ), 64 | nn.ReLU(inplace=True), 65 | ) 66 | 67 | def forward(self, x): 68 | x = self.block(x) 69 | if self.upsample: 70 | # x = F.interpolate(x, size=self.up_size, mode='trilinear') 71 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) 72 | return x 73 | 74 | class Conv3x3GNReLU_2D(nn.Module): 75 | def __init__(self, in_channels, out_channels, upsample=False): 76 | super().__init__() 77 | self.upsample = upsample 78 | self.block = nn.Sequential( 79 | nn.Conv2d( 80 | in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False 81 | ), 82 | nn.GroupNorm(8, out_channels), 83 | nn.ReLU(inplace=True), 84 | ) 85 | 86 | def forward(self, x): 87 | x = self.block(x) 88 | if self.upsample: 89 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) 90 | return x 91 | 92 | class FPNBlock(nn.Module): 93 | def __init__(self, pyramid_channels, skip_channels, out_size): 94 | super().__init__() 95 | self.skip_conv = nn.Conv3d(skip_channels, pyramid_channels, kernel_size=1) 96 | self.out_size = out_size 97 | 98 | def forward(self, x, skip=None): 99 | x = F.interpolate(x, size=self.out_size, mode="trilinear") 100 | skip = self.skip_conv(skip) 101 | x = x + skip 102 | return x 103 | 104 | class FPNBlock_2D(nn.Module): 105 | def __init__(self, pyramid_channels, skip_channels): 106 | super().__init__() 107 | self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) 108 | 109 | def forward(self, x, skip=None): 110 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) 111 | skip = self.skip_conv(skip) 112 | x = x + skip 113 | return x 114 | 115 | class SegmentationBlock(nn.Module): 116 | def __init__(self, in_channels, out_channels, n_upsamples=0, up_size=(7, 128, 128), decode_order='cgr'): 117 | super().__init__() 118 | if decode_order == 'cgr': 119 | blocks = [ 120 | Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples), up_size=up_size[n_upsamples - 1])] 121 | else: 122 | blocks = [ 123 | GNConv3x3ReLU(in_channels, out_channels, upsample=bool(n_upsamples), up_size=up_size[n_upsamples - 1])] 124 | 125 | if n_upsamples > 1: 126 | for t in range(1, n_upsamples): 127 | if decode_order == 'cgr': 128 | blocks.append( 129 | Conv3x3GNReLU(out_channels, out_channels, upsample=True, up_size=up_size[n_upsamples - t - 1])) 130 | else: 131 | blocks.append( 132 | GNConv3x3ReLU(out_channels, out_channels, upsample=True, up_size=up_size[n_upsamples - t - 1])) 133 | 134 | self.block = nn.Sequential(*blocks) 135 | 136 | def forward(self, x): 137 | x = self.block(x) 138 | return x 139 | 140 | class SegmentationBlock_2D(nn.Module): 141 | def __init__(self, in_channels, out_channels, n_upsamples=0, decode_order='cgr'): 142 | super().__init__() 143 | if decode_order == 'cgr': 144 | blocks = [Conv3x3GNReLU_2D(in_channels, out_channels, upsample=bool(n_upsamples))] 145 | else: 146 | blocks = [ 147 | GNConv3x3ReLU_2D(in_channels, out_channels, upsample=bool(n_upsamples))] 148 | 149 | if n_upsamples > 1: 150 | for t in range(1, n_upsamples): 151 | if decode_order == 'cgr': 152 | blocks.append(Conv3x3GNReLU_2D(out_channels, out_channels, upsample=True)) 153 | else: 154 | blocks.append(GNConv3x3ReLU_2D(out_channels, out_channels, upsample=True)) 155 | 156 | self.block = nn.Sequential(*blocks) 157 | 158 | def forward(self, x): 159 | x = self.block(x) 160 | return x 161 | 162 | class MergeBlock(nn.Module): 163 | def __init__(self, policy): 164 | super().__init__() 165 | if policy not in ["add", "cat"]: 166 | raise ValueError( 167 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format( 168 | policy 169 | ) 170 | ) 171 | self.policy = policy 172 | 173 | def forward(self, x): 174 | if self.policy == 'add': 175 | return sum(x) 176 | elif self.policy == 'cat': 177 | return torch.cat(x, dim=1) 178 | else: 179 | raise ValueError( 180 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy) 181 | ) 182 | 183 | class FPNDecoder(nn.Module): 184 | def __init__( 185 | self, 186 | encoder_channels, 187 | encoder_depth=5, 188 | pyramid_channels=256, 189 | segmentation_channels=128, 190 | dropout=0.2, 191 | max_seq=60, 192 | spatial_size=(128, 128), 193 | coscale_depth=1, 194 | coscale_entmax='softmax', 195 | use_coscale=True, 196 | decode_order='cgr', 197 | use_layerscale=False, 198 | init_values=1, 199 | zoom_spectral=True, 200 | ): 201 | super().__init__() 202 | 203 | self.out_channels = segmentation_channels * 4 204 | if encoder_depth < 3: 205 | raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth)) 206 | 207 | encoder_channels = encoder_channels[encoder_depth - 4:][::-1] 208 | # encoder_channels = encoder_channels[:encoder_depth + 1] 209 | if zoom_spectral: 210 | f_spatials = [[max_seq // 2 ** i, spatial_size[1] // 2 ** i, spatial_size[0] // 2 ** i] for i in 211 | range(encoder_depth)] 212 | else: 213 | f_spatials = [[max_seq, spatial_size[1] // 2 ** i, spatial_size[0] // 2 ** i] for i in 214 | range(encoder_depth)] 215 | 216 | f_spectrums = [i[0] for i in f_spatials] 217 | self.use_coscale = use_coscale 218 | 219 | if use_coscale: 220 | self.parallel_blocks = nn.ModuleList([ParallelBlock_CAT( 221 | dims=encoder_channels[::-1], num_heads=8, mlp_ratios=[3, 3, 3, 3], 222 | drop=dropout, attn_drop=dropout, drop_path=dropout, 223 | use_entmax15=coscale_entmax, 224 | use_layerscale=use_layerscale, 225 | init_values=init_values, 226 | ) 227 | for idx_p in range(coscale_depth)]) 228 | 229 | self.p5 = nn.Conv3d(encoder_channels[0], pyramid_channels, kernel_size=1) 230 | self.p4 = FPNBlock(pyramid_channels, encoder_channels[1], out_size=tuple(f_spatials[-2])) 231 | self.p3 = FPNBlock(pyramid_channels, encoder_channels[2], out_size=tuple(f_spatials[-3])) 232 | self.p2 = FPNBlock(pyramid_channels, encoder_channels[3], out_size=tuple(f_spatials[-4])) 233 | 234 | self.seg_blocks = nn.ModuleList([ 235 | SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples, up_size=f_spatials, 236 | decode_order=decode_order) 237 | for n_upsamples in [3, 2, 1, 0] 238 | ]) 239 | 240 | self.merge = MergeBlock('cat') 241 | self.dropout = nn.Dropout3d(p=dropout, inplace=True) 242 | 243 | def forward(self, *features): 244 | c2, c3, c4, c5 = features[-4:] 245 | 246 | if self.use_coscale: 247 | for blk in self.parallel_blocks: 248 | c2, c3, c4, c5 = blk(c2, c3, c4, c5) 249 | p5 = self.p5(c5) 250 | p4 = self.p4(p5, c4) 251 | p3 = self.p3(p4, c3) 252 | p2 = self.p2(p3, c2) 253 | 254 | feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])] 255 | x = self.merge(feature_pyramid) 256 | x = self.dropout(x) 257 | return x 258 | 259 | class FPNDecoder_2D(nn.Module): 260 | def __init__( 261 | self, 262 | encoder_channels, 263 | encoder_depth=5, 264 | pyramid_channels=256, 265 | segmentation_channels=128, 266 | dropout=0.2, 267 | max_seq=60, 268 | spatial_size=(128, 128), 269 | coscale_depth=1, 270 | condense_entmax='adaptive_entmax', 271 | coscale_entmax='adaptive_entmax', 272 | use_coscale=True, 273 | decode_order='cgr', 274 | use_layerscale=False 275 | ): 276 | super().__init__() 277 | self.out_channels = segmentation_channels * 4 278 | if encoder_depth < 3: 279 | raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth)) 280 | 281 | encoder_channels = encoder_channels[encoder_depth - 4:][::-1] 282 | # encoder_channels = encoder_channels[:encoder_depth + 1] 283 | f_spatials = [[max_seq // 2 ** i, spatial_size[1] // 2 ** i, spatial_size[0] // 2 ** i] for i in 284 | range(encoder_depth)] 285 | 286 | self.use_coscale = use_coscale 287 | if use_coscale: 288 | self.parallel_blocks = nn.ModuleList([ 289 | ParallelBlock_CAT( 290 | dims=encoder_channels[::-1], num_heads=8, mlp_ratios=[3, 3, 3, 3], 291 | drop=dropout, attn_drop=dropout, drop_path=dropout, 292 | use_entmax15=coscale_entmax, 293 | use_layerscale=use_layerscale 294 | ) 295 | for _ in range(coscale_depth)] 296 | ) 297 | 298 | self.zip_blocks = nn.ModuleList([ 299 | Spectral_ZipBlock_four( 300 | dims=encoder_channels[::-1], num_heads=8, mlp_ratios=[3, 3, 3, 3], 301 | drop=dropout, attn_drop=dropout, drop_path=dropout, 302 | use_entmax15=condense_entmax, use_layerscale=False, 303 | ) 304 | for _ in range(1)]) 305 | 306 | self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) 307 | self.p4 = FPNBlock_2D(pyramid_channels, encoder_channels[1]) 308 | self.p3 = FPNBlock_2D(pyramid_channels, encoder_channels[2]) 309 | self.p2 = FPNBlock_2D(pyramid_channels, encoder_channels[3]) 310 | 311 | self.seg_blocks = nn.ModuleList([ 312 | SegmentationBlock_2D(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples, 313 | decode_order=decode_order) 314 | for n_upsamples in [3, 2, 1, 0] 315 | ]) 316 | 317 | self.merge = MergeBlock('cat') 318 | self.dropout = nn.Dropout2d(p=dropout, inplace=True) 319 | 320 | def forward(self, *features): 321 | c2, c3, c4, c5 = features[-4:] 322 | if self.use_coscale: 323 | for blk in self.parallel_blocks: 324 | c2, c3, c4, c5 = blk(c2, c3, c4, c5) 325 | 326 | for blk in self.zip_blocks: 327 | c2, c3, c4, c5 = blk(c2, c3, c4, c5) 328 | 329 | p5 = self.p5(c5) 330 | p4 = self.p4(p5, c4) 331 | p3 = self.p3(p4, c3) 332 | p2 = self.p2(p3, c2) 333 | 334 | feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])] 335 | x = self.merge(feature_pyramid) 336 | x = self.dropout(x) 337 | 338 | return x 339 | 340 | def number_of_features_per_level(init_channel_number, num_levels): 341 | return [init_channel_number * 2 ** k for k in range(num_levels)] 342 | 343 | 344 | class Spectr_backbone(nn.Module): 345 | """ 346 | non-zoom-spectr and downsample by transformer block and after conv 347 | """ 348 | 349 | def __init__(self, in_channels, f_maps=64, encode_layer_order='gcr', choose_translayer=[0, 1, 1, 1], 350 | tran_enc_layers=[1, 1, 1, 1], dropout_att=0.1, dropout=0.1, 351 | num_levels=4, spatial_size=(256, 256), zoom_spectral=True, 352 | transformer_dim=3, init_values=1.0, use_layerscale=True, 353 | conv_kernel_size=(1, 3, 3), conv_padding=(0, 1, 1), max_seq=60, use_entmax15='entmax_bisect', 354 | att_blocks='att'): 355 | 356 | super(Spectr_backbone, self).__init__() 357 | 358 | assert len(tran_enc_layers) == num_levels, "input correct choiced transformer layers" 359 | 360 | if isinstance(f_maps, int): 361 | f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) 362 | 363 | if zoom_spectral: 364 | if isinstance(spatial_size, int): 365 | f_spatials = [[max_seq // 2 ** i, spatial_size // 2 ** i, spatial_size // 2 ** i] for i in 366 | range(num_levels)] 367 | else: 368 | f_spatials = [[max_seq // 2 ** i, spatial_size[1] // 2 ** i, spatial_size[0] // 2 ** i] for i in 369 | range(num_levels)] 370 | else: 371 | if isinstance(spatial_size, int): 372 | f_spatials = [[max_seq, spatial_size // 2 ** i, spatial_size // 2 ** i] for i in range(num_levels)] 373 | else: 374 | f_spatials = [[max_seq, spatial_size[1] // 2 ** i, spatial_size[0] // 2 ** i] for i in 375 | range(num_levels)] 376 | 377 | # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` 378 | self.out_channels = [] 379 | encoders = [] 380 | for i, out_feature_num in enumerate(f_maps): 381 | if choose_translayer[i]: 382 | transf = Trans_block(out_feature_num, spatial_size=f_spatials[i][1:], depth_trans=tran_enc_layers[i], 383 | dropout=dropout, attention_dropout_rate=dropout_att, 384 | use_entmax15=use_entmax15, att_blocks=att_blocks, 385 | transformer_dim=transformer_dim, init_values=init_values, 386 | seq_length=f_spatials[i][0], use_layerscale=use_layerscale) 387 | else: 388 | transf = None 389 | self.out_channels.append(out_feature_num) 390 | 391 | if i == 0: 392 | encoder = AdaptivePool_Encoder(in_channels, out_feature_num, 393 | # skip pooling in the first encoder 394 | apply_pooling=False, 395 | conv_layer_order=encode_layer_order, 396 | conv_kernel_size=conv_kernel_size, 397 | padding=conv_padding, 398 | output_size=f_spatials[i], 399 | transform=transf) 400 | 401 | else: 402 | encoder = AdaptivePool_Encoder(f_maps[i - 1], out_feature_num, 403 | apply_pooling=True, 404 | conv_layer_order=encode_layer_order, 405 | conv_kernel_size=conv_kernel_size, 406 | padding=conv_padding, 407 | output_size=f_spatials[i], 408 | transform=transf) 409 | 410 | encoders.append(encoder) 411 | self.encoders = nn.ModuleList(encoders) 412 | 413 | # in the last layer a 1×1 convolution reduces the number of output 414 | 415 | def forward(self, x): 416 | # encoder part 417 | encoders_features = [] 418 | for idx, encoder in enumerate(self.encoders): 419 | x = encoder(x) 420 | # reverse the encoder outputs to be aligned with the decoder 421 | encoders_features.append(x) 422 | 423 | return encoders_features 424 | 425 | 426 | class SegmentationHead(nn.Sequential): 427 | def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1): 428 | conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 429 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 430 | activation = nn.Sigmoid() if activation == 'sigmoid' else nn.Softmax(dim=1) 431 | super().__init__(conv2d, upsampling, activation) 432 | 433 | 434 | class SpecTr(nn.Module): 435 | """ 436 | Args: 437 | 438 | num_levels: A number of stages used in encoder in range [3, 5]. Each stage generate features 439 | two times smaller in spatial dimentions than previous one (e.g. for depth 0 we will have features 440 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 441 | Default is 4 442 | f_maps: Dimension of the encoder's feature map from the first stage (e.g. 32 -> 64 -> 128 -> 256). 443 | encode_layer_order: The order of encoder module operation e.g.: cgr means Conv -> GroupNorm -> ReLU 444 | decode_layer_order: The order of decoder module operation e.g.: cgr means Conv -> GroupNorm -> ReLU 445 | choose_translayer: Choose the encoder stages which contain transformer e.g. [0, 1, 1, 1] means the 2nd, 3rd, 4th 446 | encoder layer with Conv+Transformer, and the first layer with Conv. 447 | tran_enc_layers: the number of layers in transformer 448 | spatial_size: the input HSI spatial size 449 | zoom_spectral: False : don't downsample spectral dimension in encoder. True downsample spectra e.g. (spectral, h, w) 450 | -> (spectral/2, h/2, w/2) in each encoder stages. 451 | conv_kernel_size: 3D conv kernel size 452 | conv_padding: 3D conv padding setting 453 | max_seq: spectral number 454 | use_entmax15: wo. & w. sparsity operation (choose ['softmax', 'adaptive_entmax']) in attention block. 455 | decoder_pyramid_channels: A number of convolution filters in decoder_pyramid blocks 456 | decoder_segmentation_channels: A number of convolution filters in segmentation blocks 457 | dropout_att: Dropout ration in self-attention attention map. Default is 0.1 458 | dropout: Dropout ration in self-attention FFN modules. Default is 0.1 459 | in_channels: input channel in encoder, default is 1 460 | classes: the output class number, default is 1. 461 | activation: the activation of segmentation head, default is 'sigmod'. 462 | upsampling: Final upsampling factor. Default is 1 to preserve input-output spatial shape identity 463 | att_blocks: wo. & w. layerscale operation (choose ['att', 'layerscale']) in attention block. 464 | decode_choice: Decode A. : 3D decoder and Decoder B. Lite 2D decoder for faster spectral (choose ['3D', 'decoder_2D']) 465 | coscale_depth: the number of transformer layer in Inter-scale Spatiospectral Feature Extractor 466 | use_coscale: Use Inter-scale Spatiospectral Feature Extractor. 467 | transformer_dim: the ffn's MLP dimension ration in the transfromer 468 | use_layerscale: Use layerscale operation 469 | init_values: the init_values alpha in layerscale operation 470 | condense_entmax: wo. & w. sparsity operation (choose ['softmax', 'adaptive_entmax']) on attention block in decode lite. 471 | coscale_entmax: wo. & w. sparsity operation (choose ['softmax', 'adaptive_entmax']) on attention block in 472 | Inter-scale Spatiospectral Feature Extractor 473 | 474 | Returns: 475 | ``torch.nn.Module``: **SpecTr** 476 | 477 | """ 478 | 479 | def __init__( 480 | self, 481 | num_levels: int = 4, 482 | f_maps: int = 32, 483 | encode_layer_order: str = 'scr', 484 | decode_layer_order: str = 'cgr', 485 | choose_translayer: list = [0, 1, 1, 1], 486 | tran_enc_layers: list = [1, 1, 1, 1], 487 | spatial_size: tuple = (256, 256), 488 | zoom_spectral: bool = True, 489 | conv_kernel_size: Union[tuple, int] = (1, 3, 3), 490 | conv_padding: Union[tuple, int] = (0, 1, 1), 491 | max_seq: int = 60, 492 | use_entmax15: str = 'adaptive_entmax', 493 | decoder_pyramid_channels: int = 128, 494 | decoder_segmentation_channels: int = 64, 495 | dropout_att: float = 0.1, 496 | dropout: float = 0.1, 497 | in_channels: int = 1, 498 | classes: int = 1, 499 | activation: Optional[str] = 'sigmoid', 500 | upsampling: int = 1, 501 | att_blocks: str = 'layerscale', 502 | decode_choice: str = '3D', 503 | coscale_depth: int = 1, 504 | use_coscale: bool = True, 505 | transformer_dim: int = 3, 506 | use_layerscale: bool = True, 507 | init_values: float = 1.0, 508 | condense_entmax: str = 'adaptive_entmax', 509 | coscale_entmax: str = 'adaptive_entmax', 510 | ): 511 | super().__init__() 512 | 513 | self.decode_choice = decode_choice 514 | assert len(choose_translayer) == len(tran_enc_layers), "transformer block number must equal depth of length " 515 | 516 | self.encoder = Spectr_backbone( 517 | in_channels, f_maps=f_maps, encode_layer_order=encode_layer_order, dropout_att=dropout_att, dropout=dropout, 518 | choose_translayer=choose_translayer, tran_enc_layers=tran_enc_layers, num_levels=num_levels, 519 | spatial_size=spatial_size, zoom_spectral=zoom_spectral, 520 | conv_kernel_size=conv_kernel_size, conv_padding=conv_padding, max_seq=max_seq, 521 | use_entmax15=use_entmax15, att_blocks=att_blocks, transformer_dim=transformer_dim, init_values=init_values, 522 | use_layerscale=use_layerscale 523 | ) 524 | 525 | if decode_choice == '3D': 526 | self.decoder = FPNDecoder( 527 | encoder_channels=self.encoder.out_channels, 528 | encoder_depth=num_levels, 529 | pyramid_channels=decoder_pyramid_channels, 530 | segmentation_channels=decoder_segmentation_channels, 531 | max_seq=max_seq, 532 | spatial_size=spatial_size, 533 | coscale_depth=coscale_depth, 534 | coscale_entmax=coscale_entmax, 535 | use_coscale=use_coscale, 536 | decode_order=decode_layer_order, 537 | use_layerscale=use_layerscale, 538 | zoom_spectral=zoom_spectral) 539 | 540 | elif decode_choice == "decoder_2D": 541 | self.decoder = FPNDecoder_2D( 542 | encoder_channels=self.encoder.out_channels, 543 | encoder_depth=num_levels, 544 | pyramid_channels=decoder_pyramid_channels, 545 | segmentation_channels=decoder_segmentation_channels, 546 | max_seq=max_seq, 547 | spatial_size=spatial_size, 548 | coscale_depth=coscale_depth, 549 | condense_entmax=condense_entmax, 550 | coscale_entmax=coscale_entmax, 551 | use_coscale=use_coscale, 552 | decode_order=decode_layer_order, 553 | use_layerscale=use_layerscale, 554 | ) 555 | else: 556 | raise ValueError("please choice correct decode methods : '3D', 'decoder_2D'!") 557 | 558 | self.segmentation_head = SegmentationHead( 559 | in_channels=self.decoder.out_channels, 560 | out_channels=classes, 561 | activation=activation, 562 | kernel_size=1, 563 | upsampling=upsampling, 564 | ) 565 | 566 | def forward(self, x): 567 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 568 | features = self.encoder(x) 569 | decoder_output = self.decoder(*features) 570 | if self.decode_choice != 'decoder_2D': 571 | decoder_output = decoder_output.mean(2) 572 | masks = self.segmentation_head(decoder_output) 573 | else: 574 | masks = self.segmentation_head(decoder_output) 575 | return masks 576 | 577 | def forward_encoder(self, x): 578 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 579 | features = self.encoder(x) 580 | return features 581 | -------------------------------------------------------------------------------- /code/spectr_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from vit_modeling import Transformer 4 | from einops import rearrange 5 | 6 | def conv3d(in_channels, out_channels, kernel_size, bias, padding): 7 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) 8 | 9 | class Spectral_Normalize(nn.Module): 10 | """ 11 | create a list of modules with different spetral channel's normalize(bn,gn,in,ln) 12 | """ 13 | def __init__(self, num_features, num_spectral, eps=1e-5, momentum=0.1, affine=True, 14 | track_running_stats=True,normalize_type='g'): 15 | super(Spectral_Normalize, self).__init__() 16 | self.num_spectral= num_spectral 17 | # self.bns = nn.ModuleList([nn.modules.batchnorm._BatchNorm(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_classes)]) 18 | if normalize_type == 'b': 19 | base_norm = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) 20 | elif normalize_type == 'g': 21 | num_groups = 8 22 | if num_features < num_groups: 23 | num_groups = 1 24 | base_norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features, eps=eps, affine=affine) 25 | elif normalize_type == 'i': 26 | base_norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats) 27 | 28 | self.bns = nn.ModuleList( 29 | [base_norm for _ in range(num_spectral)]) 30 | 31 | def reset_running_stats(self): 32 | for bn in self.bns: 33 | bn.reset_running_stats() 34 | 35 | def reset_parameters(self): 36 | for bn in self.bns: 37 | bn.reset_parameters() 38 | 39 | def _check_input_dim(self, input): 40 | if input.dim() != 5: 41 | raise ValueError('expected 5D input (got {}D input)' 42 | .format(input.dim())) 43 | def forward(self, x): 44 | self._check_input_dim(x) 45 | out = torch.zeros_like(x) 46 | for i in range(self.num_spectral): 47 | out[:,:,i] = self.bns[i](x[:,:,i]) 48 | return out 49 | 50 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, num_spectral): 51 | """ 52 | Create a list of modules with together constitute a single conv layer with non-linearity 53 | and optional batchnorm/groupnorm. 54 | 55 | Args: 56 | in_channels (int): number of input channels 57 | out_channels (int): number of output channels 58 | kernel_size(int or tuple): size of the convolving kernel 59 | order (string): order of things, e.g. 60 | 'cr' -> conv + ReLU 61 | 'gcr' -> groupnorm + conv + ReLU 62 | 'cl' -> conv + LeakyReLU 63 | 'ce' -> conv + ELU 64 | 'bcr' -> batchnorm + conv + ReLU 65 | num_groups (int): number of groups for the GroupNorm 66 | padding (int or tuple): add zero-padding added to all three sides of the input 67 | 68 | Return: 69 | list of tuple (name, module) 70 | """ 71 | assert 'c' in order, "Conv layer MUST be present" 72 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' 73 | 74 | modules = [] 75 | for i, char in enumerate(order): 76 | if char == 'r': 77 | modules.append(('ReLU', nn.ReLU(inplace=True))) 78 | elif char == 'l': 79 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) 80 | elif char == 'e': 81 | modules.append(('ELU', nn.ELU(inplace=True))) 82 | elif char == 'c': 83 | # add learnable bias only in the absence of batchnorm/groupnorm 84 | bias = not ('g' in order or 'b' in order or 's' in order) 85 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) 86 | elif char == 'g': 87 | is_before_conv = i < order.index('c') 88 | if is_before_conv: 89 | num_channels = in_channels 90 | else: 91 | num_channels = out_channels 92 | 93 | # use only one group if the given number of groups is greater than the number of channels 94 | if num_channels < num_groups: 95 | num_groups = 1 96 | 97 | assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' 98 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) 99 | elif char == 'b': 100 | is_before_conv = i < order.index('c') 101 | if is_before_conv: 102 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) 103 | else: 104 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) 105 | elif char == 's': 106 | is_before_conv = i < order.index('c') 107 | if is_before_conv: 108 | num_channels = in_channels 109 | else: 110 | num_channels = out_channels 111 | 112 | modules.append(('spectralnorm', Spectral_Normalize(num_features=num_channels, num_spectral=num_spectral))) 113 | else: 114 | raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 's']") 115 | 116 | return modules 117 | 118 | class SingleConv(nn.Sequential): 119 | """ 120 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order 121 | of operations can be specified via the `order` parameter 122 | 123 | Args: 124 | in_channels (int): number of input channels 125 | out_channels (int): number of output channels 126 | kernel_size (int or tuple): size of the convolving kernel 127 | order (string): determines the order of layers, e.g. 128 | 'cr' -> conv + ReLU 129 | 'crg' -> conv + ReLU + groupnorm 130 | 'cl' -> conv + LeakyReLU 131 | 'ce' -> conv + ELU 132 | num_groups (int): number of groups for the GroupNorm 133 | padding (int or tuple): 134 | """ 135 | 136 | def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, padding=1, num_spectral=10): 137 | super(SingleConv, self).__init__() 138 | 139 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding, num_spectral=num_spectral): 140 | self.add_module(name, module) 141 | 142 | class DoubleConv(nn.Sequential): 143 | """ 144 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). 145 | We use (Conv3d+ReLU+GroupNorm3d) by default. 146 | This can be changed however by providing the 'order' argument, e.g. in order 147 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'. 148 | Use padded convolutions to make sure that the output (H_out, W_out) is the same 149 | as (H_in, W_in), so that you don't have to crop in the decoder path. 150 | 151 | Args: 152 | in_channels (int): number of input channels 153 | out_channels (int): number of output channels 154 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder 155 | kernel_size (int or tuple): size of the convolving kernel 156 | order (string): determines the order of layers, e.g. 157 | 'cr' -> conv + ReLU 158 | 'crg' -> conv + ReLU + groupnorm 159 | 'cl' -> conv + LeakyReLU 160 | 'ce' -> conv + ELU 161 | num_groups (int): number of groups for the GroupNorm 162 | padding (int or tuple): add zero-padding added to all three sides of the input 163 | """ 164 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', num_groups=8, 165 | padding=1, num_spectral=10, shape=(192, 192)): 166 | super(DoubleConv, self).__init__() 167 | if encoder: 168 | # we're in the encoder path 169 | conv1_in_channels = in_channels 170 | conv1_out_channels = out_channels // 2 171 | if conv1_out_channels < in_channels: 172 | conv1_out_channels = in_channels 173 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels 174 | else: 175 | # we're in the decoder path, decrease the number of channels in the 1st convolution 176 | conv1_in_channels, conv1_out_channels = in_channels, out_channels 177 | conv2_in_channels, conv2_out_channels = out_channels, out_channels 178 | 179 | # conv1 180 | self.add_module('SingleConv1', 181 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, 182 | padding=padding,num_spectral=num_spectral)) 183 | # conv2 184 | self.add_module('SingleConv2', 185 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, 186 | padding=padding,num_spectral = num_spectral)) 187 | 188 | class Trans_block(nn.Module): 189 | def __init__(self, in_channels, spatial_size, depth_trans=2, transformer_dim=3, 190 | dropout=0.1, use_entmax15=False, att_blocks='att', 191 | seq_length=60, attention_dropout_rate=0.1, init_values=1e-1, use_layerscale=True): 192 | 193 | super(Trans_block, self).__init__() 194 | self.spatial_size = spatial_size 195 | self.seq_length = seq_length 196 | self.att_blocks = att_blocks 197 | 198 | self.trans = Transformer(seq_length=seq_length, num_layers=depth_trans, hidden_size=in_channels, 199 | mlp_dim=transformer_dim * in_channels, 200 | num_heads=8, drop_out=dropout, attention_dropout_rate=attention_dropout_rate, 201 | block=att_blocks, use_entmax15=use_entmax15, init_values=init_values, 202 | use_layerscale=use_layerscale) 203 | 204 | def forward(self, x): 205 | shape = x.shape 206 | x = rearrange(x, 'b c s h w -> (b h w) s c') 207 | x, att = self.trans(x) 208 | x = rearrange(x, '(b p1 p2) s c -> b c s p1 p2', p1=shape[-2], p2=shape[-1]) 209 | return x 210 | 211 | class AdaptivePool_Encoder(nn.Module): 212 | """ 213 | A single module from the encoder path consisting of the optional max 214 | pooling layer (one may specify the MaxPool kernel_size to be different 215 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic 216 | (make sure to use complementary scale_factor in the decoder path) followed by 217 | a DoubleConv module. 218 | Args: 219 | in_channels (int): number of input channels 220 | out_channels (int): number of output channels 221 | conv_kernel_size (int or tuple): size of the convolving kernel 222 | apply_pooling (bool): if True use MaxPool3d before DoubleConv 223 | pool_kernel_size (int or tuple): the size of the window 224 | pool_type (str): pooling layer: 'max' or 'avg' 225 | basic_module(nn.Module): either ResNetBlock or DoubleConv 226 | conv_layer_order (string): determines the order of layers 227 | in `DoubleConv` module. See `DoubleConv` for more info. 228 | num_groups (int): number of groups for the GroupNorm 229 | padding (int or tuple): add zero-padding added to all three sides of the input 230 | """ 231 | 232 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, output_size=(10, 256, 256), 233 | pool_type='max', conv_layer_order='gcr', vis=False, 234 | padding=1, transform=None): 235 | super(AdaptivePool_Encoder, self).__init__() 236 | self.vis = vis 237 | assert pool_type in ['max', 'avg'] 238 | if apply_pooling: 239 | if pool_type == 'max': 240 | self.pooling = nn.AdaptiveMaxPool3d(output_size) 241 | else: 242 | self.pooling = nn.AdaptiveAvgPool3d(output_size) 243 | else: 244 | self.pooling = None 245 | 246 | if transform is not None: 247 | conv_kernel_size = (1, 3, 3) 248 | padding = (0, 1, 1) 249 | 250 | self.basic_module = DoubleConv(in_channels, out_channels, 251 | encoder=True, 252 | kernel_size=conv_kernel_size, 253 | order=conv_layer_order, 254 | num_groups=8, 255 | padding=padding, 256 | num_spectral=output_size[0], 257 | shape=(output_size[1], output_size[2])) 258 | self.transform = transform 259 | 260 | def forward(self, x): 261 | if self.pooling is not None: 262 | x = self.pooling(x) 263 | x1 = self.basic_module(x) 264 | 265 | if self.transform is not None: 266 | x1 = self.transform(x1) 267 | 268 | return x1 269 | -------------------------------------------------------------------------------- /code/train_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Apr 4 09:53:41 2023 5 | @author: Boxiang Yun School:ECNU Email:boxiangyun@gmail.com 6 | """ 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import argparse 11 | import json 12 | import numpy as np 13 | import pandas as pd 14 | import segmentation_models_pytorch as smp 15 | 16 | from spectr import SpecTr 17 | from torch import optim 18 | from torch.utils.data import DataLoader 19 | 20 | from local_utils.tools import save_dict 21 | from local_utils.seed_everything import seed_reproducer 22 | 23 | from tqdm import tqdm 24 | from Data_Generate import Data_Generate_Cho 25 | from argument import Transform 26 | from local_utils.misc import AverageMeter 27 | from local_utils.tools import EarlyStopping 28 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 29 | from local_utils.metrics import iou, dice 30 | 31 | import warnings 32 | warnings.filterwarnings('ignore') 33 | 34 | 35 | def main(args): 36 | seed_reproducer(42) 37 | 38 | root_path = args.root_path 39 | dataset_hyper = args.dataset_hyper 40 | dataset_mask = args.dataset_mask 41 | dataset_divide = args.dataset_divide 42 | batch = args.batch 43 | lr = args.lr 44 | wd = args.wd 45 | experiment_name = args.experiment_name 46 | output_path = args.output 47 | epochs = args.epochs 48 | cutting = args.cutting 49 | spectral_number = args.spectral_number 50 | fold = args.fold 51 | choose_translayer = args.choose_translayer 52 | worker = args.worker 53 | outtype = args.outtype 54 | channels_index = args.channels_index 55 | device = args.device 56 | decode_choice = args.decode_choice 57 | init_values = args.init_values 58 | 59 | images_root_path = os.path.join(root_path, dataset_hyper) 60 | mask_root_path = os.path.join(root_path, dataset_mask) 61 | dataset_json = os.path.join(root_path, dataset_divide) 62 | with open(dataset_json, 'r') as load_f: 63 | dataset_dict = json.load(load_f) 64 | 65 | #Data Augmentation 66 | transform = Transform(Rotate_ratio=0.2, Flip_ratio=0.2) 67 | device = torch.device(device) 68 | 69 | if os.path.exists(f'{output_path}/{experiment_name}') == False: 70 | os.mkdir(f'{output_path}/{experiment_name}') 71 | save_dict(os.path.join(f'{output_path}/{experiment_name}', 'args.csv'), args.__dict__) 72 | 73 | channels = channels_index 74 | spectral_number = spectral_number if channels is None else len(channels_index) 75 | multi_class = 1 76 | 77 | dice_criterion = smp.losses.DiceLoss(eps=1., mode='binary', from_logits=False) 78 | bce_criterion = nn.BCELoss() 79 | 80 | Miou = iou 81 | MDice = dice 82 | 83 | #For slide window operation in the validation stage 84 | def patch_index(shape, patchsize, stride): 85 | s, h, w = shape 86 | sx = (w - patchsize[1]) // stride[1] + 1 87 | sy = (h - patchsize[0]) // stride[0] + 1 88 | sz = (s - patchsize[2]) // stride[2] + 1 89 | 90 | for x in range(sx): 91 | xs = stride[1] * x 92 | for y in range(sy): 93 | ys = stride[0] * y 94 | for z in range(sz): 95 | zs = stride[2] * z 96 | yield slice(zs, zs + patchsize[2]), slice(ys, ys + patchsize[0]), slice(xs, xs + patchsize[1]) 97 | 98 | 99 | for k in fold: 100 | train_fold = list(set([1, 2, 3, 4]) - set([k])) 101 | print(f"train_fold is {train_fold} and valid_fold is {k}") 102 | 103 | train_file_dict = dataset_dict[f'fold{train_fold[0]}'] + dataset_dict[f'fold{train_fold[1]}'] + dataset_dict[ 104 | f'fold{train_fold[2]}'] 105 | 106 | train_images_path = [os.path.join(images_root_path, i) for i in train_file_dict] 107 | train_masks_path = [os.path.join(mask_root_path, f'{i[:-4]}.png') for i in train_file_dict] 108 | val_images_path = [os.path.join(images_root_path, i) for i in dataset_dict[f'fold{k}']] 109 | val_masks_path = [os.path.join(mask_root_path, f'{i[:-4]}.png') for i in dataset_dict[f'fold{k}']] 110 | 111 | train_db = Data_Generate_Cho(train_images_path, train_masks_path, cutting=cutting, 112 | transform=transform, channels=channels, outtype=outtype) 113 | train_loader = DataLoader(train_db, batch_size=batch, shuffle=True, num_workers=worker) 114 | 115 | val_db = Data_Generate_Cho(val_images_path, val_masks_path, cutting=None, transform=None, 116 | channels=channels, outtype=outtype) 117 | val_loader = DataLoader(val_db, batch_size=1, shuffle=False, num_workers=worker) 118 | 119 | model = SpecTr(choose_translayer=choose_translayer, 120 | spatial_size=(cutting, cutting), 121 | max_seq=spectral_number, 122 | classes=multi_class, 123 | decode_choice=decode_choice, 124 | init_values=init_values).to(device) 125 | 126 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) 127 | scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=1e-8) 128 | 129 | # only record, we are not use early stop. 130 | early_stopping_val = EarlyStopping(patience=1000, verbose=True, 131 | path=os.path.join(f'{output_path}/{experiment_name}', 132 | f'best_fold{k}_{experiment_name}.pth')) 133 | 134 | history = {'epoch': [], 'LR': [], 'train_loss': [], 'train_iou': [], 'val_dice': [], 'val_iou': [], 135 | 'val_count': []} 136 | 137 | 138 | for epoch in range(epochs): 139 | train_losses = AverageMeter() 140 | val_losses = AverageMeter() 141 | train_iou, val_iou, val_dice = 0, 0, 0 142 | print('now start train ..') 143 | print('epoch {}/{}, LR:{}'.format(epoch + 1, epochs, optimizer.param_groups[0]['lr'])) 144 | train_losses.reset() 145 | model.train() 146 | try: 147 | for idx, sample in enumerate(tqdm(train_loader)): 148 | image, label = sample 149 | image, label = image.to(device), label.to(device) 150 | out = model(image) 151 | loss = dice_criterion(out, label) * 0.5 + bce_criterion(out, label) * 0.5 152 | 153 | optimizer.zero_grad() 154 | loss.backward() 155 | optimizer.step() 156 | train_losses.update(loss.item()) 157 | out = out.cpu().detach().numpy() 158 | label = label.cpu().detach().numpy() 159 | out = np.where(out > 0.5, 1, 0) 160 | label = np.where(label > 0.5, 1, 0) 161 | 162 | train_iou = train_iou + np.mean( 163 | [Miou(out[b], label[b]) for b in range(len(out))]) 164 | 165 | train_iou = train_iou / (idx + 1) 166 | 167 | except RuntimeError as e: 168 | if 'out of memory' in str(e): 169 | print('| WARNING: ran out of memory, please reduce batch') 170 | for p in model.parameters(): 171 | if p.grad is not None: 172 | del p.grad # free some memory 173 | torch.cuda.empty_cache() 174 | return 175 | else: 176 | raise e 177 | 178 | print('now start evaluate ...') 179 | model.eval() 180 | val_losses.reset() 181 | for idx, sample in enumerate(tqdm(val_loader)): 182 | image, label = sample 183 | image = image.squeeze() 184 | spectrum_shape, shape_h, shape_w = image.shape 185 | patch_idx = list(patch_index((spectrum_shape, shape_h, shape_w), (cutting, cutting, spectrum_shape), 186 | (64, 128, 1))) # origan shape is 256, 320; 128=320-192, 64=256-192 187 | num_collect = torch.zeros((shape_h, shape_w), dtype=torch.uint8).to(device) 188 | pred_collect = torch.zeros((shape_h, shape_w)).to(device) 189 | for i in range(0, len(patch_idx), batch): 190 | with torch.no_grad(): 191 | output = model(torch.stack([image[x] for x in patch_idx[i:i + batch]])[None].to(device)).squeeze(1) 192 | for j in range(output.size(0)): 193 | num_collect[patch_idx[i + j][1:]] += 1 194 | pred_collect[patch_idx[i + j][1:]] += output[j] 195 | 196 | out = pred_collect / num_collect.float() 197 | out[torch.isnan(out)] = 0 198 | 199 | out, label = out.cpu().detach().numpy()[None][None], label.cpu().detach().numpy() 200 | 201 | out = np.where(out > 0.5, 1, 0) 202 | label = np.where(label > 0.5, 1, 0) 203 | val_dice = val_dice + MDice(out, label) 204 | val_iou = val_iou + Miou(out, label) 205 | 206 | val_iou = val_iou / (idx + 1) 207 | val_dice = val_dice / (idx + 1) 208 | 209 | print('epoch {}/{}\t LR:{}\t train loss:{}\t train_iou:{}\t val_dice:{}\t val_iou:{}' \ 210 | .format(epoch + 1, epochs, optimizer.param_groups[0]['lr'], train_losses.avg, train_iou, val_dice, 211 | val_iou)) 212 | history['train_loss'].append(train_losses.avg) 213 | history['val_dice'].append(val_dice) 214 | history['val_iou'].append(val_iou) 215 | history['train_iou'].append(train_iou) 216 | 217 | history['epoch'].append(epoch + 1) 218 | history['LR'].append(optimizer.param_groups[0]['lr']) 219 | 220 | scheduler.step() 221 | early_stopping_val(-val_dice, model) 222 | history['val_count'].append(early_stopping_val.counter) 223 | 224 | if args.save_every_epoch: 225 | if (epoch + 1) % 5 == 0: 226 | torch.save(model.state_dict(), 227 | os.path.join(f'{output_path}/{experiment_name}', f'middle_{k}fold_{epoch}.pth')) 228 | 229 | if epoch + 1 == epochs: 230 | torch.save(model.state_dict(), 231 | os.path.join(f'{output_path}/{experiment_name}', f'final_{k}fold_{epoch}.pth')) 232 | 233 | 234 | if early_stopping_val.early_stop: 235 | print("Early stopping") 236 | break 237 | 238 | history_pd = pd.DataFrame(history) 239 | history_pd.to_csv(os.path.join(f'{output_path}/{experiment_name}', f'log_fold{k}.csv'), index=False) 240 | history_pd = pd.DataFrame(history) 241 | history_pd.to_csv(os.path.join(f'{output_path}/{experiment_name}', f'log_fold{k}.csv'), index=False) 242 | 243 | 244 | if __name__ == '__main__': 245 | parser = argparse.ArgumentParser() 246 | parser.add_argument('--root_path', '-r', type=str, default='./Cholangiocarcinoma/L') 247 | parser.add_argument('--dataset_hyper', '-dh', type=str, default='MHSI') 248 | parser.add_argument('--dataset_mask', '-dm', type=str, default='Mask') 249 | parser.add_argument('--dataset_divide', '-dd', type=str, default='four_fold.json') 250 | parser.add_argument('--fold', '-fold', type=int, default=[1, 2, 3, 4], nargs='+') 251 | parser.add_argument('--device', '-dev', type=str, default='cuda:0') 252 | 253 | parser.add_argument('--worker', '-nw', type=int, 254 | default=4) 255 | parser.add_argument('--outtype', '-outt', type=str, 256 | default='3d') 257 | 258 | parser.add_argument('--batch', '-b', type=int, default=1) 259 | 260 | parser.add_argument('--lr', '-l', default=0.0003, type=float) 261 | parser.add_argument('--wd', '-w', default=5e-4, type=float) 262 | 263 | parser.add_argument('--spectral_number', '-sn', default=60, type=int) 264 | parser.add_argument('--channels_index', '-chi', type=int, default=None, nargs='+') 265 | 266 | parser.add_argument('--output', '-o', type=str, default='./checkpoint') 267 | parser.add_argument('--choose_translayer', '-ct', nargs='+', type=int, default=[0, 1, 1, 1]) 268 | parser.add_argument('--experiment_name', '-name', type=str, default='SpecTr_XXXX') 269 | parser.add_argument('--cutting', '-cut', default=192, type=int) 270 | parser.add_argument('--epochs', '-e', type=int, default=75) 271 | parser.add_argument('--decode_choice', '-dc', default='3D', choices=['3D', 'decoder_2D']) 272 | parser.add_argument('--init_values', '-initv', type=float, default=0.01) 273 | parser.add_argument('--save_every_epoch', '-see', default=False, action='store_true') 274 | 275 | args = parser.parse_args() 276 | main(args) 277 | -------------------------------------------------------------------------------- /code/vit_modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import math 8 | import torch.nn.functional as F 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import Dropout, Linear, LayerNorm 13 | 14 | from entmax import EntmaxAlpha 15 | from einops import rearrange 16 | 17 | from timm.models.layers import DropPath 18 | 19 | 20 | def swish(x): 21 | return x * torch.sigmoid(x) 22 | 23 | def sharpen(x, T, eps=1e-6): 24 | temp = x ** (1 / T) 25 | return (temp + eps) / (temp.sum(axis=-1, keepdims=True) + eps) 26 | 27 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 28 | 29 | 30 | def np2th(weights, conv=False): 31 | """Possibly convert HWIO to OIHW.""" 32 | if conv: 33 | weights = weights.transpose([3, 2, 0, 1]) 34 | return torch.from_numpy(weights) 35 | 36 | 37 | def max_neg_value(tensor): 38 | return -torch.finfo(tensor.dtype).max 39 | 40 | class Attention(nn.Module): 41 | def __init__(self, hidden_size, num_heads, attention_dropout_rate, 42 | use_entmax15, vis): 43 | super(Attention, self).__init__() 44 | self.vis = vis 45 | 46 | self.num_attention_heads = num_heads 47 | self.attention_head_size = hidden_size // self.num_attention_heads 48 | self.all_head_size = self.num_attention_heads * self.attention_head_size 49 | 50 | self.query = Linear(hidden_size, self.all_head_size) 51 | self.key = Linear(hidden_size, self.all_head_size) 52 | self.value = Linear(hidden_size, self.all_head_size) 53 | 54 | self.out = Linear(hidden_size, hidden_size) 55 | self.attn_dropout = Dropout(attention_dropout_rate) 56 | self.proj_dropout = Dropout(attention_dropout_rate) 57 | 58 | self.use_entmax15 = use_entmax15 59 | if use_entmax15 == 'softmax': 60 | self.att_fn = F.softmax 61 | elif use_entmax15 == 'adaptive_entmax': 62 | self.att_fn = EntmaxAlpha(self.num_attention_heads) 63 | else: 64 | raise ValueError("Oops! That was invalid attention function.Try again...") 65 | 66 | def transpose_for_scores(self, x): 67 | 68 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 69 | x = x.view(*new_x_shape) 70 | return x.permute(0, 2, 1, 3) 71 | 72 | def forward(self, hidden_states): 73 | shape = hidden_states.shape 74 | mixed_query_layer = self.query(hidden_states) 75 | mixed_key_layer = self.key(hidden_states) 76 | mixed_value_layer = self.value(hidden_states) 77 | 78 | query_layer = self.transpose_for_scores(mixed_query_layer) 79 | key_layer = self.transpose_for_scores(mixed_key_layer) 80 | value_layer = self.transpose_for_scores(mixed_value_layer) 81 | 82 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 83 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 84 | 85 | 86 | if self.use_entmax15 == 'adaptive_entmax': 87 | attention_probs = self.att_fn(attention_scores) # sharpen(attention_scores,0.5) 88 | else: 89 | attention_probs = self.att_fn(attention_scores, dim=-1) 90 | 91 | weights = attention_probs if self.vis else None 92 | 93 | attention_probs = self.attn_dropout(attention_probs) 94 | 95 | context_layer = torch.matmul(attention_probs, value_layer) 96 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 97 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 98 | context_layer = context_layer.view(*new_context_layer_shape) 99 | attention_output = self.out(context_layer) 100 | attention_output = self.proj_dropout(attention_output) 101 | return attention_output, weights 102 | 103 | class Attention_query_global(nn.Module):#query spectral attention 104 | def __init__(self, hidden_size, num_heads, attention_dropout_rate, 105 | use_entmax15, vis): 106 | super(Attention_query_global, self).__init__() 107 | self.vis = vis 108 | self.num_attention_heads = num_heads 109 | self.attention_head_size = hidden_size // self.num_attention_heads 110 | self.all_head_size = self.num_attention_heads * self.attention_head_size 111 | 112 | self.query = Linear(hidden_size, self.all_head_size) 113 | self.key = Linear(hidden_size, self.all_head_size) 114 | self.value = Linear(hidden_size, self.all_head_size) 115 | 116 | self.out = Linear(hidden_size, hidden_size) 117 | self.attn_dropout = Dropout(attention_dropout_rate) 118 | self.proj_dropout = Dropout(attention_dropout_rate) 119 | 120 | self.use_entmax15 = use_entmax15 121 | if use_entmax15 == 'softmax': 122 | self.att_fn = F.softmax 123 | elif use_entmax15 == 'adaptive_entmax': 124 | self.att_fn = EntmaxAlpha(self.num_attention_heads) 125 | else: 126 | raise ValueError("Oops! That was invalid attention function.Try again...") 127 | 128 | def transpose_for_scores(self, x): 129 | 130 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 131 | x = x.view(*new_x_shape) 132 | return x.permute(0, 2, 1, 3) 133 | 134 | def forward(self, hidden_states): 135 | shape = hidden_states.shape 136 | 137 | mixed_key_layer = self.key(hidden_states) 138 | mixed_value_layer = self.value(hidden_states) 139 | mixed_query_layer = self.query(hidden_states.mean(1).unsqueeze(1)) 140 | 141 | query_layer = self.transpose_for_scores(mixed_query_layer) 142 | key_layer = self.transpose_for_scores(mixed_key_layer) 143 | value_layer = self.transpose_for_scores(mixed_value_layer) 144 | 145 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 146 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 147 | 148 | if self.use_entmax15 == 'adaptive_entmax': 149 | attention_probs = self.att_fn(attention_scores) # sharpen(attention_scores,0.5) 150 | else: 151 | attention_probs = self.att_fn(attention_scores, dim=-1) 152 | 153 | weights = attention_probs if self.vis else None 154 | 155 | attention_probs = self.attn_dropout(attention_probs) 156 | 157 | context_layer = torch.matmul(attention_probs, value_layer) 158 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 159 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 160 | context_layer = context_layer.view(*new_context_layer_shape) 161 | attention_output = self.out(context_layer) 162 | attention_output = self.proj_dropout(attention_output) 163 | return attention_output, weights 164 | 165 | class Mlp(nn.Module): 166 | def __init__(self, hidden_size, mlp_dim, drop_out, out_dim=None): 167 | super(Mlp, self).__init__() 168 | self.fc1 = Linear(hidden_size, mlp_dim) 169 | if out_dim is not None: 170 | self.fc2 = Linear(mlp_dim, out_dim) 171 | else: 172 | self.fc2 = Linear(mlp_dim, hidden_size) 173 | self.act_fn = ACT2FN["gelu"] 174 | self.dropout = Dropout(drop_out) 175 | 176 | self._init_weights() 177 | 178 | def _init_weights(self): 179 | nn.init.xavier_uniform_(self.fc1.weight) 180 | nn.init.xavier_uniform_(self.fc2.weight) 181 | nn.init.normal_(self.fc1.bias, std=1e-6) 182 | nn.init.normal_(self.fc2.bias, std=1e-6) 183 | 184 | def forward(self, x): 185 | x = self.fc1(x) 186 | x = self.act_fn(x) 187 | x = self.dropout(x) 188 | x = self.fc2(x) 189 | x = self.dropout(x) 190 | return x 191 | 192 | class PositionalEncoding(nn.Module): 193 | def __init__(self, hidden_size=784, seq_length=10, drop_out=0.): 194 | super(PositionalEncoding, self).__init__() 195 | self.dropout = nn.Dropout(p=drop_out) 196 | 197 | pe = torch.zeros(seq_length, hidden_size) 198 | position = torch.arange(0, seq_length, dtype=torch.float).unsqueeze(1) 199 | div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size)) 200 | pe[:, 0::2] = torch.sin(position * div_term) 201 | pe[:, 1::2] = torch.cos(position * div_term) 202 | pe = pe.unsqueeze(0) # .transpose(0, 1) 203 | self.register_buffer('pe', pe) 204 | 205 | def forward(self, x): 206 | x = x + self.pe 207 | return self.dropout(x) 208 | 209 | class Embeddings(nn.Module): 210 | """Construct the embeddings from patch, position embeddings. 211 | """ 212 | 213 | def __init__(self, hidden_size=768, seq_length=10, drop_out=0.): 214 | super(Embeddings, self).__init__() 215 | 216 | self.position_embeddings = nn.Parameter(torch.zeros(1, seq_length, hidden_size)) 217 | self.dropout = Dropout(drop_out) 218 | 219 | def forward(self, x): 220 | embeddings = x + self.position_embeddings 221 | embeddings = self.dropout(embeddings) 222 | return embeddings 223 | 224 | class Block(nn.Module): 225 | def __init__(self, hidden_size, mlp_dim, num_heads, drop_out, attention_dropout_rate, use_entmax15, vis): 226 | super(Block, self).__init__() 227 | self.hidden_size = hidden_size 228 | 229 | self.attention_norm = LayerNorm(hidden_size, eps=1e-6) 230 | self.ffn_norm = LayerNorm(hidden_size, eps=1e-6) 231 | self.ffn = Mlp(hidden_size, mlp_dim, drop_out) 232 | self.attn = Attention(hidden_size, num_heads, attention_dropout_rate, 233 | use_entmax15=use_entmax15, vis=vis) 234 | 235 | def forward(self, x): 236 | h = x 237 | x = self.attention_norm(x) 238 | x, weights = self.attn(x) 239 | x = x + h 240 | 241 | h = x 242 | x = self.ffn_norm(x) 243 | x = self.ffn(x) 244 | x = x + h 245 | 246 | return x, weights 247 | 248 | class Block_LayerScale(nn.Module): 249 | def __init__(self, hidden_size, mlp_dim, num_heads, drop_out, attention_dropout_rate, use_entmax15, vis, 250 | init_values=1e-1, use_layerscale=True): 251 | super(Block_LayerScale, self).__init__() 252 | self.hidden_size = hidden_size 253 | self.attention_norm = LayerNorm(hidden_size, eps=1e-6) 254 | self.ffn_norm = LayerNorm(hidden_size, eps=1e-6) 255 | self.ffn = Mlp(hidden_size, mlp_dim, drop_out) 256 | self.attn = Attention(hidden_size, num_heads, attention_dropout_rate, 257 | use_entmax15=use_entmax15, vis=vis) 258 | 259 | if use_layerscale==False: 260 | self.register_buffer("gamma_1", init_values * torch.ones((hidden_size))) 261 | self.register_buffer("gamma_2", init_values * torch.ones((hidden_size))) 262 | else: 263 | self.gamma_1 = nn.Parameter(init_values * torch.ones((hidden_size)), requires_grad=True) 264 | self.gamma_2 = nn.Parameter(init_values * torch.ones((hidden_size)), requires_grad=True) 265 | 266 | def forward(self, x): 267 | 268 | h = x 269 | x = self.attention_norm(x) 270 | x, weights = self.attn(x) 271 | x = self.gamma_1 * x 272 | x = x + h 273 | 274 | h = x 275 | x = self.ffn_norm(x) 276 | x = self.ffn(x) 277 | x = self.gamma_2 * x 278 | x = x + h 279 | return x, weights 280 | 281 | class Encoder(nn.Module): 282 | def __init__(self, num_layers, hidden_size, mlp_dim, num_heads, drop_out, attention_dropout_rate, use_entmax15, vis, 283 | block='att', init_values=1e-1, use_layerscale=True): 284 | super(Encoder, self).__init__() 285 | self.vis = vis 286 | self.layer = nn.ModuleList() 287 | self.encoder_norm = LayerNorm(hidden_size, eps=1e-6) 288 | if block == 'att': 289 | for n in range(num_layers): 290 | layer = Block(hidden_size, mlp_dim, num_heads, drop_out, attention_dropout_rate, use_entmax15, vis, 291 | ) 292 | self.layer.append(copy.deepcopy(layer)) 293 | elif block == 'layerscale': 294 | for n in range(num_layers): 295 | layer = Block_LayerScale(hidden_size, mlp_dim, num_heads, drop_out, attention_dropout_rate, 296 | use_entmax15, vis, init_values=init_values, use_layerscale=use_layerscale) 297 | self.layer.append(copy.deepcopy(layer)) 298 | else: 299 | raise ValueError("Oops! That was invalid attention layers.Try 'att','ca', 'layerscale'!!!") 300 | 301 | def forward(self, hidden_states): 302 | attn_weights = [] 303 | for layer_block in self.layer: 304 | hidden_states, weights = layer_block(hidden_states) 305 | if self.vis: 306 | attn_weights.append(weights) 307 | 308 | return hidden_states, attn_weights 309 | 310 | class Transformer(nn.Module): 311 | def __init__(self, seq_length, num_layers, hidden_size, mlp_dim, num_heads, drop_out, 312 | attention_dropout_rate, use_entmax15, vis=False, block='att', 313 | init_values=1e-1, use_layerscale=True): 314 | super(Transformer, self).__init__() 315 | self.vis = vis 316 | self.block = block 317 | 318 | self.embeddings = PositionalEncoding(hidden_size, seq_length, drop_out) 319 | 320 | self.encoder = Encoder(num_layers, hidden_size, mlp_dim, num_heads, drop_out, attention_dropout_rate, 321 | use_entmax15, vis, init_values=init_values, 322 | block=block, use_layerscale=use_layerscale) 323 | 324 | def forward(self, input_ids): 325 | embedding_output = self.embeddings(input_ids) 326 | encoded, attn_weights = self.encoder(embedding_output) 327 | return encoded, attn_weights 328 | 329 | class ParallelBlock_CAT(nn.Module): 330 | """ Parallel block class. """ 331 | 332 | def __init__(self, dims, num_heads, mlp_ratios=[], drop=0., attn_drop=0., 333 | drop_path=0., norm_layer=nn.LayerNorm, vis=False, use_entmax15='softmax', 334 | upsample_mode='trilinear', init_values=1e-2, use_layerscale=False): 335 | super().__init__() 336 | 337 | self.norm12 = norm_layer(dims[1]) 338 | self.norm13 = norm_layer(dims[2]) 339 | self.norm14 = norm_layer(dims[3]) 340 | self.att2 = Attention( 341 | dims[1], num_heads=num_heads, attention_dropout_rate=attn_drop, 342 | use_entmax15=use_entmax15, vis=vis 343 | ) 344 | self.att3 = Attention( 345 | dims[2], num_heads=num_heads, attention_dropout_rate=attn_drop, 346 | use_entmax15=use_entmax15, vis=vis 347 | ) 348 | self.att4 = Attention( 349 | dims[3], num_heads=num_heads, attention_dropout_rate=attn_drop, 350 | use_entmax15=use_entmax15, vis=vis 351 | ) 352 | 353 | # from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 354 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 355 | 356 | # MLP. 357 | self.norm22 = norm_layer(dims[1]) 358 | self.norm23 = norm_layer(dims[2]) 359 | self.norm24 = norm_layer(dims[3]) 360 | 361 | mlp_input_dim = sum((dims[1], dims[2], dims[3])) 362 | self.mlp2 = Mlp(hidden_size=mlp_input_dim, mlp_dim=int(dims[1]*mlp_ratios[1]), 363 | drop_out=drop, out_dim=dims[1]) 364 | self.mlp3 = Mlp(hidden_size=mlp_input_dim, mlp_dim=int(dims[2]*mlp_ratios[2]), 365 | drop_out=drop, out_dim=dims[2]) 366 | self.mlp4 = Mlp(hidden_size=mlp_input_dim, mlp_dim=int(dims[3]*mlp_ratios[3]), 367 | drop_out=drop, out_dim=dims[3]) 368 | 369 | self.upsample_mode = upsample_mode 370 | 371 | self.use_layerscale = use_layerscale 372 | if use_layerscale == False: 373 | self.register_buffer("gamma_1_1", init_values * torch.ones((dims[1]))) 374 | self.register_buffer("gamma_1_2", init_values * torch.ones((dims[1]))) 375 | self.register_buffer("gamma_2_1", init_values * torch.ones((dims[2]))) 376 | self.register_buffer("gamma_2_2", init_values * torch.ones((dims[2]))) 377 | self.register_buffer("gamma_3_1", init_values * torch.ones((dims[3]))) 378 | self.register_buffer("gamma_3_2", init_values * torch.ones((dims[3]))) 379 | else: 380 | self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dims[1])), requires_grad=True) 381 | self.gamma_1_2 = nn.Parameter(init_values * torch.ones((dims[1])), requires_grad=True) 382 | self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dims[2])), requires_grad=True) 383 | self.gamma_2_2 = nn.Parameter(init_values * torch.ones((dims[2])), requires_grad=True) 384 | self.gamma_3_1 = nn.Parameter(init_values * torch.ones((dims[3])), requires_grad=True) 385 | self.gamma_3_2 = nn.Parameter(init_values * torch.ones((dims[3])), requires_grad=True) 386 | 387 | 388 | 389 | def upsample(self, x, scale_size, size): 390 | """ Feature map up-sampling. """ 391 | return self.interpolate(x, scale_size=scale_size, input_size=size) 392 | 393 | def downsample(self, x, scale_size, size): 394 | """ Feature map down-sampling. """ 395 | return self.interpolate(x, scale_size=scale_size, input_size=size) 396 | 397 | def interpolate(self, x, scale_size, input_size): 398 | """ Feature map interpolation. """ 399 | B, S, C = x.shape 400 | S, H, W = input_size 401 | # assert N == H * W 402 | img_tokens = x 403 | 404 | img_tokens = img_tokens.transpose(1, 2).reshape(-1, C, S, H, W) 405 | img_tokens = F.interpolate(img_tokens, size=scale_size, mode=self.upsample_mode) 406 | out = img_tokens.reshape(-1, C, scale_size[0]).transpose(1, 2) 407 | 408 | return out 409 | 410 | def forward(self, x1, x2, x3, x4): 411 | _, (_, _, S2, H2, W2), (_, _, S3, H3, W3), (_, _, S4, H4, W4) = x1.shape, x2.shape, x3.shape, x4.shape 412 | x2 = rearrange(x2, 'b c s h w -> (b h w) s c') 413 | x3 = rearrange(x3, 'b c s h w -> (b h w) s c') 414 | x4 = rearrange(x4, 'b c s h w -> (b h w) s c') 415 | # Conv-Attention. 416 | 417 | cur2 = self.norm12(x2) 418 | cur3 = self.norm13(x3) 419 | cur4 = self.norm14(x4) 420 | cur2, w2 = self.att2(cur2) 421 | cur3, w3 = self.att3(cur3) 422 | cur4, w4 = self.att4(cur4) 423 | 424 | if self.use_layerscale == True: 425 | x2 = x2 + self.drop_path(cur2) * self.gamma_1_1 426 | x3 = x3 + self.drop_path(cur3) * self.gamma_2_1 427 | x4 = x4 + self.drop_path(cur4) * self.gamma_3_1 428 | else: 429 | x2 = x2 + self.drop_path(cur2) 430 | x3 = x3 + self.drop_path(cur3) 431 | x4 = x4 + self.drop_path(cur4) 432 | 433 | cur2 = self.norm22(x2) 434 | cur3 = self.norm23(x3) 435 | cur4 = self.norm24(x4) 436 | 437 | upsample3_2 = self.upsample(cur3, scale_size=(S2, H2, W2), size=(S3, H3, W3)) 438 | upsample4_3 = self.upsample(cur4, scale_size=(S3, H3, W3), size=(S4, H4, W4)) 439 | upsample4_2 = self.upsample(cur4, scale_size=(S2, H2, W2), size=(S4, H4, W4)) 440 | downsample2_3 = self.downsample(cur2, scale_size=(S3, H3, W3), size=(S2, H2, W2)) 441 | downsample3_4 = self.downsample(cur3, scale_size=(S4, H4, W4), size=(S3, H3, W3)) 442 | downsample2_4 = self.downsample(cur2, scale_size=(S4, H4, W4), size=(S2, H2, W2)) 443 | 444 | 445 | cur2 = torch.cat((cur2, upsample3_2, upsample4_2), dim=-1) 446 | cur3 = torch.cat((cur3, upsample4_3, downsample2_3), dim=-1) 447 | cur4 = torch.cat((cur4, downsample3_4, downsample2_4), dim=-1) 448 | 449 | # MLP. 450 | cur2 = self.mlp2(cur2) 451 | cur3 = self.mlp3(cur3) 452 | cur4 = self.mlp4(cur4) 453 | 454 | if self.use_layerscale == True: 455 | x2 = x2 + self.drop_path(cur2) * self.gamma_1_2 456 | x3 = x3 + self.drop_path(cur3) * self.gamma_2_2 457 | x4 = x4 + self.drop_path(cur4) * self.gamma_3_2 458 | else: 459 | x2 = x2 + self.drop_path(cur2) 460 | x3 = x3 + self.drop_path(cur3) 461 | x4 = x4 + self.drop_path(cur4) 462 | 463 | x2 = rearrange(x2, '(b p1 p2) s c -> b c s p1 p2', p1=H2, p2=W2) 464 | x3 = rearrange(x3, '(b p1 p2) s c -> b c s p1 p2', p1=H3, p2=W3) 465 | x4 = rearrange(x4, '(b p1 p2) s c -> b c s p1 p2', p1=H4, p2=W4) 466 | 467 | return x1, x2, x3, x4 468 | 469 | class Spectral_ZipBlock_four(nn.Module): 470 | """ Parallel block class. """ 471 | def __init__(self, dims, num_heads, mlp_ratios=[], drop=0., attn_drop=0., use_layerscale=False, init_values=1e-1, 472 | drop_path=0., norm_layer=nn.LayerNorm, vis=False, use_entmax15='softmax', 473 | upsample_mode='trilinear'): 474 | super().__init__() 475 | self.norm11 = norm_layer(dims[0]) 476 | self.norm12 = norm_layer(dims[1]) 477 | self.norm13 = norm_layer(dims[2]) 478 | self.norm14 = norm_layer(dims[3]) 479 | self.att1 = Attention_query_global( 480 | dims[0], num_heads=num_heads, attention_dropout_rate=attn_drop, 481 | use_entmax15=use_entmax15, vis=vis 482 | ) 483 | 484 | self.att2 = Attention_query_global( 485 | dims[1], num_heads=num_heads, attention_dropout_rate=attn_drop, 486 | use_entmax15=use_entmax15, vis=vis 487 | ) 488 | self.att3 = Attention_query_global( 489 | dims[2], num_heads=num_heads, attention_dropout_rate=attn_drop, 490 | use_entmax15=use_entmax15, vis=vis 491 | ) 492 | self.att4 = Attention_query_global( 493 | dims[3], num_heads=num_heads, attention_dropout_rate=attn_drop, 494 | use_entmax15=use_entmax15, vis=vis 495 | ) 496 | 497 | # from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 498 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 499 | 500 | # MLP. 501 | self.norm21 = norm_layer(dims[0]) 502 | self.norm22 = norm_layer(dims[1]) 503 | self.norm23 = norm_layer(dims[2]) 504 | self.norm24 = norm_layer(dims[3]) 505 | 506 | self.mlp1 = Mlp(hidden_size=dims[0], mlp_dim=int(dims[0] * mlp_ratios[0]), 507 | drop_out=drop, out_dim=dims[0]) 508 | self.mlp2 = Mlp(hidden_size=dims[1], mlp_dim=int(dims[1] * mlp_ratios[1]), 509 | drop_out=drop, out_dim=dims[1]) 510 | self.mlp3 = Mlp(hidden_size=dims[2], mlp_dim=int(dims[2] * mlp_ratios[2]), 511 | drop_out=drop, out_dim=dims[2]) 512 | self.mlp4 = Mlp(hidden_size=dims[3], mlp_dim=int(dims[3] * mlp_ratios[3]), 513 | drop_out=drop, out_dim=dims[3]) 514 | 515 | self.upsample_mode = upsample_mode 516 | 517 | self.use_layerscale = use_layerscale 518 | if use_layerscale == False: 519 | self.gamma_1_1 = torch.ones((dims[0]), requires_grad=True) 520 | self.gamma_1_2 = torch.ones((dims[0]), requires_grad=True) 521 | self.gamma_2_1 = torch.ones((dims[1]), requires_grad=True) 522 | self.gamma_2_2 = torch.ones((dims[1]), requires_grad=True) 523 | self.gamma_3_1 = torch.ones((dims[2]), requires_grad=True) 524 | self.gamma_3_2 = torch.ones((dims[3]), requires_grad=True) 525 | self.gamma_4_1 = torch.ones((dims[3]), requires_grad=True) 526 | self.gamma_4_2 = torch.ones((dims[3]), requires_grad=True) 527 | else: 528 | self.gamma_1_1 = nn.Parameter(init_values * torch.ones((dims[0])), requires_grad=True) 529 | self.gamma_1_2 = nn.Parameter(init_values * torch.ones((dims[0])), requires_grad=True) 530 | self.gamma_2_1 = nn.Parameter(init_values * torch.ones((dims[1])), requires_grad=True) 531 | self.gamma_2_2 = nn.Parameter(init_values * torch.ones((dims[1])), requires_grad=True) 532 | self.gamma_3_1 = nn.Parameter(init_values * torch.ones((dims[2])), requires_grad=True) 533 | self.gamma_3_2 = nn.Parameter(init_values * torch.ones((dims[2])), requires_grad=True) 534 | self.gamma_4_1 = nn.Parameter(init_values * torch.ones((dims[3])), requires_grad=True) 535 | self.gamma_4_2 = nn.Parameter(init_values * torch.ones((dims[3])), requires_grad=True) 536 | 537 | def upsample(self, x, scale_size, size): 538 | """ Feature map up-sampling. """ 539 | return self.interpolate(x, scale_size=scale_size, input_size=size) 540 | 541 | def downsample(self, x, scale_size, size): 542 | """ Feature map down-sampling. """ 543 | return self.interpolate(x, scale_size=scale_size, input_size=size) 544 | 545 | def interpolate(self, x, scale_size, input_size): 546 | """ Feature map interpolation. """ 547 | B, S, C = x.shape 548 | S, H, W = input_size 549 | img_tokens = x 550 | 551 | img_tokens = img_tokens.transpose(1, 2).reshape(-1, C, S, H, W) 552 | img_tokens = F.interpolate(img_tokens, size=scale_size, mode=self.upsample_mode) 553 | out = img_tokens.reshape(-1, C, scale_size[0]).transpose(1, 2) 554 | 555 | return out 556 | 557 | def forward(self, x1, x2, x3, x4): 558 | (_, _, S1, H1, W1), (_, _, S2, H2, W2), (_, _, S3, H3, W3), (_, _, S4, H4, W4) = x1.shape, x2.shape, x3.shape, x4.shape 559 | x1 = rearrange(x1, 'b c s h w -> (b h w) s c') 560 | x2 = rearrange(x2, 'b c s h w -> (b h w) s c') 561 | x3 = rearrange(x3, 'b c s h w -> (b h w) s c') 562 | x4 = rearrange(x4, 'b c s h w -> (b h w) s c') 563 | # Conv-Attention. 564 | 565 | cur1 = self.norm11(x1) 566 | cur2 = self.norm12(x2) 567 | cur3 = self.norm13(x3) 568 | cur4 = self.norm14(x4) 569 | 570 | cur1, w1 = self.att1(cur1) 571 | cur2, w2 = self.att2(cur2) 572 | cur3, w3 = self.att3(cur3) 573 | cur4, w4 = self.att4(cur4) # b 1 c 574 | if self.use_layerscale: 575 | x1 = x1.mean(1).unsqueeze(1) + cur1 * self.gamma_1_1 576 | x2 = x2.mean(1).unsqueeze(1) + cur2 * self.gamma_2_1 577 | x3 = x3.mean(1).unsqueeze(1) + cur3 * self.gamma_3_1 578 | x4 = x4.mean(1).unsqueeze(1) + cur4 * self.gamma_4_1 579 | else: 580 | x1 = x1.mean(1).unsqueeze(1) + cur1 581 | x2 = x2.mean(1).unsqueeze(1) + cur2 582 | x3 = x3.mean(1).unsqueeze(1) + cur3 583 | x4 = x4.mean(1).unsqueeze(1) + cur4 584 | 585 | cur1 = self.norm21(x1) 586 | cur2 = self.norm22(x2) 587 | cur3 = self.norm23(x3) 588 | cur4 = self.norm24(x4) 589 | # MLP. 590 | cur1 = self.mlp1(cur1) 591 | cur2 = self.mlp2(cur2) 592 | cur3 = self.mlp3(cur3) 593 | cur4 = self.mlp4(cur4) 594 | 595 | if self.use_layerscale: 596 | x1 = x1 + cur1 * self.gamma_1_2 597 | x2 = x2 + cur2 * self.gamma_2_2 598 | x3 = x3 + cur3 * self.gamma_3_2 599 | x4 = x4 + cur4 * self.gamma_4_2 600 | else: 601 | x1 = x1 + cur1 602 | x2 = x2 + cur2 603 | x3 = x3 + cur3 604 | x4 = x4 + cur4 605 | 606 | x1 = rearrange(x1, '(b p1 p2) s c -> b c s p1 p2', p1=H1, p2=W1) 607 | x2 = rearrange(x2, '(b p1 p2) s c -> b c s p1 p2', p1=H2, p2=W2) 608 | x3 = rearrange(x3, '(b p1 p2) s c -> b c s p1 p2', p1=H3, p2=W3) 609 | x4 = rearrange(x4, '(b p1 p2) s c -> b c s p1 p2', p1=H4, p2=W4) 610 | 611 | # x1 = x1.mean(2) 612 | x1, x2, x3, x4 = x1.squeeze(2), x2.squeeze(2), x3.squeeze(2), x4.squeeze(2) 613 | return x1, x2, x3, x4 614 | -------------------------------------------------------------------------------- /dataset/four_fold.json: -------------------------------------------------------------------------------- 1 | {"fold1": ["050625-20x-roi4.hdr", "050625-20x-roi6.hdr", "050625-20x-roi1.hdr", "042145-20x-roi4.hdr", "042145-20x-roi9.hdr", "042145-20x-roi6.hdr", "042145_2-20x-roi5.hdr", "042145_2-20x-roi7.hdr", "042145_2-20x-roi8.hdr", "042145-20x-roi8.hdr", "042145-20x-roi7.hdr", "042145-20x-roi1.hdr", "042145-20x-roi10.hdr", "042145_2-20x-roi9.hdr", "052474_2-20x-roi5.hdr", "052474-20x-roi6.hdr", "040037-20x-roi1.hdr", "040037-20x-roi6.hdr", "040037-20x-roi5.hdr", "040037-20x-roi3.hdr", "040037-20x-roi4.hdr", "032414t-20x-roi3.hdr", "032414c-20x-roi5.hdr", "032414t-20x-roi6.hdr", "032414t-20x-roi5.hdr", "032414t-20x-roi2.hdr", "032414c-20x-roi4.hdr", "032414t-20x-roi4.hdr", "052357a-20x-roi3.hdr", "052357_2-20x-roi4.hdr", "052357_2-20x-roi3.hdr", "052357_2-20x-roi1.hdr", "052357-20x-roi5.hdr", "052357_2-20x-roi2.hdr", "052357-20x-roi3.hdr", "052357-20x-roi2.hdr", "052357_2-20x-roi5.hdr", "031541_2-20x-roi4.hdr", "031541_2-20x-roi6.hdr", "031541c-20x-roi3.hdr", "031541vtc-20x-roi1.hdr", "031541vtc-20x-roi5.hdr", "031541vtc-20x-roi4.hdr", "031541_2-20x-roi1.hdr", "031541c-20x-roi6.hdr", "031541_2-20x-roi2.hdr", "050323-20x-roi5.hdr", "050323-20x-roi6.hdr", "050323-20x-roi9.hdr", "050323-20x-roi1.hdr", "050323_2-20x-roi4.hdr", "050323-20x-roi8.hdr", "050323_2-20x-roi5.hdr", "050323-20x-roi2.hdr", "050323-20x-roi3.hdr", "050323_2-20x-roi2.hdr", "050323-20x-roi4.hdr", "050323_2-20x-roi1.hdr", "050323-20x-roi7.hdr", "050323_2-20x-roi3.hdr", "032546-20x-roi6.hdr", "032546-20x-roi4.hdr", "032546-20x-roi2.hdr", "032546-20x-roi3.hdr", "032546-20x-roi1.hdr", "041602-20x-roi3.hdr", "041602-20x-roi10.hdr", "041602-20x-roi1.hdr", "041602_2-20x-roi6.hdr", "041602_2-20x-roi3.hdr", "041602-20x-roi8.hdr", "041602_2-20x-roi4.hdr", "041602_2-20x-roi5.hdr", "041602-20x-roi9.hdr", "041602_2-20x-roi1.hdr", "041602_2-20x-roi2.hdr", "041602-20x-roi2.hdr", "033629-20x-roi8.hdr", "033629-20x-roi2.hdr", "033629_3-20x-roi3.hdr", "033629_3-20x-roi6.hdr", "033629-20x-roi9.hdr", "033629_2-20x-roi2.hdr", "033629_2-20x-roi6.hdr", "033629-20x-roi6.hdr", "033629_2-20x-roi4.hdr", "033629-20x-roi1.hdr", "033629-20x-roi7.hdr", "033629_3-20x-roi2.hdr", "033629_2-20x-roi3.hdr", "033629_3-20x-roi1.hdr", "033629-20x-roi10.hdr", "033629_2-20x-roi1.hdr", "033629-20x-roi5.hdr", "033629_3-20x-roi4.hdr", "033629_2-20x-roi5.hdr", "050028_2-20x-roi1.hdr", "050028_2-20x-roi4.hdr", "050028_2-20x-roi7.hdr", "050028-20x-roi4.hdr", "050028_2-20x-roi3.hdr", "050028-20x-roi1.hdr", "050028-20x-roi6.hdr", "050028_2-20x-roi5.hdr", "050028-20x-roi2.hdr", "050028-20x-roi10.hdr", "050028-20x-roi8.hdr", "050028_2-20x-roi2.hdr", "050028-20x-roi5.hdr", "050028_2-20x-roi8.hdr", "050028-20x-roi3.hdr", "050875-20x-roi2.hdr", "050875-20x-roi6.hdr", "050875_2-20x-roi5.hdr", "050875-20x-roi4.hdr", "050875_2-20x-roi3.hdr", "050875_2-20x-roi1.hdr", "050875-20x-roi3.hdr", "050875-20x-roi5.hdr", "050875_2-20x-roi6.hdr", "050875-20x-roi1.hdr", "042338-20x-roi4.hdr", "042338b-20x-roi4.hdr", "042338b-20x-roi1.hdr", "042338b-20x-roi5.hdr", "042338-20x-roi5.hdr", "042338b-20x-roi3.hdr", "042338-20x-roi2.hdr", "042338-20x-roi6.hdr", "042338b-20x-roi9.hdr", "042338b-20x-roi7.hdr", "042338b-20x-roi10.hdr", "042338-20x-roi3.hdr", "042338b-20x-roi8.hdr", "042338b-20x-roi2.hdr", "042338b-20x-roi6.hdr", "042189c_2-20x-roi3.hdr", "042189c_2-20x-roi7.hdr", "042189c_2-20x-roi10.hdr", "042189c-20x-roi2.hdr", "042189c-20x-roi1.hdr", "042189c_2-20x-roi8.hdr", "042189c-20x-roi5.hdr", "042189c_2-20x-roi6.hdr", "042189c_2-20x-roi4.hdr", "042189c_2-20x-roi2.hdr", "042189c_2-20x-roi9.hdr", "030968_2-20x-roi6.hdr", "030968_2-20x-roi1.hdr", "030968-20x-roi2.hdr", "030968-20x-roi5.hdr", "030968-20x-roi1.hdr", "030968_2-20x-roi5.hdr", "032814c-20x-roi2.hdr", "032814c-20x-roi4.hdr", "032814c-20x-roi5.hdr", "032814c-20x-roi6.hdr", "032814c-20x-roi1.hdr", "032814c-20x-roi3.hdr", "042477_2-20x-roi4.hdr", "042477_2-20x-roi6.hdr", "042477_2-20x-roi3.hdr", "042477-20x-roi4.hdr", "042477_2-20x-roi2.hdr", "042477-20x-roi1.hdr", "042477_2-20x-roi5.hdr", "042477-20x-roi6.hdr", "042477_2-20x-roi1.hdr", "042477-20x-roi5.hdr"], "fold2": ["055110_2-20x-roi6.hdr", "055110_2-20x-roi2.hdr", "055110-20x-roi6.hdr", "055110_2-20x-roi3.hdr", "055110_2-20x-roi5.hdr", "055110_2-20x-roi1.hdr", "055110-20x-roi5.hdr", "055110-20x-roi1.hdr", "055110_2-20x-roi4.hdr", "033850c-20x-roi5.hdr", "033850c-20x-roi6.hdr", "033850c-20x-roi3.hdr", "032236c-20x-roi5.hdr", "032236c-20x-roi3.hdr", "032236-20x-roi1.hdr", "032236-20x-roi6.hdr", "032236-20x-roi4.hdr", "032236-20x-roi5.hdr", "032236-20x-roi2.hdr", "034247_2-20x-roi1.hdr", "034247-20x-roi5.hdr", "034247_2-20x-roi5.hdr", "034247_2-20x-roi4.hdr", "034247_2-20x-roi2.hdr", "034247-20x-roi8.hdr", "034247-20x-roi2.hdr", "034247_2-20x-roi6.hdr", "034247-20x-roi1.hdr", "034247-20x-roi6.hdr", "034247-20x-roi3.hdr", "034247-20x-roi9.hdr", "034247-20x-roi7.hdr", "034247-20x-roi4.hdr", "034247-20x-roi10.hdr", "031368-20x-roi2.hdr", "031368c-20x-roi5.hdr", "031368c-20x-roi2.hdr", "031368-20x-roi7.hdr", "031368c-20x-roi3.hdr", "031368-20x-roi3.hdr", "031368-20x-roi4.hdr", "031368-20x-roi10.hdr", "031368-20x-roi6.hdr", "041664-20x-roi9.hdr", "041664_2-20x-roi1.hdr", "041664-20x-roi2.hdr", "041664-20x-roi5.hdr", "041664_2-20x-roi3.hdr", "041664-20x-roi4.hdr", "041664_2-20x-roi6.hdr", "041664-20x-roi1.hdr", "041664_2-20x-roi2.hdr", "041664-20x-roi10.hdr", "041664_2-20x-roi5.hdr", "041664-20x-roi3.hdr", "041664_2-20x-roi4.hdr", "032310-20x-roi3.hdr", "032310-20x-roi2.hdr", "032310-20x-roi1.hdr", "055431-20x-roi5.hdr", "055431-20x-roi6.hdr", "055431-20x-roi4.hdr", "055431-20x-roi3.hdr", "052057-20x-roi4.hdr", "052057_2-20x-roi5.hdr", "052057-20x-roi5.hdr", "052057-20x-roi6.hdr", "052057_2-20x-roi4.hdr", "052057_2-20x-roi6.hdr", "052057-20x-roi3.hdr", "052057_2-20x-roi2.hdr", "052057_2-20x-roi1.hdr", "052057-20x-roi2.hdr", "052057_2-20x-roi3.hdr", "052057-20x-roi1.hdr", "051417-20x-roi2.hdr", "051417-20x-roi1.hdr", "051417-20x-roi4.hdr", "032979_2-20x-roi6.hdr", "032979l-20x-roi4.hdr", "032979l-20x-roi6.hdr", "032979-20x-roi4.hdr", "032979-20x-roi5.hdr", "032979_2-20x-roi3.hdr", "032979c-20x-roi5.hdr", "032979_2-20x-roi2.hdr", "032979l-20x-roi5.hdr", "032979c-20x-roi3.hdr", "032979-20x-roi6.hdr", "032979l-20x-roi2.hdr", "032979_2-20x-roi5.hdr", "032979l-20x-roi1.hdr", "032979_2-20x-roi1.hdr", "032979_2-20x-roi4.hdr", "033808-20x-roi4.hdr", "040493-20x-roi1.hdr", "040493-20x-roi6.hdr", "040493-20x-roi4.hdr", "050752-20x-roi6.hdr", "050752-20x-roi1.hdr", "055380-20x-roi2.hdr", "055380-20x-roi4.hdr", "055380-20x-roi6.hdr", "033908-2-20x-roi5.hdr", "033908_3-20x-roi5.hdr", "033908_2-20x-roi6.hdr", "033908_2-20x-roi1.hdr", "033908_2-20x-roi3.hdr", "033908-2-20x-roi4.hdr", "033908-2-20x-roi1.hdr", "033908_2-20x-roi4.hdr", "033908_2-20x-roi2.hdr", "033908-20x-roi10.hdr", "033908_3-20x-roi4.hdr", "033908-2-20x-roi2.hdr", "033908_2-20x-roi5.hdr", "033908_3-20x-roi1.hdr", "051492-20x-roi4.hdr", "051492-20x-roi1.hdr", "051492-20x-roi6.hdr", "051492-20x-roi2.hdr", "051492-20x-roi3.hdr", "033942-20x-roi7.hdr", "033942-20x-roi10.hdr", "033942-20x-roi8.hdr", "033942-20x-roi1.hdr", "033942c-20x-roi6.hdr", "033942c-20x-roi3.hdr", "033942c-20x-roi5.hdr", "033942-20x-roi3.hdr", "033942c-20x-roi2.hdr", "033942-20x-roi9.hdr", "033942c-20x-roi4.hdr", "033942-20x-roi4.hdr"], "fold3": ["053961c-20x-roi1.hdr", "053961c-20x-roi5.hdr", "053961c-20x-roi6.hdr", "053961c-20x-roi4.hdr", "050624-20x-roi5.hdr", "050624_2-20x-roi6.hdr", "050624-20x-roi1.hdr", "050624_2-20x-roi1.hdr", "050624_2-20x-roi5.hdr", "050624-20x-roi6.hdr", "052032_2-20x-roi2.hdr", "052032_2-20x-roi6.hdr", "052032-20x-roi6.hdr", "052032-20x-roi1.hdr", "052032-20x-roi4.hdr", "052032_2-20x-roi4.hdr", "052032-20x-roi2.hdr", "052032-20x-roi5.hdr", "052032_2-20x-roi3.hdr", "052032_2-20x-roi5.hdr", "034653-20x-roi8.hdr", "034653-20x-roi2.hdr", "034653-20x-roi6.hdr", "034653-20x-roi10.hdr", "034653-20x-roi4.hdr", "050833_2-20x-roi4.hdr", "050833_2-20x-roi6.hdr", "050833-20x-roi6.hdr", "050833-20x-roi3.hdr", "050833-20x-roi1.hdr", "050833_2-20x-roi5.hdr", "050833-20x-roi2.hdr", "050833-20x-roi4.hdr", "050833-20x-roi5.hdr", "051147-20x-roi6.hdr", "051147-20x-roi1.hdr", "051147-20x-roi5.hdr", "051147-20x-roi3.hdr", "051147-20x-roi2.hdr", "051147-20x-roi4.hdr", "040186-20x-roi4.hdr", "040186-20x-roi2.hdr", "040186-20x-roi5.hdr", "040186-20x-roi3.hdr", "053686c-20x-roi2.hdr", "053686c-20x-roi3.hdr", "053686c-20x-roi1.hdr", "053065_2-20x-roi4.hdr", "053065d-20x-roi5.hdr", "053065_2-20x-roi5.hdr", "053065-20x-roi5.hdr", "053065_2-20x-roi6.hdr", "053065-20x-roi9.hdr", "053065-20x-roi2.hdr", "053065-20x-roi4.hdr", "053065-20x-roi3.hdr", "041845-20x-roi1.hdr", "041845-20x-roi3.hdr", "041845-20x-roi10.hdr", "041845_2-20x-roi1.hdr", "041845-20x-roi5.hdr", "041845-20x-roi4.hdr", "041845-20x-roi7.hdr", "041845-20x-roi8.hdr", "041845-20x-roi6.hdr", "040483-20x-roi4.hdr", "040483-20x-roi1.hdr", "040483-20x-roi3.hdr", "053449-20x-roi2.hdr", "053449-20x-roi4.hdr", "053449-20x-roi3.hdr", "053449-20x-roi5.hdr", "053449-20x-roi6.hdr", "034004a2-20x-roi1.hdr", "034004a2-20x-roi3.hdr", "034004a2-20x-roi4.hdr", "034004a2-20x-roi6.hdr", "034004a2-20x-roi5.hdr", "030907c-20x-roi6.hdr", "030907c-20x-roi3.hdr", "030907-20x-roi3.hdr", "030907c-20x-roi5.hdr", "030907-20x-roi4.hdr", "030907-20x-roi5.hdr", "030907c-20x-roi2.hdr", "051909-20x-roi4.hdr", "051909-20x-roi2.hdr", "051909-20x-roi7.hdr", "051909-20x-roi3.hdr", "051909-20x-roi6.hdr", "051909-20x-roi1.hdr", "051909-20x-roi9.hdr", "051909-20x-roi8.hdr", "051909-20x-roi5.hdr", "034250-20x-roi4.hdr", "034250-20x-roi5.hdr", "034250-20x-roi3.hdr", "034250-20x-roi1.hdr", "034250-20x-roi2.hdr", "032112-20x-roi8.hdr", "032112-20x-roi5.hdr", "032112-20x-roi4.hdr", "032112-20x-roi3.hdr", "032112-20x-roi6.hdr", "032112c-20x-roi5.hdr", "032112-20x-roi7.hdr", "032112-20x-roi1.hdr", "032112c-20x-roi4.hdr", "032112-20x-roi2.hdr"], "fold4": ["030979-20x-roi3.hdr", "030979-20x-roi6.hdr", "030979-20x-roi2.hdr", "030979-20x-roi5.hdr", "034080_2-20x-roi6.hdr", "034080-20x-roi5.hdr", "034080_3-20x-roi2.hdr", "034080_2-20x-roi5.hdr", "034080-20x-roi4.hdr", "034080_2-20x-roi1.hdr", "054733-20x-roi4.hdr", "054733-20x-roi3.hdr", "054733-20x-roi2.hdr", "054733-20x-roi6.hdr", "033805a-20x-roi2.hdr", "055308-20x-roi3.hdr", "055308-20x-roi9.hdr", "055308-20x-roi5.hdr", "055308-20x-roi1.hdr", "055308-20x-roi6.hdr", "055308-20x-roi4.hdr", "055308-20x-roi10.hdr", "055308-20x-roi8.hdr", "034429-20x-roi5.hdr", "034429-20x-roi3.hdr", "034429-20x-roi4.hdr", "034429-20x-roi6.hdr", "034429-20x-roi2.hdr", "030406-20x-roi6.hdr", "030406c-20x-roi6.hdr", "030406c-20x-roi4.hdr", "030406-20x-roi2.hdr", "030406-20x-roi5.hdr", "030406c_2-20x-roi2.hdr", "030406c_2-20x-roi5.hdr", "030406c-20x-roi1.hdr", "030406-20x-roi1.hdr", "030406c_2-20x-roi1.hdr", "053832-20x-roi3.hdr", "040126_2-20x-roi4.hdr", "040126_2-20x-roi2.hdr", "040126_2-20x-roi1.hdr", "040126_2-20x-roi5.hdr", "040126_2-20x-roi3.hdr", "040126-20x-roi5.hdr", "040579_2-20x-roi6.hdr", "040579_2-20x-roi8.hdr", "040579_2-20x-roi2.hdr", "040579-20x-roi8.hdr", "040579-20x-roi6.hdr", "040579_2-20x-roi4.hdr", "040579_2-20x-roi10.hdr", "040579_2-20x-roi1.hdr", "040579_2-20x-roi3.hdr", "040579-20x-roi1.hdr", "040579-20x-roi7.hdr", "040579_2-20x-roi5.hdr", "040579-20x-roi2.hdr", "054276t-20x-roi2.hdr", "054276-20x-roi1.hdr", "054276-20x-roi6.hdr", "054276t-20x-roi3.hdr", "054276t-20x-roi6.hdr", "045073_2-20x-roi1.hdr", "045073a-20x-roi5.hdr", "045073-20x-roi9.hdr", "045073a-20x-roi6.hdr", "045073_2-20x-roi4.hdr", "045073_2-20x-roi2.hdr", "045073_2-20x-roi6.hdr", "045073a-20x-roi7.hdr", "045073_2-20x-roi3.hdr", "045073a-20x-roi9.hdr", "045073-20x-roi2.hdr", "052206-20x-roi3.hdr", "052206-20x-roi2.hdr", "052206-20x-roi5.hdr", "052206-20x-roi4.hdr", "051765-20x-roi4.hdr", "051765-20x-roi5.hdr", "041248-20x-roi7.hdr", "041248_2-20x-roi4.hdr", "041248_2-20x-roi5.hdr", "041248-20x-roi5.hdr", "041248_2-20x-roi2.hdr", "041248-20x-roi6.hdr", "041248_2-20x-roi3.hdr", "041248_2-20x-roi6.hdr", "041248_2-20x-roi1.hdr", "041248-20x-roi4.hdr", "032370b_2-20x-roi1.hdr", "032370b-20x-roi4.hdr", "032370b-20x-roi6.hdr", "032370b_2-20x-roi6.hdr", "032370b_2-20x-roi5.hdr", "032370b-20x-roi1.hdr", "032370b_2-20x-roi2.hdr", "032370b-20x-roi2.hdr", "032370b-20x-roi3.hdr", "032370b_2-20x-roi4.hdr", "032370b-20x-roi5.hdr", "050541-20x-roi4.hdr", "050541_2-20x-roi6.hdr", "050541-20x-roi10.hdr", "050541-20x-roi7.hdr", "050541-20x-roi5.hdr", "050541_2-20x-roi5.hdr", "050541_2-20x-roi4.hdr", "050541_2-20x-roi1.hdr", "050031-20x-roi3.hdr", "050031-20x-roi10.hdr", "050031_2-20x-roi6.hdr", "050031-20x-roi5.hdr", "050031_2-20x-roi4.hdr", "050031_2-20x-roi10.hdr", "050031_2-20x-roi2.hdr", "050031-20x-roi4.hdr", "050031-20x-roi1.hdr", "050031_2-20x-roi9.hdr", "050031-20x-roi6.hdr", "050031_2-20x-roi5.hdr", "050031-20x-roi8.hdr", "050031-20x-roi9.hdr", "050031_2-20x-roi7.hdr", "050031_2-20x-roi8.hdr", "050031_2-20x-roi3.hdr"]} -------------------------------------------------------------------------------- /model/link.md: -------------------------------------------------------------------------------- 1 | Get this pre-trained model: https://drive.google.com/drive/folders/1BfPt6kvLLWDqZVeppRVKx7nGk1lcQaCA?usp=sharing 2 | --------------------------------------------------------------------------------