├── .gitignore ├── README.md ├── codes ├── __init__.py ├── builder.py ├── dataloaders │ ├── __init__.py │ ├── _base.py │ ├── _samplers.py │ ├── _transforms.py │ └── acdc.py ├── logger.py ├── losses │ ├── __init__.py │ ├── losses.py │ └── pixel_contras_loss.py ├── metrics │ ├── __init__.py │ └── metrics.py ├── models │ ├── __init__.py │ ├── _base.py │ ├── swin_decoder.py │ ├── unet.py │ └── unet_tf.py ├── schedulers │ ├── __init__.py │ └── poly.py ├── trainers │ ├── __init__.py │ ├── _base.py │ ├── mt_trainer.py │ ├── supervised_trainer.py │ └── ugpcl_trainer.py └── utils │ ├── __init__.py │ ├── analyze.py │ ├── init.py │ ├── ramps.py │ └── utils.py ├── configs ├── _datasets │ ├── acdc.yaml │ └── acdc_224.yaml ├── _models │ ├── unet_r50.yaml │ └── unet_tf_r50.yaml ├── _trainers │ ├── mt.yaml │ ├── supervised.yaml │ └── ugpcl.yaml └── comparison_acdc_224_136 │ ├── mt_unet_r50.yaml │ ├── ugpcl_unet_r50.yaml │ ├── unet_r50.yaml │ └── unet_r50_full.yaml ├── pics ├── overview.jpg ├── preds.jpg └── show_feats.jpg ├── requirements.txt ├── test.py ├── train.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .git 3 | results 4 | shows 5 | weights 6 | convert.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UGPCL 2 | 3 | > [IJCAI' 22] Uncertainty-Guided Pixel Contrastive Learning for Semi-Supervised Medical Image Segmentation. 4 | > 5 | > Tao Wang, Jianglin Lu, Zhihui Lai, Jiajun Wen and Heng Kong 6 | 7 | ![](pics/overview.jpg) 8 | 9 | # Installation 10 | Please refer to [requirements.txt](requirements.txt) 11 | 12 | # Dataset 13 | *ACDC* dataset can be found in this [Link](https://github.com/HiLab-git/SSL4MIS/tree/master/data/ACDC). 14 | 15 | Please change the dataset root directory in _configs/\_datasets/acdc_224.yaml_. 16 | 17 | # Results 18 | 19 | ### *ACDC* (136 labels | 10%): 20 | 21 | | Method | Model | Iterations | Batch Size | Label Size | DsC | Ckpt | Config File | 22 | | :----------: | :------: | :--------: | :--------: | :--------: | :---: | :----------------------------------------------------------: | :-------------------------------------------------: | 23 | | UGPCL | UNet-R50 | 6000 | 16 | 8 | 88.11 | [Link](https://drive.google.com/file/d/1T8T6g_xiJWGetQhZeFMNG2q7dzmYyN4s/view?usp=sharing) | configs/comparison_acdc_224_136/ugpcl_unet_r50.yaml | 24 | | Mean Teacher | UNet-R50 | 6000 | 16 | 8 | 85.75 | [Link](https://drive.google.com/file/d/1mWKKoeZbSlf6DNxqnoypr50ialPMqFYL/view?usp=sharing) | configs/comparison_acdc_224_136/mt_unet_r50.yaml | 25 | 26 | ### Visualization 27 | 28 | - Segmentation results: 29 | 30 | 31 | 32 | - Pixel features (t-SNE): 33 | 34 | 35 | 36 | 37 | # Reference 38 | - [https://github.com/HiLab-git/SSL4MIS](https://github.com/HiLab-git/SSL4MIS) 39 | - [https://github.com/tfzhou/ContrastiveSeg](https://github.com/tfzhou/ContrastiveSeg) 40 | 41 | # Citation 42 | ```bibtex 43 | @inproceedings{ijcai2022-201, 44 | title = {Uncertainty-Guided Pixel Contrastive Learning for Semi-Supervised Medical Image Segmentation}, 45 | author = {Wang, Tao and Lu, Jianglin and Lai, Zhihui and Wen, Jiajun and Kong, Heng}, 46 | booktitle = {Proceedings of the Thirty-First International Joint Conference on 47 | Artificial Intelligence, {IJCAI-22}}, 48 | publisher = {International Joint Conferences on Artificial Intelligence Organization}, 49 | editor = {Lud De Raedt}, 50 | pages = {1444--1450}, 51 | year = {2022}, 52 | month = {7}, 53 | note = {Main Track}, 54 | doi = {10.24963/ijcai.2022/201}, 55 | url = {https://doi.org/10.24963/ijcai.2022/201}, 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /codes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/codes/__init__.py -------------------------------------------------------------------------------- /codes/builder.py: -------------------------------------------------------------------------------- 1 | from .dataloaders import * 2 | from .losses import * 3 | from .metrics import * 4 | from .models import * 5 | from .schedulers import * 6 | from .losses import * 7 | from .logger import Logger 8 | 9 | from torch.optim.lr_scheduler import * 10 | from torch.optim import * 11 | 12 | __all__ = ['build_model', 'build_metric', 'build_dataloader', 'build_scheduler', 13 | 'build_optimizer', 'build_criterion', 'build_logger', '_build_from_cfg'] 14 | 15 | 16 | def _build_from_cfg(cfg): 17 | if isinstance(cfg, dict): 18 | name = cfg['name'] 19 | if 'kwargs' in cfg.keys() and cfg['kwargs'] is not None: 20 | return eval(f"{name}")(**cfg['kwargs']) 21 | else: 22 | name = cfg.name 23 | if hasattr(cfg, 'kwargs') and cfg.kwargs is not None: 24 | return eval(f"{name}")(**cfg.kwargs.__dict__) 25 | return eval(f"{name}()") 26 | 27 | 28 | def build_model(cfg): 29 | return _build_from_cfg(cfg) 30 | 31 | 32 | def build_optimizer(model_parameters, cfg): 33 | if isinstance(cfg, dict): 34 | name = cfg['name'] 35 | if 'kwargs' in cfg.keys() and cfg['kwargs'] is not None: 36 | kwargs = cfg['kwargs'] 37 | kwargs['params'] = model_parameters 38 | return eval(f"{name}")(**kwargs) 39 | else: 40 | name = cfg.name 41 | if hasattr(cfg, 'kwargs') and cfg.kwargs is not None: 42 | kwargs = cfg.kwargs.__dict__ 43 | kwargs['params'] = model_parameters 44 | return eval(f"{name}")(**kwargs) 45 | kwargs = {'params': model_parameters} 46 | return eval(f"{name}")(**kwargs) 47 | 48 | 49 | def build_scheduler(optimizer_, cfg): 50 | if isinstance(cfg, dict): 51 | name = cfg['name'] 52 | if 'kwargs' in cfg.keys() and cfg['kwargs'] is not None: 53 | kwargs = cfg['kwargs'] 54 | kwargs['optimizer'] = optimizer_ 55 | return eval(f"{name}")(**kwargs) 56 | else: 57 | name = cfg.name 58 | if hasattr(cfg, 'kwargs') and cfg.kwargs is not None: 59 | kwargs = cfg.kwargs.__dict__ 60 | kwargs['optimizer'] = optimizer_ 61 | return eval(f"{name}")(**kwargs) 62 | kwargs = {'optimizer': optimizer_} 63 | return eval(f"{name}")(**kwargs) 64 | 65 | 66 | def build_criterion(cfg): 67 | return _build_from_cfg(cfg) 68 | 69 | 70 | def build_metric(cfg): 71 | return _build_from_cfg(cfg) 72 | 73 | 74 | def build_dataloader(cfg, worker_init_fn): 75 | if not isinstance(cfg, dict): 76 | kwargs = cfg.kwargs.__dict__ 77 | else: 78 | kwargs = cfg['kwargs'] 79 | kwargs['worker_init_fn'] = worker_init_fn 80 | return eval(f"get_{cfg.name}_loaders")(**kwargs) 81 | 82 | def build_logger(cfg): 83 | return eval(f"Logger")(**cfg.__dict__) 84 | -------------------------------------------------------------------------------- /codes/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .acdc import get_acdc_loaders 2 | -------------------------------------------------------------------------------- /codes/dataloaders/_base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torch.utils.data.dataset import T_co 3 | 4 | 5 | class BaseDataSet(Dataset): 6 | 7 | def __init__(self) -> None: 8 | super().__init__() 9 | 10 | def __getitem__(self, index) -> T_co: 11 | pass 12 | 13 | def __len__(self): 14 | pass 15 | -------------------------------------------------------------------------------- /codes/dataloaders/_samplers.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class TwoStreamBatchSampler(Sampler): 7 | """ 8 | Iterate two sets of indices 9 | 10 | An 'epoch' is one iteration through the primary indices. 11 | During the epoch, the secondary indices are iterated through 12 | as many times as needed. 13 | """ 14 | 15 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 16 | self.primary_indices = primary_indices 17 | self.secondary_indices = secondary_indices 18 | self.primary_batch_size = batch_size - secondary_batch_size 19 | self.secondary_batch_size = secondary_batch_size 20 | 21 | assert len(self.primary_indices) >= self.primary_batch_size > 0 22 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 23 | 24 | def __iter__(self): 25 | """ 26 | 生成每个batch的数据采样索引:分别从labeled indices和unlabeled indices中选取 27 | 每次调用索引序列都会被打乱,确保随机采样 28 | Returns: (labeled index1,labeled index2,...,unlabeled index1,unlabeled index2,...) 29 | 30 | """ 31 | primary_iter = iterate_once(self.primary_indices) 32 | secondary_iter = iterate_eternally(self.secondary_indices) 33 | return ( 34 | primary_batch + secondary_batch 35 | for (primary_batch, secondary_batch) 36 | in zip(grouper(primary_iter, self.primary_batch_size), 37 | grouper(secondary_iter, self.secondary_batch_size)) 38 | ) 39 | 40 | def __len__(self): 41 | return len(self.primary_indices) // self.primary_batch_size 42 | 43 | 44 | def iterate_once(iterable): 45 | return np.random.permutation(iterable) # 随机排列序列 46 | 47 | 48 | def iterate_eternally(indices): 49 | """ 50 | 异步不断生成随机打乱的indices序列 并由itertools.chain连接后返回 51 | Args: 52 | indices: 53 | 54 | Returns: 55 | 56 | """ 57 | def infinite_shuffles(): 58 | while True: 59 | yield np.random.permutation(indices) 60 | 61 | return itertools.chain.from_iterable(infinite_shuffles()) 62 | 63 | 64 | def grouper(iterable, n): 65 | """ 66 | Collect data into fixed-length chunks or blocks 67 | eg: grouper('ABCDEFG', 3) --> ABC DEF 68 | Args: 69 | iterable: 70 | n: 71 | 72 | Returns: 73 | 74 | """ 75 | args = [iter(iterable)] * n 76 | return zip(*args) 77 | -------------------------------------------------------------------------------- /codes/dataloaders/_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms.functional as F 4 | 5 | from numpy import random 6 | from scipy import ndimage 7 | from scipy.ndimage.interpolation import zoom 8 | from torchvision.transforms.functional import InterpolationMode 9 | from torchvision.transforms import transforms as T 10 | 11 | 12 | def build_transforms(transforms_): 13 | transforms = [] 14 | for transform in transforms_: 15 | if hasattr(transform, 'kwargs') and transform.kwargs is not None: 16 | kwargs = transform.kwargs.__dict__ 17 | transform = eval(f"{transform.name}")(**kwargs) 18 | else: 19 | transform = eval(f"{transform.name}()") 20 | transforms.append(transform) 21 | return Compose(transforms) 22 | 23 | 24 | class Compose: 25 | 26 | def __init__(self, transforms): 27 | self.transforms = transforms 28 | 29 | def __call__(self, img, **kwargs): 30 | for t in self.transforms: 31 | img = t(img, **kwargs) 32 | return img 33 | 34 | def __repr__(self): 35 | format_string = self.__class__.__name__ + '(' 36 | for t in self.transforms: 37 | format_string += '\n' 38 | format_string += ' {0}'.format(t) 39 | format_string += '\n)' 40 | return format_string 41 | 42 | 43 | class ToTensor3D(object): 44 | """Convert ndarrays in sample to Tensors.""" 45 | 46 | def __call__(self, sample, with_sdf=False): 47 | img = torch.from_numpy(sample['image']) 48 | label = torch.from_numpy(sample['label']) 49 | if len(img.shape) == 3: 50 | img = img.unsqueeze(0) 51 | if len(label.shape) == 3: 52 | label = label.unsqueeze(0) 53 | if with_sdf: 54 | sdf = torch.from_numpy(sample['sdf']) 55 | return {'image': img, 'label': label, 'sdf': sdf} 56 | return {'image': img, 'label': label} 57 | 58 | 59 | # 3D transforms 60 | class CenterCrop3D(object): 61 | def __init__(self, output_size): 62 | self.output_size = output_size 63 | 64 | def __call__(self, sample, with_sdf=False): 65 | image, label = sample['image'], sample['label'] 66 | if with_sdf: 67 | sdf = sample['sdf'] 68 | # pad the sample if necessary 69 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 70 | self.output_size[2]: 71 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 72 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 73 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 74 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 75 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 76 | if with_sdf: 77 | for _ in range(sdf.shape[0]): 78 | sdf[_] = np.pad(sdf[_], [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 79 | 80 | (w, h, d) = image.shape 81 | 82 | w1 = int(round((w - self.output_size[0]) / 2.)) 83 | h1 = int(round((h - self.output_size[1]) / 2.)) 84 | d1 = int(round((d - self.output_size[2]) / 2.)) 85 | 86 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 87 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 88 | if with_sdf: 89 | for _ in range(sdf.shape[0]): 90 | sdf[_] = sdf[_][w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 91 | return {'image': image, 'label': label, 'sdf': sdf} 92 | return {'image': image, 'label': label} 93 | 94 | 95 | class RandomCrop3D(object): 96 | """ 97 | Crop randomly the image in a sample 98 | Args: 99 | output_size (int): Desired output size 100 | """ 101 | 102 | def __init__(self, output_size): 103 | self.output_size = output_size 104 | 105 | def __call__(self, sample, with_sdf=False): 106 | image, label = sample['image'], sample['label'] 107 | if with_sdf: 108 | sdf = sample['sdf'] 109 | # pad the sample if necessary 110 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 111 | self.output_size[2]: 112 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 113 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 114 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 115 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 116 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 117 | if with_sdf: 118 | temp_sdf = [] 119 | for _ in range(sdf.shape[0]): 120 | temp_sdf.append(np.pad(sdf[_], [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)) 121 | sdf = np.stack(temp_sdf) 122 | 123 | (w, h, d) = image.shape 124 | # if np.random.uniform() > 0.33: 125 | # w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4) 126 | # h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4) 127 | # else: 128 | w1 = np.random.randint(0, w - self.output_size[0]) 129 | h1 = np.random.randint(0, h - self.output_size[1]) 130 | d1 = np.random.randint(0, d - self.output_size[2]) 131 | 132 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 133 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 134 | 135 | if with_sdf: 136 | sdf = sdf[:, w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 137 | return {'image': image, 'label': label, 'sdf': sdf} 138 | else: 139 | return {'image': image, 'label': label} 140 | 141 | 142 | class RandomRotFlip3D(object): 143 | """ 144 | Crop randomly flip the dataset in a sample 145 | Args: 146 | output_size (int): Desired output size 147 | """ 148 | 149 | def __call__(self, sample, with_sdf=False): 150 | image, label = sample['image'], sample['label'] 151 | if with_sdf: 152 | sdf = sample['sdf'] 153 | k = np.random.randint(0, 4) 154 | image = np.rot90(image, k) 155 | label = np.rot90(label, k) 156 | if with_sdf: 157 | sdf = np.rot90(sdf, k, axes=(1, 2)) 158 | axis = np.random.randint(0, 2) 159 | image = np.flip(image, axis=axis).copy() 160 | label = np.flip(label, axis=axis).copy() 161 | if with_sdf: 162 | sdf = np.flip(sdf, axis=axis + 1).copy() 163 | if with_sdf: 164 | return {'image': image, 'label': label, 'sdf': sdf} 165 | return {'image': image, 'label': label} 166 | 167 | 168 | class RandomNoise3D(object): 169 | def __init__(self, mu=0, sigma=0.1): 170 | self.mu = mu 171 | self.sigma = sigma 172 | 173 | def __call__(self, sample, with_sdf=False): 174 | image, label = sample['image'], sample['label'] 175 | noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2 * self.sigma, 176 | 2 * self.sigma) 177 | noise = noise + self.mu 178 | image = image + noise 179 | if with_sdf: 180 | return {'image': image, 'label': label, 'sdf': sample['sdf']} 181 | return {'image': image, 'label': label} 182 | 183 | 184 | class CreateOnehotLabel3D(object): 185 | def __init__(self, num_classes): 186 | self.num_classes = num_classes 187 | 188 | def __call__(self, sample): 189 | image, label = sample['image'], sample['label'] 190 | onehot_label = np.zeros((self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32) 191 | for i in range(self.num_classes): 192 | onehot_label[i, :, :, :] = (label == i).astype(np.float32) 193 | return {'image': image, 'label': label, 'onehot_label': onehot_label} 194 | 195 | 196 | class RandomGenerator(object): 197 | def __init__(self, output_size, p_flip=0.5, p_rot=0.5): 198 | self.output_size = output_size 199 | self.p_flip = p_flip 200 | self.p_rot = p_rot 201 | 202 | def __call__(self, sample): 203 | image, label = sample['image'], sample['label'] 204 | if torch.rand(1) < self.p_flip: 205 | image, label = self.random_rot_flip(image, label) 206 | elif torch.rand(1) < self.p_rot: 207 | image, label = self.random_rotate(image, label) 208 | x, y = image.shape 209 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=0) 210 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 211 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 212 | label = torch.from_numpy(label.astype(np.uint8)).unsqueeze(0) 213 | sample['image'] = image 214 | sample['label'] = label 215 | return sample 216 | 217 | @staticmethod 218 | def random_rot_flip(image, label): 219 | k = np.random.randint(0, 4) 220 | image = np.rot90(image, k) 221 | label = np.rot90(label, k) 222 | axis = np.random.randint(0, 2) 223 | image = np.flip(image, axis=axis).copy() 224 | label = np.flip(label, axis=axis).copy() 225 | return image, label 226 | 227 | @staticmethod 228 | def random_rotate(image, label): 229 | angle = np.random.randint(-20, 20) 230 | image = ndimage.rotate(image, angle, order=0, reshape=False) 231 | label = ndimage.rotate(label, angle, order=0, reshape=False) 232 | return image, label 233 | 234 | 235 | class ToTensor: 236 | """ 237 | Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 238 | """ 239 | 240 | def __call__(self, sample): 241 | return {'image': F.to_tensor(sample['image']), 242 | 'label': F.to_tensor(sample['label'])} 243 | 244 | def __repr__(self): 245 | return self.__class__.__name__ + '()' 246 | 247 | 248 | class ToRGB: 249 | 250 | def __call__(self, sample): 251 | if sample['image'].shape[0] == 1: 252 | sample['image'] = sample['image'].repeat(3, 1, 1) 253 | return sample 254 | 255 | def __repr__(self): 256 | return self.__class__.__name__ + '()' 257 | 258 | 259 | class ConvertImageDtype(torch.nn.Module): 260 | 261 | def __init__(self, dtype: torch.dtype) -> None: 262 | super().__init__() 263 | self.dtype = dtype 264 | 265 | def forward(self, sample): 266 | sample['image'] = F.convert_image_dtype(sample['image'], self.dtype) 267 | sample['label'] = F.convert_image_dtype(sample['label'], self.dtype) 268 | return sample 269 | 270 | 271 | class ToPILImage: 272 | 273 | def __init__(self, mode=None): 274 | self.mode = mode 275 | 276 | def __call__(self, sample): 277 | sample['image'] = F.to_pil_image(sample['image'], self.mode) 278 | sample['label'] = F.to_pil_image(sample['label'], self.mode) 279 | return sample 280 | 281 | def __repr__(self): 282 | format_string = self.__class__.__name__ + '(' 283 | if self.mode is not None: 284 | format_string += 'mode={0}'.format(self.mode) 285 | format_string += ')' 286 | return format_string 287 | 288 | 289 | # 2D transforms 290 | class Normalize(T.Normalize): 291 | 292 | def __init__(self, mean, std, inplace=False): 293 | super().__init__(mean, std, inplace) 294 | 295 | def forward(self, sample): 296 | sample['image'] = F.normalize(sample['image'], self.mean, self.std, self.inplace) 297 | return sample 298 | 299 | 300 | class Resize(T.Resize): 301 | 302 | def __init__(self, size, interpolation=InterpolationMode.BILINEAR): 303 | super().__init__(size, interpolation) 304 | 305 | def forward(self, sample): 306 | sample['image'] = F.resize(sample['image'], self.size, self.interpolation) 307 | sample['label'] = F.resize(sample['label'], self.size, self.interpolation) 308 | return sample 309 | 310 | 311 | class CenterCrop(T.CenterCrop): 312 | 313 | def __init__(self, size): 314 | super().__init__(size) 315 | 316 | def forward(self, sample): 317 | sample['image'] = F.center_crop(sample['image'], self.size) 318 | sample['label'] = F.center_crop(sample['label'], self.size) 319 | return sample 320 | 321 | 322 | class Pad(T.Pad): 323 | 324 | def __init__(self, padding, fill=0, padding_mode="constant"): 325 | super().__init__(padding, fill, padding_mode) 326 | 327 | def forward(self, sample): 328 | sample['label'] = F.pad(sample['image'], self.padding, self.fill, self.padding_mode) 329 | sample['label'] = F.pad(sample['label'], self.padding, self.fill, self.padding_mode) 330 | return sample 331 | 332 | 333 | class RandomCrop(T.RandomCrop): 334 | 335 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): 336 | super().__init__(size, padding, pad_if_needed, fill, padding_mode) 337 | 338 | def forward(self, sample): 339 | img = sample['image'] 340 | label = sample['label'] 341 | if self.padding is not None: 342 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 343 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 344 | 345 | width, height = F._get_image_size(img) 346 | # pad the width if needed 347 | if self.pad_if_needed and width < self.size[1]: 348 | padding = [self.size[1] - width, 0] 349 | img = F.pad(img, padding, self.fill, self.padding_mode) 350 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 351 | # pad the height if needed 352 | if self.pad_if_needed and height < self.size[0]: 353 | padding = [0, self.size[0] - height] 354 | img = F.pad(img, padding, self.fill, self.padding_mode) 355 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 356 | 357 | i, j, h, w = self.get_params(img, self.size) 358 | 359 | sample['image'] = F.crop(img, i, j, h, w) 360 | sample['label'] = F.crop(label, i, j, h, w) 361 | return sample 362 | 363 | def __repr__(self): 364 | return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) 365 | 366 | 367 | class RandomFlip(torch.nn.Module): 368 | 369 | def __init__(self, p=0.5, direction='horizontal'): 370 | super().__init__() 371 | assert 0 <= p <= 1 372 | assert direction in ['horizontal', 'vertical', None], 'direction should be horizontal, vertical or None' 373 | self.p = p 374 | self.direction = direction 375 | 376 | def forward(self, sample): 377 | if torch.rand(1) < self.p: 378 | img, label = sample['image'], sample['label'] 379 | if self.direction == 'horizontal': 380 | sample['image'] = F.hflip(img) 381 | sample['label'] = F.hflip(label) 382 | elif self.direction == 'vertical': 383 | sample['image'] = F.vflip(img) 384 | sample['label'] = F.vflip(label) 385 | else: 386 | if torch.rand(1) < 0.5: 387 | sample['image'] = F.hflip(img) 388 | sample['label'] = F.hflip(label) 389 | else: 390 | sample['image'] = F.vflip(img) 391 | sample['label'] = F.vflip(label) 392 | return sample 393 | 394 | def __repr__(self): 395 | return self.__class__.__name__ + '(p={})'.format(self.p) 396 | 397 | 398 | class RandomResizedCrop(T.RandomResizedCrop): 399 | 400 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): 401 | super().__init__(size, scale, ratio, interpolation) 402 | 403 | 404 | def forward(self, sample): 405 | img, mask = sample['image'], sample['label'] 406 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 407 | sample['image'] = F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 408 | sample['label'] = F.resized_crop(mask, i, j, h, w, self.size, self.interpolation) 409 | return sample 410 | 411 | 412 | class RandomRotation(T.RandomRotation): 413 | 414 | def __init__( 415 | self, 416 | degrees, 417 | interpolation=InterpolationMode.NEAREST, 418 | expand=False, 419 | center=None, 420 | fill=0, 421 | p=0.5, 422 | resample=None 423 | ): 424 | super().__init__(degrees, interpolation, expand, center, fill, resample) 425 | self.p = p 426 | 427 | def forward(self, sample): 428 | if torch.rand(1) > self.p: 429 | return sample 430 | img, label = sample['image'], sample['label'] 431 | fill = self.fill 432 | if isinstance(img, torch.Tensor): 433 | if isinstance(fill, (int, float)): 434 | fill = [float(fill)] * F._get_image_num_channels(img) 435 | else: 436 | fill = [float(f) for f in fill] 437 | label_fill = self.fill 438 | if isinstance(label, torch.Tensor): 439 | if isinstance(label_fill, (int, float)): 440 | label_fill = [float(label_fill)] * F._get_image_num_channels(label) 441 | else: 442 | label_fill = [float(f) for f in label_fill] 443 | angle = self.get_params(self.degrees) 444 | sample['image'] = F.rotate(img, angle, self.resample, self.expand, self.center, fill) 445 | sample['label'] = F.rotate(label, angle, self.resample, self.expand, self.center, label_fill) 446 | return sample 447 | 448 | 449 | class RandomRotation90(torch.nn.Module): 450 | 451 | def __init__(self, p=0.5): 452 | super().__init__() 453 | self.p=p 454 | 455 | def forward(self, sample): 456 | if torch.rand(1) < self.p: 457 | rot_times = random.randint(0, 4) 458 | sample['image'] = torch.rot90(sample['image'], rot_times, [1, 2]) 459 | sample['label'] = torch.rot90(sample['label'], rot_times, [1, 2]) 460 | return sample 461 | 462 | 463 | class RandomErasing(T.RandomErasing): 464 | 465 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): 466 | super().__init__(p, scale, ratio, value, inplace) 467 | 468 | def forward(self, sample): 469 | if torch.rand(1) < self.p: 470 | img, label = sample['image'], sample['label'] 471 | # cast self.value to script acceptable type 472 | if isinstance(self.value, (int, float)): 473 | value = [self.value, ] 474 | elif isinstance(self.value, str): 475 | value = None 476 | elif isinstance(self.value, tuple): 477 | value = list(self.value) 478 | else: 479 | value = self.value 480 | 481 | if value is not None and not (len(value) in (1, img.shape[-3])): 482 | raise ValueError( 483 | "If value is a sequence, it should have either a single value or " 484 | "{} (number of input channels)".format(img.shape[-3]) 485 | ) 486 | 487 | x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) 488 | sample['image'] = F.erase(img, x, y, h, w, v, self.inplace) 489 | sample['label'] = F.erase(label, x, y, h, w, v, self.inplace) 490 | return sample 491 | 492 | 493 | class GaussianBlur(T.GaussianBlur): 494 | 495 | def __init__(self, kernel_size, sigma=(0.1, 2.0), p=0.5): 496 | super().__init__(kernel_size, sigma) 497 | self.p = p 498 | 499 | def forward(self, sample): 500 | if torch.rand(1) < self.p: 501 | sigma = self.get_params(self.sigma[0], self.sigma[1]) 502 | sample['image'] = F.gaussian_blur(sample['image'], self.kernel_size, [sigma, sigma]) 503 | return sample 504 | 505 | 506 | class RandomGrayscale(T.RandomGrayscale): 507 | 508 | def __init__(self, p=0.1): 509 | super().__init__(p) 510 | 511 | def forward(self, sample): 512 | if torch.rand(1) < self.p: 513 | img = sample['image'] 514 | if len(img.shape) == 4: 515 | img = img.permute(3, 0, 1, 2).contiguous() 516 | if img.size(1) == 1: 517 | img = img.repeat(1, 3, 1, 1) 518 | num_output_channels = F._get_image_num_channels(img) 519 | img = F.rgb_to_grayscale(img, num_output_channels=num_output_channels) 520 | if len(img.shape) == 4: 521 | img = img.permute(1, 2, 3, 0).contiguous() 522 | img = img[0].unsqueeze(0) 523 | sample['image'] = img 524 | return sample 525 | 526 | def __repr__(self): 527 | return self.__class__.__name__ + '(p={0})'.format(self.p) 528 | 529 | 530 | class ColorJitter(T.ColorJitter): 531 | 532 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=1.): 533 | super().__init__(brightness, contrast, saturation, hue) 534 | self.p = p 535 | 536 | def forward(self, sample): 537 | if torch.rand(1) < self.p: 538 | img = sample['image'] 539 | if len(img.shape) == 4: 540 | img = img.permute(3, 0, 1, 2).contiguous() 541 | elif img.size(0) == 1: 542 | img = img.repeat(3, 1, 1) 543 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ 544 | self.get_params(self.brightness, self.contrast, self.saturation, self.hue) 545 | 546 | for fn_id in fn_idx: 547 | if fn_id == 0 and brightness_factor is not None: 548 | img = F.adjust_brightness(img, brightness_factor) 549 | elif fn_id == 1 and contrast_factor is not None: 550 | img = F.adjust_contrast(img, contrast_factor) 551 | elif fn_id == 2 and saturation_factor is not None: 552 | img = F.adjust_saturation(img, saturation_factor) 553 | elif fn_id == 3 and hue_factor is not None: 554 | img = F.adjust_hue(img, hue_factor) 555 | 556 | if len(img.shape) == 4: 557 | img = img.permute(1, 2, 3, 0).contiguous() 558 | img = img[0].unsqueeze(0) 559 | sample['image'] = img 560 | return sample 561 | 562 | def __repr__(self): 563 | format_string = self.__class__.__name__ + '(' 564 | format_string += 'brightness={0}'.format(self.brightness) 565 | format_string += ', contrast={0}'.format(self.contrast) 566 | format_string += ', saturation={0}'.format(self.saturation) 567 | format_string += ', hue={0})'.format(self.hue) 568 | return format_string 569 | -------------------------------------------------------------------------------- /codes/dataloaders/acdc.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | 3 | from torch.utils.data import Dataset, DataLoader 4 | from ._transforms import build_transforms 5 | from ._samplers import TwoStreamBatchSampler 6 | 7 | 8 | class ACDCDataSet(Dataset): 9 | 10 | def __init__(self, root_dir=r'F:/datasets/ACDC/', mode='train', num=None, transforms=None): 11 | 12 | self.root_dir = root_dir 13 | self.mode = mode 14 | 15 | if isinstance(transforms, list): 16 | transforms = build_transforms(transforms) 17 | self.transforms = transforms 18 | 19 | names_file = f'{self.root_dir}/train_slices.list' if self.mode == 'train' else f'{self.root_dir}/val_slices.list' 20 | 21 | with open(names_file, 'r') as f: 22 | self.sample_list = f.readlines() 23 | self.sample_list = [item.replace('\n', '') for item in self.sample_list] 24 | 25 | if num is not None and self.mode == 'train': 26 | self.sample_list = self.sample_list[:num] 27 | 28 | def __len__(self): 29 | return len(self.sample_list) 30 | 31 | def __getitem__(self, idx): 32 | case = self.sample_list[idx] 33 | h5f = h5py.File(f'{self.root_dir}/data/slices/{case}.h5', 'r') 34 | image = h5f['image'][:] 35 | label = h5f['label'][:] 36 | sample = {'image': image, 'label': label} 37 | if self.transforms: 38 | sample = self.transforms(sample) 39 | return sample 40 | 41 | 42 | def get_acdc_loaders(root_dir=r'F:/datasets/ACDC/', labeled_num=7, labeled_bs=12, batch_size=24, batch_size_val=16, 43 | num_workers=4, worker_init_fn=None, train_transforms=None, val_transforms=None): 44 | ref_dict = {"3": 68, "7": 136, "14": 256, "21": 396, "28": 512, "35": 664, "140": 1312} 45 | 46 | db_train = ACDCDataSet(root_dir=root_dir, mode="train", transforms=train_transforms) 47 | db_val = ACDCDataSet(root_dir=root_dir, mode="val", transforms=val_transforms) 48 | 49 | if labeled_bs < batch_size: 50 | labeled_slice = ref_dict[str(labeled_num)] 51 | labeled_idxs = list(range(0, labeled_slice)) 52 | unlabeled_idxs = list(range(labeled_slice, len(db_train))) 53 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size - labeled_bs) 54 | train_loader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=num_workers, 55 | pin_memory=True, worker_init_fn=worker_init_fn) 56 | else: 57 | train_loader = DataLoader(db_train, batch_size=batch_size, num_workers=num_workers, 58 | pin_memory=True, worker_init_fn=worker_init_fn) 59 | val_loader = DataLoader(db_val, batch_size=batch_size_val, shuffle=False, num_workers=num_workers) 60 | 61 | return train_loader, val_loader 62 | -------------------------------------------------------------------------------- /codes/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from torch import Tensor 5 | from tensorboardX import SummaryWriter 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, 10 | log_dir, 11 | logger_name='root', 12 | log_level=logging.INFO, 13 | log_file=True, 14 | file_mode='w', 15 | tensorboard=True): 16 | 17 | assert log_dir is not None, 'log_dir is None!' 18 | 19 | self.log_dir = log_dir 20 | if not os.path.exists(log_dir): 21 | os.makedirs(self.log_dir) 22 | self.writer = SummaryWriter(log_dir=f'{self.log_dir}/tensorboard_log') if tensorboard else None 23 | self.logger = self.init_logger(logger_name, log_file, log_level, file_mode) 24 | 25 | def init_logger(self, name, log_file, log_level, file_mode): 26 | logger = logging.getLogger(name) 27 | logger.handlers.clear() 28 | stream_handler = logging.StreamHandler() 29 | handlers = [stream_handler] 30 | if log_file: 31 | log_file = os.path.join(self.log_dir, 'info.log') 32 | file_handler = logging.FileHandler(log_file, file_mode) 33 | handlers.append(file_handler) 34 | 35 | date_format = '%Y-%m-%d %H:%M:%S' 36 | # basic_format = '%(asctime)s-%(name)s-%(levelname)s-%(message)s' 37 | basic_format = '%(asctime)s - %(name)s: %(message)s' 38 | formatter = logging.Formatter(basic_format, date_format) 39 | for handler in handlers: 40 | handler.setFormatter(formatter) 41 | handler.setLevel(log_level) 42 | logger.addHandler(handler) 43 | 44 | logger.setLevel(log_level) 45 | return logger 46 | 47 | def update_scalars(self, ordered_dict, step): 48 | for key, value in ordered_dict.items(): 49 | if isinstance(value, Tensor): 50 | ordered_dict[key] = value.item() 51 | self.writer.add_scalar(key, value, step) 52 | 53 | def update_images(self, images_dict, step): 54 | for key, value in images_dict.items(): 55 | self.writer.add_image(key, value, step) 56 | 57 | def info(self, *kwargs): 58 | self.logger.info(*kwargs) 59 | 60 | def close(self): 61 | if self.writer is not None: 62 | self.writer.close() 63 | self.logger.handlers.clear() 64 | del self.logger 65 | -------------------------------------------------------------------------------- /codes/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * 2 | from .pixel_contras_loss import PixelContrastLoss 3 | -------------------------------------------------------------------------------- /codes/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from torch.nn import (BCELoss, BCEWithLogitsLoss, CrossEntropyLoss) 4 | 5 | 6 | class BaseWeightLoss: 7 | def __init__(self, name='loss', weight=1.) -> None: 8 | super().__init__() 9 | self.name = name 10 | self.weight = weight 11 | 12 | @abstractmethod 13 | def _cal_loss(self, preds, targets, **kwargs): 14 | pass 15 | 16 | def __call__(self, preds, targets, **kwargs): 17 | return self._cal_loss(preds, targets, **kwargs) * self.weight 18 | 19 | 20 | class BCELoss_(BaseWeightLoss): 21 | def __init__(self, name='loss_bce', weight=1., **kwargs) -> None: 22 | super().__init__(name, weight) 23 | self.loss = BCELoss(**kwargs) 24 | 25 | def _cal_loss(self, preds, targets, **kwargs): 26 | return self.loss(preds, targets) 27 | 28 | 29 | class BCEWithLogitsLoss_(BaseWeightLoss): 30 | def __init__(self, name='loss_bce', weight=1., **kwargs) -> None: 31 | super().__init__(name, weight) 32 | self.loss = BCEWithLogitsLoss(**kwargs) 33 | 34 | def _cal_loss(self, preds, targets, **kwargs): 35 | return self.loss(preds, targets) 36 | 37 | 38 | class CrossEntropyLoss_(BaseWeightLoss): 39 | def __init__(self, name='loss_ce', weight=1., **kwargs) -> None: 40 | super().__init__(name, weight) 41 | self.loss = CrossEntropyLoss(**kwargs) 42 | 43 | def _cal_loss(self, preds, targets, **kwargs): 44 | targets = targets.to(torch.long).squeeze(1) 45 | return self.loss(preds, targets) 46 | 47 | 48 | class BinaryDiceLoss_(BaseWeightLoss): 49 | def __init__(self, name='loss_dice', weight=1., smooth=1e-5, softmax=True, **kwargs) -> None: 50 | super().__init__(name, weight) 51 | self.smooth = smooth 52 | self.softmax = softmax 53 | 54 | def _cal_loss(self, preds, targets, **kwargs): 55 | assert preds.shape[0] == targets.shape[0] 56 | if self.softmax: 57 | preds = torch.argmax(torch.softmax(preds, dim=1), dim=1, keepdim=True).to(torch.float32) 58 | intersect = torch.sum(torch.mul(preds, targets)) 59 | loss = 1 - (2 * intersect + self.smooth) / (torch.sum(preds.pow(2)) + torch.sum(targets.pow(2)) + self.smooth) 60 | return loss 61 | 62 | 63 | class DiceLoss_(BaseWeightLoss): 64 | def __init__(self, name='loss_dice', weight=1., smooth=1e-5, n_classes=2, class_weight=None, softmax=True, 65 | **kwargs): 66 | super().__init__(name, weight) 67 | self.n_classes = n_classes 68 | self.smooth = smooth 69 | self.class_weight = [1.] * self.n_classes if class_weight is None else class_weight 70 | self.softmax = softmax 71 | 72 | def _one_hot_encoder(self, targets): 73 | target_list = [] 74 | for _ in range(self.n_classes): 75 | temp_prob = targets == _ * torch.ones_like(targets) 76 | target_list.append(temp_prob) 77 | output_tensor = torch.cat(target_list, dim=1) 78 | return output_tensor.float() 79 | 80 | def _dice_loss(self, pred, target): 81 | assert pred.shape[0] == target.shape[0] 82 | intersect = torch.sum(torch.mul(pred, target)) 83 | loss = 1 - (2 * intersect + self.smooth) / (torch.sum(pred.pow(2)) + torch.sum(target.pow(2)) + self.smooth) 84 | return loss 85 | 86 | def _cal_loss(self, preds, targets, **kwargs): 87 | if self.softmax: 88 | preds = torch.softmax(preds, dim=1) 89 | targets = self._one_hot_encoder(targets) 90 | assert preds.size() == targets.size(), 'pred & target shape do not match' 91 | loss = 0.0 92 | for _ in range(self.n_classes): 93 | dice = self._dice_loss(preds[:, _], targets[:, _]) 94 | loss += dice * self.class_weight[_] 95 | return loss / self.n_classes 96 | -------------------------------------------------------------------------------- /codes/losses/pixel_contras_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from abc import ABC 5 | from torch import nn 6 | 7 | 8 | class PixelContrastLoss(nn.Module, ABC): 9 | def __init__(self, 10 | temperature=0.07, 11 | base_temperature=0.07, 12 | max_samples=1024, 13 | max_views=100, 14 | ignore_index=-1, 15 | device='cuda:0'): 16 | super(PixelContrastLoss, self).__init__() 17 | 18 | self.temperature = temperature 19 | self.base_temperature = base_temperature 20 | 21 | self.ignore_label = ignore_index 22 | 23 | self.max_samples = max_samples 24 | self.max_views = max_views 25 | 26 | self.device = device 27 | 28 | def _anchor_sampling(self, X, y_hat, y): 29 | batch_size, feat_dim = X.shape[0], X.shape[-1] 30 | 31 | classes = [] 32 | total_classes = 0 33 | for ii in range(batch_size): 34 | this_y = y_hat[ii] 35 | 36 | this_classes = torch.unique(this_y) 37 | this_classes = [x for x in this_classes if x != self.ignore_label] 38 | this_classes = [x for x in this_classes if (this_y == x).nonzero().shape[0] > self.max_views] 39 | 40 | classes.append(this_classes) 41 | total_classes += len(this_classes) 42 | if total_classes == 0: 43 | return None, None 44 | 45 | n_view = self.max_samples // total_classes 46 | n_view = min(n_view, self.max_views) 47 | 48 | X_ = torch.zeros((total_classes, n_view, feat_dim), dtype=torch.float).to(self.device) 49 | y_ = torch.zeros(total_classes, dtype=torch.float).to(self.device) 50 | 51 | X_ptr = 0 52 | for ii in range(batch_size): 53 | this_y_hat = y_hat[ii] 54 | this_y = y[ii] 55 | this_classes = classes[ii] 56 | 57 | for cls_id in this_classes: 58 | hard_indices = ((this_y_hat == cls_id) & (this_y != cls_id)).nonzero() 59 | easy_indices = ((this_y_hat == cls_id) & (this_y == cls_id)).nonzero() 60 | 61 | num_hard = hard_indices.shape[0] 62 | num_easy = easy_indices.shape[0] 63 | 64 | if num_hard >= n_view / 2 and num_easy >= n_view / 2: 65 | num_hard_keep = n_view // 2 66 | num_easy_keep = n_view - num_hard_keep 67 | elif num_hard >= n_view / 2: 68 | num_easy_keep = num_easy 69 | num_hard_keep = n_view - num_easy_keep 70 | elif num_easy >= n_view / 2: 71 | num_hard_keep = num_hard 72 | num_easy_keep = n_view - num_hard_keep 73 | else: 74 | print('this shoud be never touched! {} {} {}'.format(num_hard, num_easy, n_view)) 75 | raise Exception 76 | 77 | perm = torch.randperm(num_hard) 78 | hard_indices = hard_indices[perm[:num_hard_keep]] 79 | perm = torch.randperm(num_easy) 80 | easy_indices = easy_indices[perm[:num_easy_keep]] 81 | indices = torch.cat((hard_indices, easy_indices), dim=0) 82 | 83 | X_[X_ptr, :, :] = X[ii, indices, :].squeeze(1) 84 | y_[X_ptr] = cls_id 85 | X_ptr += 1 86 | return X_, y_ 87 | 88 | def _sample_negative(self, Q): 89 | class_num, memory_size, feat_size = Q.shape 90 | 91 | x_ = torch.zeros((class_num * memory_size, feat_size)).float().to(self.device) 92 | y_ = torch.zeros((class_num * memory_size, 1)).float().to(self.device) 93 | 94 | sample_ptr = 0 95 | for c in range(class_num): 96 | if c == 0: 97 | continue 98 | this_q = Q[c, :memory_size, :] 99 | x_[sample_ptr:sample_ptr + memory_size, ...] = this_q 100 | y_[sample_ptr:sample_ptr + memory_size, ...] = c 101 | sample_ptr += memory_size 102 | return x_, y_ 103 | 104 | def _contrastive(self, X_anchor, y_anchor, queue=None): 105 | anchor_num, n_view = X_anchor.shape[0], X_anchor.shape[1] 106 | 107 | y_anchor = y_anchor.contiguous().view(-1, 1) # (anchor_num × n_view) × 1 108 | anchor_count = n_view 109 | anchor_feature = torch.cat(torch.unbind(X_anchor, dim=1), dim=0) # (anchor_num × n_view) × feat_dim 110 | 111 | if queue is not None: 112 | X_contrast, y_contrast = self._sample_negative(queue) 113 | y_contrast = y_contrast.contiguous().view(-1, 1) 114 | contrast_count = 1 115 | contrast_feature = X_contrast 116 | else: 117 | y_contrast = y_anchor 118 | contrast_count = n_view 119 | contrast_feature = torch.cat(torch.unbind(X_anchor, dim=1), dim=0) 120 | 121 | # (anchor_num × n_view) × (anchor_num × n_view) 122 | mask = torch.eq(y_anchor, y_contrast.T).float().to(self.device) 123 | # (anchor_num × n_view) × (anchor_num × n_view) 124 | anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature) 125 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 126 | logits = anchor_dot_contrast - logits_max.detach() 127 | mask = mask.repeat(anchor_count, contrast_count) 128 | neg_mask = 1 - mask 129 | 130 | logits_mask = torch.ones_like(mask).\ 131 | scatter_(1, torch.arange(anchor_num * anchor_count).view(-1, 1).to(self.device), 0) 132 | mask = mask * logits_mask 133 | 134 | neg_logits = torch.exp(logits) * neg_mask 135 | neg_logits = neg_logits.sum(1, keepdim=True) 136 | 137 | exp_logits = torch.exp(logits) 138 | 139 | log_prob = logits - torch.log(exp_logits + neg_logits) 140 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-5) 141 | 142 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 143 | loss = loss.mean() 144 | return loss 145 | 146 | def forward(self, feats, labels=None, predict=None, queue=None): 147 | labels = labels.float().clone() 148 | labels = torch.nn.functional.interpolate(labels, (feats.shape[2], feats.shape[3]), mode='nearest') 149 | predict = torch.nn.functional.interpolate(predict, (feats.shape[2], feats.shape[3]), mode='nearest') 150 | labels = labels.long() 151 | assert labels.shape[-1] == feats.shape[-1], '{} {}'.format(labels.shape, feats.shape) 152 | 153 | batch_size = feats.shape[0] 154 | 155 | labels = labels.contiguous().view(batch_size, -1) 156 | predict = predict.contiguous().view(batch_size, -1) 157 | feats = feats.permute(0, 2, 3, 1) 158 | feats = feats.contiguous().view(feats.shape[0], -1, feats.shape[-1]) 159 | 160 | # feats: N×(HW)×C 161 | # labels: N×(HW) 162 | # predict: N×(HW) 163 | feats_, labels_ = self._anchor_sampling(feats, labels, predict) 164 | loss = self._contrastive(feats_, labels_, queue=queue) 165 | return loss 166 | -------------------------------------------------------------------------------- /codes/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import Dice, Jaccard, HD95, ASD 2 | -------------------------------------------------------------------------------- /codes/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | preds.shape: (N,1,H,W) || (N,1,H,W,D) || (N,H,W) || (N,H,W,D) 3 | labels.shape: (N,1,H,W) || (N,1,H,W,D) || (N,H,W) || (N,H,W,D) 4 | """ 5 | import torch 6 | from medpy import metric 7 | 8 | 9 | class Dice: 10 | 11 | def __init__(self, name='Dice', class_indexs=[1], class_names=['xx']) -> None: 12 | super().__init__() 13 | self.name = name 14 | self.class_indexs = class_indexs 15 | self.class_names = class_names 16 | 17 | def __call__(self, preds, labels): 18 | res = {} 19 | for class_index, class_name in zip(self.class_indexs, self.class_names): 20 | preds_ = (preds == class_index).to(torch.int) 21 | labels_ = (labels == class_index).to(torch.int) 22 | intersection = (preds_ * labels_).sum() 23 | try: 24 | res[class_name] = (2 * intersection) / (preds_.sum() + labels_.sum()).item() 25 | except ZeroDivisionError: 26 | res[class_name] = 1.0 27 | # res[class_name] = metric.dc(preds_.cpu().numpy(), labels_.cpu().numpy()) 28 | return res 29 | 30 | 31 | class Jaccard: 32 | def __init__(self, name='Jaccard', class_indexs=[1], class_names=['xx']) -> None: 33 | super().__init__() 34 | self.name = name 35 | self.class_indexs = class_indexs 36 | self.class_names = class_names 37 | 38 | def __call__(self, preds, labels): 39 | res = {} 40 | for class_index, class_name in zip(self.class_indexs, self.class_names): 41 | preds_ = (preds == class_index).to(torch.int) 42 | labels_ = (labels == class_index).to(torch.int) 43 | intersection = (preds_ * labels_).sum() 44 | union = ((preds_ + labels_) != 0).sum() 45 | res[class_name] = intersection / union 46 | try: 47 | res[class_name] = intersection / union 48 | except ZeroDivisionError: 49 | res[class_name] = 1.0 50 | # res[class_name] = metric.jc(preds_.cpu().numpy(), labels_.cpu().numpy()) 51 | return res 52 | 53 | 54 | class HD95: 55 | """ 56 | 95th percentile of the Hausdorff Distance. 57 | """ 58 | 59 | def __init__(self, name='95HD', class_indexs=[1], class_names=['xx']) -> None: 60 | super().__init__() 61 | self.name = name 62 | self.class_indexs = class_indexs 63 | self.class_names = class_names 64 | 65 | def __call__(self, preds, labels): 66 | if preds.size(1) == 1: 67 | preds = preds.squeeze(1) 68 | if labels.size(1) == 1: 69 | labels = labels.squeeze(1) 70 | res = {} 71 | for class_index, class_name in zip(self.class_indexs, self.class_names): 72 | res[class_name] = 0. 73 | for i in range(preds.size(0)): 74 | preds_ = (preds[i] == class_index).to(torch.int) 75 | labels_ = (labels[i] == class_index).to(torch.int) 76 | if preds_.sum() == 0.: 77 | preds_ = (preds_ == 0).to(torch.int) 78 | res[class_name] += torch.tensor(metric.hd95(preds_.cpu().numpy(), labels_.cpu().numpy())) 79 | res[class_name] /= preds.size(0) 80 | return res 81 | 82 | 83 | class ASD: 84 | """ 85 | Average surface distance. 86 | """ 87 | 88 | def __init__(self, name='ASD', class_indexs=[1], class_names=['xx']) -> None: 89 | super().__init__() 90 | self.name = name 91 | self.class_indexs = class_indexs 92 | self.class_names = class_names 93 | 94 | def __call__(self, preds, labels): 95 | if preds.size(1) == 1: 96 | preds = preds.squeeze(1) 97 | if labels.size(1) == 1: 98 | labels = labels.squeeze(1) 99 | res = {} 100 | for class_index, class_name in zip(self.class_indexs, self.class_names): 101 | res[class_name] = 0. 102 | for i in range(preds.size(0)): 103 | preds_ = (preds[i] == class_index).to(torch.int) 104 | labels_ = (labels[i] == class_index).to(torch.int) 105 | if preds_.sum() == 0.: 106 | preds_ = (preds_ == 0).to(torch.int) 107 | res[class_name] += torch.tensor(metric.asd(preds_.cpu().numpy(), labels_.cpu().numpy())) 108 | res[class_name] /= preds.size(0) 109 | return res -------------------------------------------------------------------------------- /codes/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import UNet 2 | from .unet_tf import UNetTF 3 | -------------------------------------------------------------------------------- /codes/models/_base.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from ..losses import * 3 | 4 | 5 | class BaseModel2D(nn.Module): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def _init_weights(self, **kwargs): 11 | for m in self.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | torch.nn.init.kaiming_normal_(m.weight) 14 | elif isinstance(m, nn.BatchNorm2d): 15 | nn.init.constant_(m.weight, 1) 16 | nn.init.constant_(m.bias, 0) 17 | 18 | def inference(self, x, **kwargs): 19 | logits = self(x) 20 | preds = torch.argmax(logits['seg'], dim=1, keepdim=True).to(torch.float) 21 | return preds 22 | -------------------------------------------------------------------------------- /codes/models/swin_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from einops import rearrange 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | 7 | 8 | class Mlp(nn.Module): 9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 10 | super().__init__() 11 | out_features = out_features or in_features 12 | hidden_features = hidden_features or in_features 13 | self.fc1 = nn.Linear(in_features, hidden_features) 14 | self.act = act_layer() 15 | self.fc2 = nn.Linear(hidden_features, out_features) 16 | self.drop = nn.Dropout(drop) 17 | 18 | def forward(self, x): 19 | x = self.fc1(x) 20 | x = self.act(x) 21 | x = self.drop(x) 22 | x = self.fc2(x) 23 | x = self.drop(x) 24 | return x 25 | 26 | 27 | def window_partition(x, window_size): 28 | """ 29 | Args: 30 | x: (B, H, W, C) 31 | window_size (int): window size 32 | Returns: 33 | windows: (num_windows*B, window_size, window_size, C) 34 | """ 35 | B, H, W, C = x.shape 36 | x = x.view(B, H // window_size, window_size, 37 | W // window_size, window_size, C) 38 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous( 39 | ).view(-1, window_size, window_size, C) 40 | return windows 41 | 42 | 43 | def window_reverse(windows, window_size, H, W): 44 | """ 45 | Args: 46 | windows: (num_windows*B, window_size, window_size, C) 47 | window_size (int): Window size 48 | H (int): Height of image 49 | W (int): Width of image 50 | Returns: 51 | x: (B, H, W, C) 52 | """ 53 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 54 | x = windows.view(B, H // window_size, W // window_size, 55 | window_size, window_size, -1) 56 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 57 | return x 58 | 59 | 60 | class WindowAttention(nn.Module): 61 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 62 | It supports both of shifted and non-shifted window. 63 | Args: 64 | dim (int): Number of input channels. 65 | window_size (tuple[int]): The height and width of the window. 66 | num_heads (int): Number of attention heads. 67 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 68 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 69 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 70 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 71 | """ 72 | 73 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 74 | 75 | super().__init__() 76 | self.dim = dim 77 | self.window_size = window_size # Wh, Ww 78 | self.num_heads = num_heads 79 | head_dim = dim // num_heads 80 | self.scale = qk_scale or head_dim ** -0.5 81 | 82 | # define a parameter table of relative position bias 83 | self.relative_position_bias_table = nn.Parameter( 84 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 85 | 86 | # get pair-wise relative position index for each token inside the window 87 | coords_h = torch.arange(self.window_size[0]) 88 | coords_w = torch.arange(self.window_size[1]) 89 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 90 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 91 | relative_coords = coords_flatten[:, :, None] - \ 92 | coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 93 | relative_coords = relative_coords.permute( 94 | 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 95 | relative_coords[:, :, 0] += self.window_size[0] - \ 96 | 1 # shift to start from 0 97 | relative_coords[:, :, 1] += self.window_size[1] - 1 98 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 99 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 100 | self.register_buffer("relative_position_index", 101 | relative_position_index) 102 | 103 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 104 | self.attn_drop = nn.Dropout(attn_drop) 105 | self.proj = nn.Linear(dim, dim) 106 | self.proj_drop = nn.Dropout(proj_drop) 107 | 108 | trunc_normal_(self.relative_position_bias_table, std=.02) 109 | self.softmax = nn.Softmax(dim=-1) 110 | 111 | def forward(self, x, mask=None): 112 | """ 113 | Args: 114 | x: input features with shape of (num_windows*B, N, C) 115 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 116 | """ 117 | B_, N, C = x.shape 118 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // 119 | self.num_heads).permute(2, 0, 3, 1, 4) 120 | # make torchscript happy (cannot use tensor as tuple) 121 | q, k, v = qkv[0], qkv[1], qkv[2] 122 | 123 | q = q * self.scale 124 | attn = (q @ k.transpose(-2, -1)) 125 | 126 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 127 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 128 | relative_position_bias = relative_position_bias.permute( 129 | 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 130 | attn = attn + relative_position_bias.unsqueeze(0) 131 | 132 | if mask is not None: 133 | nW = mask.shape[0] 134 | attn = attn.view(B_ // nW, nW, self.num_heads, N, 135 | N) + mask.unsqueeze(1).unsqueeze(0) 136 | attn = attn.view(-1, self.num_heads, N, N) 137 | attn = self.softmax(attn) 138 | else: 139 | attn = self.softmax(attn) 140 | 141 | attn = self.attn_drop(attn) 142 | 143 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 144 | x = self.proj(x) 145 | x = self.proj_drop(x) 146 | return x 147 | 148 | def extra_repr(self) -> str: 149 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 150 | 151 | def flops(self, N): 152 | # calculate flops for 1 window with token length of N 153 | flops = 0 154 | # qkv = self.qkv(x) 155 | flops += N * self.dim * 3 * self.dim 156 | # attn = (q @ k.transpose(-2, -1)) 157 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 158 | # x = (attn @ v) 159 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 160 | # x = self.proj(x) 161 | flops += N * self.dim * self.dim 162 | return flops 163 | 164 | 165 | class SwinTransformerBlock(nn.Module): 166 | r""" Swin Transformer Block. 167 | Args: 168 | dim (int): Number of input channels. 169 | input_resolution (tuple[int]): Input resulotion. 170 | num_heads (int): Number of attention heads. 171 | window_size (int): Window size. 172 | shift_size (int): Shift size for SW-MSA. 173 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 174 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 175 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 176 | drop (float, optional): Dropout rate. Default: 0.0 177 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 178 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 179 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 180 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 181 | """ 182 | 183 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 184 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 185 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 186 | super().__init__() 187 | self.dim = dim 188 | self.input_resolution = input_resolution 189 | self.num_heads = num_heads 190 | self.window_size = window_size 191 | self.shift_size = shift_size 192 | self.mlp_ratio = mlp_ratio 193 | if min(self.input_resolution) <= self.window_size: 194 | # if window size is larger than input resolution, we don't partition windows 195 | self.shift_size = 0 196 | self.window_size = min(self.input_resolution) 197 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 198 | 199 | self.norm1 = norm_layer(dim) 200 | self.attn = WindowAttention( 201 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 202 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 203 | 204 | self.drop_path = DropPath( 205 | drop_path) if drop_path > 0. else nn.Identity() 206 | self.norm2 = norm_layer(dim) 207 | mlp_hidden_dim = int(dim * mlp_ratio) 208 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 209 | act_layer=act_layer, drop=drop) 210 | 211 | if self.shift_size > 0: 212 | # calculate attention mask for SW-MSA 213 | H, W = self.input_resolution 214 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 215 | h_slices = (slice(0, -self.window_size), 216 | slice(-self.window_size, -self.shift_size), 217 | slice(-self.shift_size, None)) 218 | w_slices = (slice(0, -self.window_size), 219 | slice(-self.window_size, -self.shift_size), 220 | slice(-self.shift_size, None)) 221 | cnt = 0 222 | for h in h_slices: 223 | for w in w_slices: 224 | img_mask[:, h, w, :] = cnt 225 | cnt += 1 226 | 227 | # nW, window_size, window_size, 1 228 | mask_windows = window_partition(img_mask, self.window_size) 229 | mask_windows = mask_windows.view(-1, 230 | self.window_size * self.window_size) 231 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 232 | attn_mask = attn_mask.masked_fill( 233 | attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 234 | else: 235 | attn_mask = None 236 | 237 | self.register_buffer("attn_mask", attn_mask) 238 | 239 | def forward(self, x): 240 | H, W = self.input_resolution 241 | B, L, C = x.shape 242 | assert L == H * W, "input feature has wrong size" 243 | 244 | shortcut = x 245 | x = self.norm1(x) 246 | x = x.view(B, H, W, C) 247 | 248 | # cyclic shift 249 | if self.shift_size > 0: 250 | shifted_x = torch.roll( 251 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 252 | else: 253 | shifted_x = x 254 | 255 | # partition windows 256 | # nW*B, window_size, window_size, C 257 | x_windows = window_partition(shifted_x, self.window_size) 258 | # nW*B, window_size*window_size, C 259 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 260 | 261 | # W-MSA/SW-MSA 262 | # nW*B, window_size*window_size, C 263 | attn_windows = self.attn(x_windows, mask=self.attn_mask) 264 | 265 | # merge windows 266 | attn_windows = attn_windows.view(-1, 267 | self.window_size, self.window_size, C) 268 | shifted_x = window_reverse( 269 | attn_windows, self.window_size, H, W) # B H' W' C 270 | 271 | # reverse cyclic shift 272 | if self.shift_size > 0: 273 | x = torch.roll(shifted_x, shifts=( 274 | self.shift_size, self.shift_size), dims=(1, 2)) 275 | else: 276 | x = shifted_x 277 | x = x.view(B, H * W, C) 278 | 279 | # FFN 280 | x = shortcut + self.drop_path(x) 281 | x = x + self.drop_path(self.mlp(self.norm2(x))) 282 | 283 | return x 284 | 285 | def extra_repr(self) -> str: 286 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 287 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 288 | 289 | def flops(self): 290 | flops = 0 291 | H, W = self.input_resolution 292 | # norm1 293 | flops += self.dim * H * W 294 | # W-MSA/SW-MSA 295 | nW = H * W / self.window_size / self.window_size 296 | flops += nW * self.attn.flops(self.window_size * self.window_size) 297 | # mlp 298 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 299 | # norm2 300 | flops += self.dim * H * W 301 | return flops 302 | 303 | 304 | class PatchMerging(nn.Module): 305 | r""" Patch Merging Layer. 306 | Args: 307 | input_resolution (tuple[int]): Resolution of input feature. 308 | dim (int): Number of input channels. 309 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 310 | """ 311 | 312 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 313 | super().__init__() 314 | self.input_resolution = input_resolution 315 | self.dim = dim 316 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 317 | self.norm = norm_layer(4 * dim) 318 | 319 | def forward(self, x): 320 | """ 321 | x: B, H*W, C 322 | """ 323 | H, W = self.input_resolution 324 | B, L, C = x.shape 325 | assert L == H * W, "input feature has wrong size" 326 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 327 | 328 | x = x.view(B, H, W, C) 329 | 330 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 331 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 332 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 333 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 334 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 335 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 336 | 337 | x = self.norm(x) 338 | x = self.reduction(x) 339 | 340 | return x 341 | 342 | def extra_repr(self) -> str: 343 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 344 | 345 | def flops(self): 346 | H, W = self.input_resolution 347 | flops = H * W * self.dim 348 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 349 | return flops 350 | 351 | 352 | class PatchExpand(nn.Module): 353 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 354 | super().__init__() 355 | self.input_resolution = input_resolution 356 | self.dim = dim 357 | self.expand = nn.Linear( 358 | dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() 359 | self.norm = norm_layer(dim // dim_scale) 360 | 361 | def forward(self, x): 362 | """ 363 | x: B, H*W, C 364 | """ 365 | H, W = self.input_resolution 366 | x = self.expand(x) 367 | B, L, C = x.shape 368 | assert L == H * W, "input feature has wrong size" 369 | 370 | x = x.view(B, H, W, C) 371 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', 372 | p1=2, p2=2, c=C // 4) 373 | x = x.view(B, -1, C // 4) 374 | x = self.norm(x) 375 | 376 | return x 377 | 378 | 379 | class FinalPatchExpand_X4(nn.Module): 380 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 381 | super().__init__() 382 | self.input_resolution = input_resolution 383 | self.dim = dim 384 | self.dim_scale = dim_scale 385 | self.expand = nn.Linear(dim, 16 * dim, bias=False) 386 | self.output_dim = dim 387 | self.norm = norm_layer(self.output_dim) 388 | 389 | def forward(self, x): 390 | """ 391 | x: B, H*W, C 392 | """ 393 | H, W = self.input_resolution 394 | x = self.expand(x) 395 | B, L, C = x.shape 396 | assert L == H * W, "input feature has wrong size" 397 | 398 | x = x.view(B, H, W, C) 399 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', 400 | p1=self.dim_scale, p2=self.dim_scale, c=C // (self.dim_scale ** 2)) 401 | x = x.view(B, -1, self.output_dim) 402 | x = self.norm(x) 403 | 404 | return x 405 | 406 | 407 | class BasicLayer(nn.Module): 408 | """ A basic Swin Transformer layer for one stage. 409 | Args: 410 | dim (int): Number of input channels. 411 | input_resolution (tuple[int]): Input resolution. 412 | depth (int): Number of blocks. 413 | num_heads (int): Number of attention heads. 414 | window_size (int): Local window size. 415 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 416 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 417 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 418 | drop (float, optional): Dropout rate. Default: 0.0 419 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 420 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 421 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 422 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 423 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 424 | """ 425 | 426 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 427 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 428 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 429 | 430 | super().__init__() 431 | self.dim = dim 432 | self.input_resolution = input_resolution 433 | self.depth = depth 434 | self.use_checkpoint = use_checkpoint 435 | 436 | # build blocks 437 | self.blocks = nn.ModuleList([ 438 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 439 | num_heads=num_heads, window_size=window_size, 440 | shift_size=0 if ( 441 | i % 2 == 0) else window_size // 2, 442 | mlp_ratio=mlp_ratio, 443 | qkv_bias=qkv_bias, qk_scale=qk_scale, 444 | drop=drop, attn_drop=attn_drop, 445 | drop_path=drop_path[i] if isinstance( 446 | drop_path, list) else drop_path, 447 | norm_layer=norm_layer) 448 | for i in range(depth)]) 449 | 450 | # patch merging layer 451 | if downsample is not None: 452 | self.downsample = downsample( 453 | input_resolution, dim=dim, norm_layer=norm_layer) 454 | else: 455 | self.downsample = None 456 | 457 | def forward(self, x): 458 | for blk in self.blocks: 459 | if self.use_checkpoint: 460 | x = checkpoint.checkpoint(blk, x) 461 | else: 462 | x = blk(x) 463 | if self.downsample is not None: 464 | x = self.downsample(x) 465 | return x 466 | 467 | def extra_repr(self) -> str: 468 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 469 | 470 | def flops(self): 471 | flops = 0 472 | for blk in self.blocks: 473 | flops += blk.flops() 474 | if self.downsample is not None: 475 | flops += self.downsample.flops() 476 | return flops 477 | 478 | 479 | class BasicLayer_up(nn.Module): 480 | """ A basic Swin Transformer layer for one stage. 481 | Args: 482 | dim (int): Number of input channels. 483 | input_resolution (tuple[int]): Input resolution. 484 | depth (int): Number of blocks. 485 | num_heads (int): Number of attention heads. 486 | window_size (int): Local window size. 487 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 488 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 489 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 490 | drop (float, optional): Dropout rate. Default: 0.0 491 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 492 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 493 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 494 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 495 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 496 | """ 497 | 498 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 499 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 500 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 501 | 502 | super().__init__() 503 | self.dim = dim 504 | self.input_resolution = input_resolution 505 | self.depth = depth 506 | self.use_checkpoint = use_checkpoint 507 | 508 | # build blocks 509 | self.blocks = nn.ModuleList([ 510 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 511 | num_heads=num_heads, window_size=window_size, 512 | shift_size=0 if ( 513 | i % 2 == 0) else window_size // 2, 514 | mlp_ratio=mlp_ratio, 515 | qkv_bias=qkv_bias, qk_scale=qk_scale, 516 | drop=drop, attn_drop=attn_drop, 517 | drop_path=drop_path[i] if isinstance( 518 | drop_path, list) else drop_path, 519 | norm_layer=norm_layer) 520 | for i in range(depth)]) 521 | 522 | # patch merging layer 523 | if upsample is not None: 524 | self.upsample = PatchExpand( 525 | input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 526 | else: 527 | self.upsample = None 528 | 529 | def forward(self, x): 530 | for blk in self.blocks: 531 | if self.use_checkpoint: 532 | x = checkpoint.checkpoint(blk, x) 533 | else: 534 | x = blk(x) 535 | if self.upsample is not None: 536 | x = self.upsample(x) 537 | return x 538 | 539 | 540 | class PatchEmbed(nn.Module): 541 | r""" Image to Patch Embedding 542 | Args: 543 | img_size (int): Image size. Default: 224. 544 | patch_size (int): Patch token size. Default: 4. 545 | in_chans (int): Number of input image channels. Default: 3. 546 | embed_dim (int): Number of linear projection output channels. Default: 96. 547 | norm_layer (nn.Module, optional): Normalization layer. Default: None 548 | """ 549 | 550 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 551 | super().__init__() 552 | img_size = to_2tuple(img_size) 553 | patch_size = to_2tuple(patch_size) 554 | patches_resolution = [img_size[0] // 555 | patch_size[0], img_size[1] // patch_size[1]] 556 | self.img_size = img_size 557 | self.patch_size = patch_size 558 | self.patches_resolution = patches_resolution 559 | self.num_patches = patches_resolution[0] * patches_resolution[1] 560 | 561 | self.in_chans = in_chans 562 | self.embed_dim = embed_dim 563 | 564 | self.proj = nn.Conv2d(in_chans, embed_dim, 565 | kernel_size=patch_size, stride=patch_size) 566 | if norm_layer is not None: 567 | self.norm = norm_layer(embed_dim) 568 | else: 569 | self.norm = None 570 | 571 | def forward(self, x): 572 | B, C, H, W = x.shape 573 | # FIXME look at relaxing size constraints 574 | assert H == self.img_size[0] and W == self.img_size[1], \ 575 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 576 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 577 | if self.norm is not None: 578 | x = self.norm(x) 579 | return x 580 | 581 | def flops(self): 582 | Ho, Wo = self.patches_resolution 583 | flops = Ho * Wo * self.embed_dim * self.in_chans * \ 584 | (self.patch_size[0] * self.patch_size[1]) 585 | if self.norm is not None: 586 | flops += Ho * Wo * self.embed_dim 587 | return flops 588 | 589 | 590 | class SwinTransformerSys(nn.Module): 591 | r""" Swin Transformer 592 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 593 | https://arxiv.org/pdf/2103.14030 594 | Args: 595 | img_size (int | tuple(int)): Input image size. Default 224 596 | patch_size (int | tuple(int)): Patch size. Default: 4 597 | in_chans (int): Number of input image channels. Default: 3 598 | num_classes (int): Number of classes for classification head. Default: 1000 599 | embed_dim (int): Patch embedding dimension. Default: 96 600 | depths (tuple(int)): Depth of each Swin Transformer layer. 601 | num_heads (tuple(int)): Number of attention heads in different layers. 602 | window_size (int): Window size. Default: 7 603 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 604 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 605 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 606 | drop_rate (float): Dropout rate. Default: 0 607 | attn_drop_rate (float): Attention dropout rate. Default: 0 608 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 609 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 610 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 611 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 612 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 613 | """ 614 | 615 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 616 | embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], 617 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 618 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 619 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 620 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 621 | super().__init__() 622 | 623 | print( 624 | "SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format( 625 | depths, 626 | depths_decoder, drop_path_rate, num_classes)) 627 | 628 | self.num_classes = num_classes 629 | self.num_layers = len(depths) 630 | self.embed_dim = embed_dim 631 | self.ape = ape 632 | self.patch_norm = patch_norm 633 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 634 | self.num_features_up = int(embed_dim * 2) 635 | self.mlp_ratio = mlp_ratio 636 | self.final_upsample = final_upsample 637 | 638 | # split image into non-overlapping patches 639 | self.patch_embed = PatchEmbed( 640 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 641 | norm_layer=norm_layer if self.patch_norm else None) 642 | num_patches = self.patch_embed.num_patches 643 | patches_resolution = self.patch_embed.patches_resolution 644 | self.patches_resolution = patches_resolution 645 | 646 | # absolute position embedding 647 | if self.ape: 648 | self.absolute_pos_embed = nn.Parameter( 649 | torch.zeros(1, num_patches, embed_dim)) 650 | trunc_normal_(self.absolute_pos_embed, std=.02) 651 | 652 | self.pos_drop = nn.Dropout(p=drop_rate) 653 | 654 | # stochastic depth 655 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 656 | sum(depths))] # stochastic depth decay rule 657 | 658 | # build encoder and bottleneck layers 659 | self.layers = nn.ModuleList() 660 | for i_layer in range(self.num_layers): 661 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 662 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 663 | patches_resolution[1] // (2 ** i_layer)), 664 | depth=depths[i_layer], 665 | num_heads=num_heads[i_layer], 666 | window_size=window_size, 667 | mlp_ratio=self.mlp_ratio, 668 | qkv_bias=qkv_bias, qk_scale=qk_scale, 669 | drop=drop_rate, attn_drop=attn_drop_rate, 670 | drop_path=dpr[sum(depths[:i_layer]):sum( 671 | depths[:i_layer + 1])], 672 | norm_layer=norm_layer, 673 | downsample=PatchMerging if ( 674 | i_layer < self.num_layers - 1) else None, 675 | use_checkpoint=use_checkpoint) 676 | self.layers.append(layer) 677 | 678 | # build decoder layers 679 | self.layers_up = nn.ModuleList() 680 | self.concat_back_dim = nn.ModuleList() 681 | for i_layer in range(self.num_layers): 682 | concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 683 | int(embed_dim * 2 ** ( 684 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity() 685 | if i_layer == 0: 686 | layer_up = PatchExpand( 687 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 688 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 689 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer) 690 | else: 691 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 692 | input_resolution=( 693 | patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 694 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 695 | depth=depths[( 696 | self.num_layers - 1 - i_layer)], 697 | num_heads=num_heads[( 698 | self.num_layers - 1 - i_layer)], 699 | window_size=window_size, 700 | mlp_ratio=self.mlp_ratio, 701 | qkv_bias=qkv_bias, qk_scale=qk_scale, 702 | drop=drop_rate, attn_drop=attn_drop_rate, 703 | drop_path=dpr[sum(depths[:( 704 | self.num_layers - 1 - i_layer)]):sum( 705 | depths[:(self.num_layers - 1 - i_layer) + 1])], 706 | norm_layer=norm_layer, 707 | upsample=PatchExpand if ( 708 | i_layer < self.num_layers - 1) else None, 709 | use_checkpoint=use_checkpoint) 710 | self.layers_up.append(layer_up) 711 | self.concat_back_dim.append(concat_linear) 712 | 713 | self.norm = norm_layer(self.num_features) 714 | self.norm_up = norm_layer(self.embed_dim) 715 | 716 | if self.final_upsample == "expand_first": 717 | # print("---final upsample expand_first---") 718 | self.up = FinalPatchExpand_X4(input_resolution=( 719 | img_size // patch_size, img_size // patch_size), dim_scale=4, dim=embed_dim) 720 | self.output = nn.Conv2d( 721 | in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) 722 | 723 | self.apply(self._init_weights) 724 | 725 | def _init_weights(self, m): 726 | if isinstance(m, nn.Linear): 727 | trunc_normal_(m.weight, std=.02) 728 | if m.bias is not None: 729 | nn.init.constant_(m.bias, 0) 730 | elif isinstance(m, nn.LayerNorm): 731 | nn.init.constant_(m.bias, 0) 732 | nn.init.constant_(m.weight, 1.0) 733 | 734 | @torch.jit.ignore 735 | def no_weight_decay(self): 736 | return {'absolute_pos_embed'} 737 | 738 | @torch.jit.ignore 739 | def no_weight_decay_keywords(self): 740 | return {'relative_position_bias_table'} 741 | 742 | # Encoder and Bottleneck 743 | def forward_features(self, x): 744 | x = self.patch_embed(x) 745 | if self.ape: 746 | x = x + self.absolute_pos_embed 747 | x = self.pos_drop(x) 748 | x_downsample = [] 749 | 750 | for layer in self.layers: 751 | x_downsample.append(x) 752 | x = layer(x) 753 | 754 | x = self.norm(x) # B L C 755 | 756 | return x, x_downsample 757 | 758 | # Dencoder and Skip connection 759 | def forward_up_features(self, x, x_downsample): 760 | for inx, layer_up in enumerate(self.layers_up): 761 | if inx == 0: 762 | x = layer_up(x) 763 | else: 764 | x = torch.cat([x, x_downsample[3 - inx]], -1) 765 | x = self.concat_back_dim[inx](x) 766 | x = layer_up(x) 767 | 768 | x = self.norm_up(x) # B L C 769 | 770 | return x 771 | 772 | def up_x4(self, x): 773 | H, W = self.patches_resolution 774 | B, L, C = x.shape 775 | assert L == H * W, "input features has wrong size" 776 | 777 | if self.final_upsample == "expand_first": 778 | x = self.up(x) 779 | x = x.view(B, 4 * H, 4 * W, -1) 780 | x = x.permute(0, 3, 1, 2) # B,C,H,W 781 | x = self.output(x) 782 | 783 | return x 784 | 785 | def forward(self, x): 786 | x, x_downsample = self.forward_features(x) 787 | print(x.shape) 788 | for _ in x_downsample: 789 | print(_.shape) 790 | x = self.forward_up_features(x, x_downsample) 791 | x = self.up_x4(x) 792 | 793 | return x 794 | 795 | def flops(self): 796 | flops = 0 797 | flops += self.patch_embed.flops() 798 | for i, layer in enumerate(self.layers): 799 | flops += layer.flops() 800 | flops += self.num_features * \ 801 | self.patches_resolution[0] * \ 802 | self.patches_resolution[1] // (2 ** self.num_layers) 803 | flops += self.num_features * self.num_classes 804 | return flops 805 | 806 | 807 | class SwinTransDecoder(nn.Module): 808 | def __init__(self, 809 | classes=2, 810 | embed_dim=96, 811 | norm_layer=nn.LayerNorm, 812 | img_size=224, 813 | patch_size=4, 814 | depths=[2, 2, 2, 2], 815 | num_heads=[3, 6, 12, 24], 816 | window_size=7, 817 | qkv_bias=True, 818 | qk_scale=None, 819 | drop_rate=0., 820 | attn_drop_rate=0., 821 | use_checkpoint=False, 822 | ape=True, 823 | mlp_ratio=4., 824 | drop_path_rate=0.1, 825 | final_upsample="expand_first", 826 | patches_resolution=[56, 56], 827 | encoder_channels=[256, 512, 1024, 2048]): 828 | super().__init__( 829 | ) 830 | self.patches_resolution = patches_resolution 831 | self.ape = ape 832 | self.final_upsample = final_upsample 833 | self.num_layers = len(depths) 834 | self.mlp_ratio = mlp_ratio 835 | 836 | if self.ape: 837 | self.absolute_pos_embeds = [ 838 | nn.Parameter(torch.zeros(1, 3136, 96)), 839 | nn.Parameter(torch.zeros(1, 784, 192)), 840 | nn.Parameter(torch.zeros(1, 196, 384)), 841 | nn.Parameter(torch.zeros(1, 49, 768)) 842 | ] 843 | for abs in self.absolute_pos_embeds: 844 | trunc_normal_(abs, std=.02) 845 | self.pos_drop = nn.Dropout(p=drop_rate) 846 | 847 | # build decoder layers 848 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 849 | self.layers_up = nn.ModuleList() 850 | self.concat_back_dim = nn.ModuleList() 851 | for i_layer in range(self.num_layers): 852 | concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 853 | int(embed_dim * 2 ** ( 854 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity() 855 | if i_layer == 0: 856 | layer_up = PatchExpand( 857 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 858 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 859 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer) 860 | else: 861 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 862 | input_resolution=( 863 | patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 864 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 865 | depth=depths[(self.num_layers - 1 - i_layer)], 866 | num_heads=num_heads[(self.num_layers - 1 - i_layer)], 867 | window_size=window_size, 868 | mlp_ratio=self.mlp_ratio, 869 | qkv_bias=qkv_bias, qk_scale=qk_scale, 870 | drop=drop_rate, attn_drop=attn_drop_rate, 871 | drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum( 872 | depths[:(self.num_layers - 1 - i_layer) + 1])], 873 | norm_layer=norm_layer, 874 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, 875 | use_checkpoint=use_checkpoint) 876 | self.layers_up.append(layer_up) 877 | self.concat_back_dim.append(concat_linear) 878 | 879 | self.patches_embed = nn.ModuleList([ 880 | PatchEmbed(img_size=56, patch_size=1, in_chans=encoder_channels[-4], embed_dim=96, norm_layer=norm_layer), 881 | PatchEmbed(img_size=28, patch_size=1, in_chans=encoder_channels[-3], embed_dim=192, norm_layer=norm_layer), 882 | PatchEmbed(img_size=14, patch_size=1, in_chans=encoder_channels[-2], embed_dim=384, norm_layer=norm_layer), 883 | PatchEmbed(img_size=7, patch_size=1, in_chans=encoder_channels[-1], embed_dim=768, norm_layer=norm_layer), 884 | ]) 885 | 886 | self.norm = norm_layer(768) 887 | self.norm_up = norm_layer(96) 888 | 889 | if self.final_upsample == "expand_first": 890 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), 891 | dim_scale=4, dim=embed_dim) 892 | self.output = nn.Conv2d(in_channels=embed_dim, out_channels=classes, kernel_size=1, bias=False) 893 | 894 | self.apply(self._init_weights) 895 | 896 | def _init_weights(self, m): 897 | if isinstance(m, nn.Linear): 898 | trunc_normal_(m.weight, std=.02) 899 | if isinstance(m, nn.Linear) and m.bias is not None: 900 | nn.init.constant_(m.bias, 0) 901 | elif isinstance(m, nn.LayerNorm): 902 | nn.init.constant_(m.bias, 0) 903 | nn.init.constant_(m.weight, 1.0) 904 | 905 | def forward_up_features(self, x, x_downsample): 906 | for inx, layer_up in enumerate(self.layers_up): 907 | if inx == 0: 908 | x = layer_up(x) 909 | else: 910 | x = torch.cat([x, x_downsample[3 - inx]], -1) 911 | x = self.concat_back_dim[inx](x) 912 | x = layer_up(x) 913 | x = self.norm_up(x) # B L C 914 | return x 915 | 916 | def up_x4(self, x): 917 | H, W = self.patches_resolution 918 | B, L, C = x.shape 919 | assert L == H * W, "input features has wrong size" 920 | 921 | if self.final_upsample == "expand_first": 922 | x = self.up(x) 923 | x = x.view(B, 4 * H, 4 * W, -1) 924 | x = x.permute(0, 3, 1, 2) # B,C,H,W 925 | x = self.output(x) 926 | return x 927 | 928 | def forward(self, features, device): 929 | patches = [] 930 | for i, f in enumerate(features[2:]): 931 | p_embed = self.patches_embed[i](f) 932 | if self.ape: 933 | p_embed = p_embed + self.absolute_pos_embeds[i].to(device) 934 | p_embed = self.pos_drop(p_embed) 935 | patches.append(p_embed) 936 | p = self.forward_up_features(patches[-1], patches) 937 | return self.up_x4(p) 938 | -------------------------------------------------------------------------------- /codes/models/unet.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseModel2D 2 | from typing import Optional, Union, List 3 | from segmentation_models_pytorch.encoders import get_encoder 4 | from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead 5 | from segmentation_models_pytorch.unet.decoder import UnetDecoder 6 | 7 | 8 | class Unet(SegmentationModel): 9 | 10 | def __init__( 11 | self, 12 | encoder_name: str = "resnet34", 13 | encoder_depth: int = 5, 14 | encoder_weights: Optional[str] = "imagenet", 15 | decoder_use_batchnorm: bool = True, 16 | decoder_channels: List[int] = (256, 128, 64, 32, 16), 17 | decoder_attention_type: Optional[str] = None, 18 | in_channels: int = 3, 19 | classes: int = 1, 20 | activation: Optional[Union[str, callable]] = None, 21 | aux_params: Optional[dict] = None, 22 | ): 23 | super().__init__() 24 | 25 | self.encoder = get_encoder( 26 | encoder_name, 27 | in_channels=in_channels, 28 | depth=encoder_depth, 29 | weights=encoder_weights, 30 | ) 31 | 32 | self.decoder = UnetDecoder( 33 | encoder_channels=self.encoder.out_channels, 34 | decoder_channels=decoder_channels, 35 | n_blocks=encoder_depth, 36 | use_batchnorm=decoder_use_batchnorm, 37 | center=True if encoder_name.startswith("vgg") else False, 38 | attention_type=decoder_attention_type, 39 | ) 40 | 41 | self.segmentation_head = SegmentationHead( 42 | in_channels=decoder_channels[-1], 43 | out_channels=classes, 44 | activation=activation, 45 | kernel_size=3, 46 | ) 47 | 48 | if aux_params is not None: 49 | self.classification_head = ClassificationHead( 50 | in_channels=self.encoder.out_channels[-1], **aux_params 51 | ) 52 | else: 53 | self.classification_head = None 54 | 55 | self.name = "u-{}".format(encoder_name) 56 | self.initialize() 57 | 58 | def forward_features(self, x): 59 | features = self.encoder(x) 60 | return features 61 | 62 | 63 | class UNet(BaseModel2D): 64 | 65 | def __init__(self, 66 | encoder_name="resnet34", 67 | encoder_depth=5, 68 | encoder_weights="imagenet", 69 | decoder_use_batchnorm: bool = True, 70 | decoder_channels=(256, 128, 64, 32, 16), 71 | decoder_attention_type=None, 72 | in_channels=3, 73 | classes=1, 74 | activation=None, 75 | aux_params=None): 76 | super().__init__() 77 | 78 | self.segmentor = Unet(encoder_name=encoder_name, 79 | encoder_depth=encoder_depth, 80 | encoder_weights=encoder_weights, 81 | decoder_use_batchnorm=decoder_use_batchnorm, 82 | decoder_channels=decoder_channels, 83 | decoder_attention_type=decoder_attention_type, 84 | in_channels=in_channels, 85 | classes=classes, 86 | activation=activation, 87 | aux_params=aux_params) 88 | self.num_classes = classes 89 | 90 | def forward(self, x): 91 | return {'seg': self.segmentor(x)} 92 | 93 | def inference_features(self, x, **kwargs): 94 | feats = self.segmentor.forward_features(x) 95 | return {'feats': feats} 96 | -------------------------------------------------------------------------------- /codes/models/unet_tf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from codes.models.swin_decoder import SwinTransDecoder 6 | from codes.models._base import BaseModel2D 7 | from codes.utils.init import kaiming_normal_init_weight 8 | from segmentation_models_pytorch.encoders import get_encoder 9 | from segmentation_models_pytorch.unet.decoder import UnetDecoder 10 | from segmentation_models_pytorch.base import SegmentationHead, ClassificationHead 11 | 12 | 13 | class EmbeddingHead(nn.Module): 14 | def __init__(self, dim_in, embed_dim=256, embed='convmlp'): 15 | super(EmbeddingHead, self).__init__() 16 | 17 | if embed == 'linear': 18 | self.embed = nn.Conv2d(dim_in, embed_dim, kernel_size=1) 19 | elif embed == 'convmlp': 20 | self.embed = nn.Sequential( 21 | nn.Conv2d(dim_in, dim_in, kernel_size=1), 22 | nn.BatchNorm2d(dim_in), 23 | nn.ReLU(), 24 | nn.Conv2d(dim_in, embed_dim, kernel_size=1) 25 | ) 26 | 27 | def forward(self, x): 28 | return F.normalize(self.embed(x), p=2, dim=1) 29 | 30 | 31 | class UNetTF(BaseModel2D): 32 | 33 | def __init__(self, 34 | encoder_name="resnet50", 35 | encoder_depth=5, 36 | encoder_weights="imagenet", 37 | decoder_use_batchnorm=True, 38 | decoder_channels=(256, 128, 64, 32, 16), 39 | decoder_attention_type=None, 40 | in_channels=3, 41 | classes=2, 42 | activation=None, 43 | embed_dim=96, 44 | norm_layer=nn.LayerNorm, 45 | img_size=224, 46 | patch_size=4, 47 | depths=[2, 2, 2, 2], 48 | num_heads=[3, 6, 12, 24], 49 | window_size=7, 50 | qkv_bias=True, 51 | qk_scale=None, 52 | drop_rate=0., 53 | attn_drop_rate=0., 54 | use_checkpoint=False, 55 | ape=True, 56 | cls=True, 57 | contrast_embed=False, 58 | contrast_embed_dim=256, 59 | contrast_embed_index=-3, 60 | mlp_ratio=4., 61 | drop_path_rate=0.1, 62 | final_upsample="expand_first", 63 | patches_resolution=[56, 56] 64 | ): 65 | super().__init__() 66 | self.cls = cls 67 | self.contrast_embed_index = contrast_embed_index 68 | 69 | self.encoder = get_encoder( 70 | encoder_name, 71 | in_channels=in_channels, 72 | depth=encoder_depth, 73 | weights=encoder_weights, 74 | ) 75 | encoder_channels = self.encoder.out_channels 76 | self.cnn_decoder = UnetDecoder( 77 | encoder_channels=encoder_channels, 78 | decoder_channels=decoder_channels, 79 | n_blocks=encoder_depth, 80 | use_batchnorm=decoder_use_batchnorm, 81 | center=True if encoder_name.startswith("vgg") else False, 82 | attention_type=decoder_attention_type, 83 | ) 84 | self.seg_head = SegmentationHead( 85 | in_channels=decoder_channels[-1], 86 | out_channels=classes, 87 | activation=activation, 88 | kernel_size=3, 89 | ) 90 | self.swin_decoder = SwinTransDecoder(classes, embed_dim, norm_layer, img_size, patch_size, depths, num_heads, 91 | window_size, qkv_bias, qk_scale, drop_rate, attn_drop_rate, use_checkpoint, 92 | ape, mlp_ratio, drop_path_rate, final_upsample, patches_resolution, 93 | encoder_channels) 94 | 95 | self.cls_head = ClassificationHead(in_channels=encoder_channels[-1], classes=4) if cls else None 96 | self.embed_head = EmbeddingHead(dim_in=encoder_channels[contrast_embed_index], 97 | embed_dim=contrast_embed_dim) if contrast_embed else None 98 | self._init_weights() 99 | 100 | def _init_weights(self): 101 | kaiming_normal_init_weight(self.cnn_decoder) 102 | kaiming_normal_init_weight(self.seg_head) 103 | if self.cls_head is not None: 104 | kaiming_normal_init_weight(self.cls_head) 105 | if self.embed_head is not None: 106 | kaiming_normal_init_weight(self.embed_head.embed) 107 | 108 | def forward(self, x, device): 109 | features = self.encoder(x) 110 | seg = self.seg_head(self.cnn_decoder(*features)) 111 | seg_tf = self.swin_decoder(features, device) 112 | 113 | embedding = self.embed_head(features[self.contrast_embed_index]) if self.embed_head else None 114 | cls = self.cls_head(features[-1]) if self.cls_head else None 115 | return {'seg': seg, 'seg_tf': seg_tf, 'cls': cls, 'embed': embedding} 116 | 117 | def inference(self, x, **kwargs): 118 | features = self.encoder(x) 119 | seg = self.seg_head(self.cnn_decoder(*features)) 120 | preds = torch.argmax(seg, dim=1, keepdim=True).to(torch.float) 121 | return preds 122 | 123 | def inference_features(self, x, **kwargs): 124 | features = self.encoder(x) 125 | embedding = self.embed_head(features[self.contrast_embed_index]) if self.embed_head else None 126 | return {'feats': features, 'embed': embedding} 127 | 128 | def inference_tf(self, x, device, **kwargs): 129 | features = self.encoder(x) 130 | seg_tf = self.swin_decoder(features, device) 131 | preds = torch.argmax(seg_tf, dim=1, keepdim=True).to(torch.float) 132 | return preds 133 | -------------------------------------------------------------------------------- /codes/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .poly import PolyLR 2 | -------------------------------------------------------------------------------- /codes/schedulers/poly.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolyLR(_LRScheduler): 5 | def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6): 6 | self.power = power 7 | self.max_iters = max_iters + 1 # avoid zero lr 8 | self.min_lr = min_lr 9 | self.last_epoch = last_epoch 10 | super(PolyLR, self).__init__(optimizer, last_epoch) 11 | 12 | def get_lr(self): 13 | ''' 14 | factor = pow(1 - self.last_epoch / self.max_iters, self.power) 15 | return [(base_lr) * factor for base_lr in self.base_lrs] 16 | ''' 17 | return [max(base_lr * (1.0 - self.last_epoch / self.max_iters) ** self.power, self.min_lr) 18 | for base_lr in self.base_lrs] 19 | 20 | def __str__(self): 21 | return f'PolyLR(' \ 22 | f'\n\tpower: {self.power}' \ 23 | f'\n\tmax_iters: {self.max_iters}' \ 24 | f'\n\tmin_lr: {self.min_lr}' \ 25 | f'\n\tlast_epoch: {self.last_epoch}\n)' 26 | -------------------------------------------------------------------------------- /codes/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .mt_trainer import MeanTeacherTrainer 2 | from .supervised_trainer import SupervisedTrainer 3 | from .ugpcl_trainer import UGPCLTrainer 4 | -------------------------------------------------------------------------------- /codes/trainers/_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from tqdm import tqdm 8 | from prettytable import PrettyTable 9 | from colorama import Fore 10 | from ..utils import ramps 11 | from ..builder import _build_from_cfg, build_model, build_optimizer, build_scheduler 12 | 13 | 14 | class BaseTrainer: 15 | 16 | def __init__(self, 17 | model=None, 18 | optimizer=None, 19 | scheduler=None, 20 | criterions=None, 21 | metrics=None, 22 | logger=None, 23 | device='cuda', 24 | resume_from=None, 25 | labeled_bs=12, 26 | consistency=1.0, 27 | consistency_rampup=40.0, 28 | data_parallel=False, 29 | ckpt_save_path=None, 30 | max_iter=10000, 31 | eval_interval=1000, 32 | save_image_interval=50, 33 | save_ckpt_interval=2000) -> None: 34 | super(BaseTrainer, self).__init__() 35 | self.model = None 36 | # build cfg 37 | if model is not None: 38 | self.model = build_model(model).to(device) 39 | if optimizer is not None: 40 | self.optimizer = build_optimizer(self.model.parameters(), optimizer) 41 | if scheduler is not None: 42 | self.scheduler = build_scheduler(self.optimizer, scheduler) 43 | self.criterions = [] 44 | if criterions is not None: 45 | for criterion_cfg in criterions: 46 | self.criterions.append(_build_from_cfg(criterion_cfg)) 47 | self.metrics = [] 48 | if metrics is not None: 49 | for metric_cfg in metrics: 50 | self.metrics.append(_build_from_cfg(metric_cfg)) 51 | 52 | # semi-supervised params 53 | self.labeled_bs = labeled_bs 54 | self.consistency = consistency 55 | self.consistency_rampup = consistency_rampup 56 | 57 | # train params 58 | self.logger = logger 59 | self.device = device 60 | self.data_parallel = data_parallel 61 | self.ckpt_save_path = ckpt_save_path 62 | 63 | self.max_iter = max_iter 64 | self.eval_interval = eval_interval 65 | self.save_image_interval = save_image_interval 66 | self.save_ckpt_interval = save_ckpt_interval 67 | 68 | if self.data_parallel: 69 | self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1]) 70 | 71 | if resume_from is not None: 72 | ckpt = torch.load(resume_from) 73 | model.load_state_dict(ckpt['state_dict']) 74 | optimizer.load_state_dict(ckpt['optimizer']) 75 | for state in optimizer.state.values(): 76 | for k, v in state.items(): 77 | if isinstance(v, torch.Tensor): 78 | state[k] = v.cuda() 79 | scheduler.load_state_dict(ckpt['scheduler']) 80 | self.start_step = ckpt['step'] 81 | 82 | logger.info(f'Resume from {resume_from}.') 83 | logger.info(f'Train from step {self.start_step}.') 84 | else: 85 | self.start_step = 0 86 | if self.model is not None: 87 | logger.info(f'\n{self.model}\n') 88 | 89 | logger.info(f'start training...') 90 | 91 | @abstractmethod 92 | def train_step(self, batch_data, step, save_image): 93 | loss = 0. 94 | log_infos, scalars, images = {}, {}, {} 95 | return loss, log_infos, scalars, images 96 | 97 | def val_step(self, batch_data): 98 | data, labels = batch_data['image'].to(self.device), batch_data['label'].to(self.device) 99 | preds = self.model.inference(data) 100 | metric_total_res = {} 101 | for metric in self.metrics: 102 | metric_total_res[metric.name] = metric(preds, labels) 103 | return metric_total_res 104 | 105 | def train(self, train_loader, val_loader): 106 | # iter_train_loader = iter(train_loader) 107 | max_epoch = self.max_iter // len(train_loader) + 1 108 | step = self.start_step 109 | self.model.train() 110 | with tqdm(total=self.max_iter - self.start_step, bar_format='[{elapsed}<{remaining}] ') as pbar: 111 | for _ in range(max_epoch): 112 | for batch_data in train_loader: 113 | save_image = True if (step + 1) % self.save_image_interval == 0 else False 114 | 115 | loss, log_infos, scalars, images = self.train_step(batch_data, step, save_image) 116 | 117 | self.optimizer.zero_grad() 118 | loss.backward() 119 | self.optimizer.step() 120 | self.scheduler.step() 121 | 122 | if (step + 1) % 10 == 0: 123 | scalars.update({'lr': self.scheduler.get_lr()[0]}) 124 | log_infos.update({'lr': self.scheduler.get_lr()[0]}) 125 | self.logger.update_scalars(scalars, step + 1) 126 | self.logger.info(f'[{step + 1}/{self.max_iter}] {log_infos}') 127 | 128 | if save_image: 129 | self.logger.update_images(images, step + 1) 130 | 131 | if (step + 1) % self.eval_interval == 0: 132 | if val_loader is not None: 133 | val_res, val_scalars, val_table = self.val(val_loader) 134 | self.logger.info(f'val result:\n{val_table.get_string()}') 135 | self.logger.update_scalars(val_scalars, step + 1) 136 | self.model.train() 137 | 138 | if (step + 1) % self.save_ckpt_interval == 0: 139 | if not os.path.exists(self.ckpt_save_path): 140 | os.makedirs(self.ckpt_save_path) 141 | self.save_ckpt(step + 1, f'{self.ckpt_save_path}/iter_{step + 1}.pth') 142 | step += 1 143 | pbar.update(1) 144 | if step >= self.max_iter: 145 | break 146 | if step >= self.max_iter: 147 | break 148 | 149 | if not os.path.exists(self.ckpt_save_path): 150 | os.makedirs(self.ckpt_save_path) 151 | torch.save(self.model.state_dict(), f'{self.ckpt_save_path}/ckpt_final.pth') 152 | 153 | @torch.no_grad() 154 | def val(self, val_loader, test=False): 155 | self.model.eval() 156 | val_res = None 157 | val_scalars = {} 158 | if self.logger is not None: 159 | self.logger.info('Evaluating...') 160 | if test: 161 | val_loader = tqdm(val_loader, desc='Testing', unit='batch', 162 | bar_format='%s{l_bar}{bar}{r_bar}%s' % (Fore.LIGHTCYAN_EX, Fore.RESET)) 163 | for batch_data in val_loader: 164 | batch_res = self.val_step(batch_data) # {'Dice':{'c1':0.1, 'c2':0.1, ...}, ...} 165 | if val_res is None: 166 | val_res = batch_res 167 | else: 168 | for metric_name in val_res.keys(): 169 | for key in val_res[metric_name].keys(): 170 | val_res[metric_name][key] += batch_res[metric_name][key] 171 | for metric_name in val_res.keys(): 172 | for key in val_res[metric_name].keys(): 173 | val_res[metric_name][key] = val_res[metric_name][key] / len(val_loader) 174 | val_scalars[f'val/{metric_name}.{key}'] = val_res[metric_name][key] 175 | 176 | val_res_list = [_.cpu() for _ in val_res[metric_name].values()] 177 | val_res[metric_name]['Mean'] = np.mean(val_res_list[1:]) 178 | val_scalars[f'val/{metric_name}.Mean'] = val_res[metric_name]['Mean'] 179 | 180 | val_table = PrettyTable() 181 | val_table.field_names = ['Metirc'] + list(list(val_res.values())[0].keys()) 182 | for metric_name in val_res.keys(): 183 | if metric_name in ['Dice', 'Jaccard', 'Acc', 'IoU', 'Recall', 'Precision']: 184 | temp = [float(format(_ * 100, '.2f')) for _ in val_res[metric_name].values()] 185 | else: 186 | temp = [float(format(_, '.2f')) for _ in val_res[metric_name].values()] 187 | val_table.add_row([metric_name] + temp) 188 | return val_res, val_scalars, val_table 189 | 190 | def get_current_consistency_weight(self, epoch): 191 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 192 | return self.consistency * ramps.sigmoid_rampup(epoch, self.consistency_rampup) 193 | 194 | def save_ckpt(self, step, save_path): 195 | ckpt = {'state_dict': self.model.state_dict(), 196 | 'optimizer': self.optimizer.state_dict(), 197 | 'scheduler': self.scheduler.state_dict(), 198 | 'step': step} 199 | torch.save(ckpt, save_path) 200 | self.logger.info('Checkpoint saved!') 201 | -------------------------------------------------------------------------------- /codes/trainers/mt_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. 3 | Link: https://proceedings.neurips.cc/paper/2017/file/68053af2923e00204c3ca7c6a3150cf7-Paper.pdf 4 | Code modified from: https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_mean_teacher_2D.py 5 | """ 6 | 7 | import os 8 | import torch 9 | 10 | from tqdm import tqdm 11 | from torchvision.utils import make_grid 12 | from ._base import BaseTrainer 13 | from ..builder import build_model 14 | 15 | 16 | class MeanTeacherTrainer(BaseTrainer): 17 | 18 | def __init__(self, model=None, 19 | ema_model=None, 20 | optimizer=None, 21 | scheduler=None, 22 | criterions=None, 23 | metrics=None, 24 | logger=None, 25 | device='cuda', 26 | resume_from=None, 27 | labeled_bs=12, 28 | alpha=0.99, 29 | consistency=1.0, 30 | consistency_rampup=40.0, 31 | data_parallel=False, 32 | ckpt_save_path=None, 33 | max_iter=60000, 34 | eval_interval=1000, 35 | save_image_interval=50, 36 | save_ckpt_interval=2000) -> None: 37 | 38 | super().__init__(model, optimizer, scheduler, criterions, metrics, logger, device, resume_from, labeled_bs, 39 | consistency, consistency_rampup, data_parallel, ckpt_save_path, max_iter, eval_interval, 40 | save_image_interval, save_ckpt_interval) 41 | 42 | self.ema_model = build_model(ema_model).to(device) 43 | for param in self.ema_model.parameters(): 44 | param.detach_() 45 | 46 | self.alpha = alpha 47 | 48 | def train_step(self, batch_data, step, save_image): 49 | log_infos, scalars, images = {}, {}, {} 50 | data, label = batch_data['image'].to(self.device), batch_data['label'].to(self.device) 51 | 52 | unlabeled_data = data[self.labeled_bs:] 53 | 54 | noise = torch.clamp(torch.randn_like(unlabeled_data) * 0.1, -0.2, 0.2) 55 | ema_inputs = unlabeled_data + noise 56 | 57 | outputs = self.model(data) 58 | outputs_soft = torch.softmax(outputs['seg'], dim=1) 59 | with torch.no_grad(): 60 | ema_output = self.ema_model(ema_inputs) 61 | ema_output_soft = torch.softmax(ema_output['seg'], dim=1) 62 | 63 | supervised_loss = 0. 64 | for criterion in self.criterions: 65 | loss_ = criterion(outputs['seg'][:self.labeled_bs], label[:self.labeled_bs]) 66 | supervised_loss += loss_ 67 | log_infos[criterion.name] = float(format(loss_, '.5f')) 68 | scalars[f'loss/{criterion.name}'] = loss_ 69 | 70 | consistency_weight = self.get_current_consistency_weight(step // 150) 71 | if step < 1000: 72 | consistency_loss = 0.0 73 | else: 74 | consistency_loss = torch.mean((outputs_soft[self.labeled_bs:] - ema_output_soft) ** 2) 75 | 76 | loss = supervised_loss + consistency_weight * consistency_loss 77 | 78 | log_infos['con_weight'] = float(format(consistency_weight, '.5f')) 79 | log_infos['loss_con'] = float(format(consistency_loss, '.5f')) 80 | log_infos['loss'] = float(format(loss, '.5f')) 81 | scalars['consistency_weight'] = consistency_weight 82 | scalars['loss/loss_consistency'] = consistency_loss 83 | scalars['loss/total'] = loss 84 | 85 | preds = torch.argmax(outputs['seg'], dim=1, keepdim=True).to(torch.float) 86 | 87 | metric_res = self.metrics[0](preds, label) 88 | for key in metric_res.keys(): 89 | log_infos[f'{self.metrics[0].name}.{key}'] = float(format(metric_res[key], '.5f')) 90 | scalars[f'train/{self.metrics[0].name}.{key}'] = metric_res[key] 91 | 92 | if save_image: 93 | grid_image = make_grid(data, 4, normalize=True) 94 | images['train/images'] = grid_image 95 | grid_image = make_grid(preds * 50., 4, normalize=False) 96 | images['train/preds'] = grid_image 97 | grid_image = make_grid(label * 50., 4, normalize=False) 98 | images['train/ground_truth'] = grid_image 99 | return loss, log_infos, scalars, images 100 | 101 | def train(self, train_loader, val_loader): 102 | max_epoch = self.max_iter // len(train_loader) + 1 103 | step = self.start_step 104 | with tqdm(total=self.max_iter - self.start_step, bar_format='[{elapsed}<{remaining}] ') as pbar: 105 | for _ in range(max_epoch): 106 | for batch_data in train_loader: 107 | save_image = True if (step + 1) % self.save_image_interval == 0 else False 108 | 109 | loss, log_infos, scalars, images = self.train_step(batch_data, step, save_image) 110 | 111 | self.optimizer.zero_grad() 112 | loss.backward() 113 | self.optimizer.step() 114 | self.scheduler.step() 115 | 116 | self.update_ema_variables(step) 117 | 118 | if (step + 1) % 10 == 0: 119 | scalars.update({'lr': self.scheduler.get_lr()[0]}) 120 | log_infos.update({'lr': self.scheduler.get_lr()[0]}) 121 | self.logger.update_scalars(scalars, step + 1) 122 | self.logger.info(f'[{step + 1}/{self.max_iter}] {log_infos}') 123 | 124 | if save_image: 125 | self.logger.update_images(images, step + 1) 126 | 127 | if (step + 1) % self.eval_interval == 0: 128 | if val_loader is not None: 129 | val_res, val_scalars, val_table = self.val(val_loader) 130 | self.logger.info(f'val result:\n{val_table.get_string()}') 131 | self.logger.update_scalars(val_scalars, step + 1) 132 | self.model.train() 133 | 134 | if (step + 1) % self.save_ckpt_interval == 0: 135 | if not os.path.exists(self.ckpt_save_path): 136 | os.makedirs(self.ckpt_save_path) 137 | self.save_ckpt(step + 1, f'{self.ckpt_save_path}/iter_{step + 1}.pth') 138 | step += 1 139 | pbar.update(1) 140 | if step >= self.max_iter: 141 | break 142 | if step >= self.max_iter: 143 | break 144 | 145 | if not os.path.exists(self.ckpt_save_path): 146 | os.makedirs(self.ckpt_save_path) 147 | torch.save(self.model.state_dict(), f'{self.ckpt_save_path}/ckpt_final.pth') 148 | 149 | def update_ema_variables(self, global_step): 150 | alpha = min(1 - 1 / (global_step + 1), self.alpha) 151 | for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()): 152 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 153 | -------------------------------------------------------------------------------- /codes/trainers/supervised_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ._base import BaseTrainer 3 | from torchvision.utils import make_grid 4 | 5 | 6 | class SupervisedTrainer(BaseTrainer): 7 | 8 | def __init__(self, 9 | model=None, 10 | optimizer=None, 11 | scheduler=None, 12 | criterions=None, 13 | metrics=None, 14 | logger=None, 15 | device='cuda', 16 | resume_from=None, 17 | labeled_bs=12, 18 | data_parallel=False, 19 | ckpt_save_path=None, 20 | max_iter=10000, 21 | eval_interval=1000, 22 | save_image_interval=50, 23 | save_ckpt_interval=2000) -> None: 24 | 25 | super().__init__(model, optimizer, scheduler, criterions, metrics, logger, device, resume_from, labeled_bs, 26 | 1.0, 40.0, data_parallel, ckpt_save_path, max_iter, eval_interval, 27 | save_image_interval, save_ckpt_interval) 28 | 29 | def train_step(self, batch_data, step, save_image): 30 | log_infos, scalars, images = {}, {}, {} 31 | data, label = batch_data['image'].to(self.device), batch_data['label'].to(self.device) 32 | logits = self.model(data) 33 | loss = 0. 34 | for criterion in self.criterions: 35 | loss_ = criterion(logits['seg'][:self.labeled_bs], label[:self.labeled_bs]) 36 | loss += loss_ 37 | log_infos[criterion.name] = float(format(loss_, '.5f')) 38 | scalars[f'loss/{criterion.name}'] = loss_ 39 | 40 | log_infos['loss'] = float(format(loss, '.5f')) 41 | scalars['loss/total'] = loss 42 | preds = torch.argmax(torch.softmax(logits['seg'], dim=1), dim=1, keepdim=True).to(torch.float) 43 | 44 | metric_res = self.metrics[0](preds, label) 45 | for key in metric_res.keys(): 46 | log_infos[f'{self.metrics[0].name}.{key}'] = float(format(metric_res[key], '.5f')) 47 | scalars[f'train/{self.metrics[0].name}.{key}'] = metric_res[key] 48 | 49 | images = {} 50 | if save_image: 51 | grid_image = make_grid(data[:self.labeled_bs], 4, normalize=True) 52 | images['train/image'] = grid_image 53 | grid_image = make_grid(preds[:self.labeled_bs] * 50., 4, normalize=False) 54 | images['train/pred'] = grid_image 55 | grid_image = make_grid(label[:self.labeled_bs] * 50., 4, normalize=False) 56 | images['train/label'] = grid_image 57 | return loss, log_infos, scalars, images -------------------------------------------------------------------------------- /codes/trainers/ugpcl_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import torchvision.transforms.functional as TF 7 | 8 | from torch import nn 9 | from tqdm import tqdm 10 | from prettytable import PrettyTable 11 | from colorama import Fore 12 | from torchvision.utils import make_grid 13 | from ._base import BaseTrainer 14 | from ..utils import ramps 15 | from ..losses import PixelContrastLoss 16 | 17 | 18 | class UGPCLTrainer(BaseTrainer): 19 | 20 | def __init__(self, 21 | model=None, 22 | optimizer=None, 23 | scheduler=None, 24 | criterions=None, 25 | metrics=None, 26 | logger=None, 27 | device='cuda', 28 | resume_from=None, 29 | labeled_bs=8, 30 | data_parallel=False, 31 | ckpt_save_path=None, 32 | max_iter=6000, 33 | eval_interval=1000, 34 | save_image_interval=50, 35 | save_ckpt_interval=2000, 36 | consistency=0.1, 37 | consistency_rampup=40.0, 38 | tf_decoder_weight=0.4, 39 | cls_weight=0.1, 40 | contrast_type='ugpcl', # ugpcl, pseudo, sup, none 41 | contrast_weight=0.1, 42 | temperature=0.1, 43 | base_temperature=0.07, 44 | max_samples=1024, 45 | max_views=1, 46 | memory=True, 47 | memory_size=100, 48 | pixel_update_freq=10, 49 | pixel_classes=4, 50 | dim=256) -> None: 51 | 52 | super(UGPCLTrainer, self).__init__(model, optimizer, scheduler, criterions, metrics, logger, device, 53 | resume_from, labeled_bs, consistency, consistency_rampup, data_parallel, 54 | ckpt_save_path, max_iter, eval_interval, save_image_interval, 55 | save_ckpt_interval) 56 | 57 | self.tf_decoder_weight = tf_decoder_weight 58 | self.cls_weight = cls_weight 59 | self.cls_criterion = torch.nn.CrossEntropyLoss() 60 | 61 | self.contrast_type = contrast_type 62 | self.contrast_weight = contrast_weight 63 | self.contrast_criterion = PixelContrastLoss(temperature=temperature, 64 | base_temperature=base_temperature, 65 | max_samples=max_samples, 66 | max_views=max_views, 67 | device=device) 68 | # memory param 69 | self.memory = memory 70 | self.memory_size = memory_size 71 | self.pixel_update_freq = pixel_update_freq 72 | 73 | if self.memory: 74 | self.segment_queue = torch.randn(pixel_classes, self.memory_size, dim) 75 | self.segment_queue = nn.functional.normalize(self.segment_queue, p=2, dim=2) 76 | self.segment_queue_ptr = torch.zeros(pixel_classes, dtype=torch.long) 77 | self.pixel_queue = torch.zeros(pixel_classes, self.memory_size, dim) 78 | self.pixel_queue = nn.functional.normalize(self.pixel_queue, p=2, dim=2) 79 | self.pixel_queue_ptr = torch.zeros(pixel_classes, dtype=torch.long) 80 | 81 | def _dequeue_and_enqueue(self, keys, labels): 82 | batch_size = keys.shape[0] 83 | feat_dim = keys.shape[1] 84 | 85 | labels = torch.nn.functional.interpolate(labels, (keys.shape[2], keys.shape[3]), mode='nearest') 86 | 87 | for bs in range(batch_size): 88 | this_feat = keys[bs].contiguous().view(feat_dim, -1) 89 | this_label = labels[bs].contiguous().view(-1) 90 | this_label_ids = torch.unique(this_label) 91 | this_label_ids = [x for x in this_label_ids if x > 0] 92 | for lb in this_label_ids: 93 | idxs = (this_label == lb).nonzero() 94 | lb = int(lb.item()) 95 | # segment enqueue and dequeue 96 | feat = torch.mean(this_feat[:, idxs], dim=1).squeeze(1) 97 | ptr = int(self.segment_queue_ptr[lb]) 98 | self.segment_queue[lb, ptr, :] = nn.functional.normalize(feat.view(-1), p=2, dim=0) 99 | self.segment_queue_ptr[lb] = (self.segment_queue_ptr[lb] + 1) % self.memory_size 100 | 101 | # pixel enqueue and dequeue 102 | num_pixel = idxs.shape[0] 103 | perm = torch.randperm(num_pixel) 104 | K = min(num_pixel, self.pixel_update_freq) 105 | feat = this_feat[:, perm[:K]] 106 | feat = torch.transpose(feat, 0, 1) 107 | ptr = int(self.pixel_queue_ptr[lb]) 108 | 109 | if ptr + K >= self.memory_size: 110 | self.pixel_queue[lb, -K:, :] = nn.functional.normalize(feat, p=2, dim=1) 111 | self.pixel_queue_ptr[lb] = 0 112 | else: 113 | self.pixel_queue[lb, ptr:ptr + K, :] = nn.functional.normalize(feat, p=2, dim=1) 114 | self.pixel_queue_ptr[lb] = (self.pixel_queue_ptr[lb] + 1) % self.memory_size 115 | 116 | @staticmethod 117 | def _random_rotate(image, label): 118 | angle = float(torch.empty(1).uniform_(-20., 20.).item()) 119 | image = TF.rotate(image, angle) 120 | label = TF.rotate(label, angle) 121 | return image, label 122 | 123 | def train_step(self, batch_data, step, save_image): 124 | log_infos, scalars = {}, {} 125 | images = {} 126 | data_, label_ = batch_data['image'].to(self.device), batch_data['label'].to(self.device) 127 | # data, label = self._random_aug(data_, label_) 128 | if self.cls_weight >= 0.: 129 | images_, labels_ = [], [] 130 | cls_label = [] 131 | for image, label in zip(data_, label_): 132 | rot_times = random.randrange(0, 4) 133 | cls_label.append(rot_times) 134 | image = torch.rot90(image, rot_times, [1, 2]) 135 | label = torch.rot90(label, rot_times, [1, 2]) 136 | image, label = self._random_rotate(image, label) 137 | images_.append(image) 138 | labels_.append(label) 139 | cls_label = torch.tensor(cls_label).to(self.device) 140 | data = torch.stack(images_, dim=0).to(self.device) 141 | label = torch.stack(labels_, dim=0).to(self.device) 142 | else: 143 | data = data_ 144 | label = label_ 145 | cls_label = None 146 | 147 | outputs = self.model(data, self.device) 148 | seg = outputs['seg'] 149 | seg_tf = outputs['seg_tf'] 150 | 151 | supervised_loss = 0. 152 | for criterion in self.criterions: 153 | loss_ = criterion(seg[:self.labeled_bs], label[:self.labeled_bs]) + \ 154 | self.tf_decoder_weight * criterion(seg_tf[:self.labeled_bs], label[:self.labeled_bs]) 155 | supervised_loss += loss_ 156 | log_infos[criterion.name] = float(format(loss_, '.5f')) 157 | scalars[f'loss/{criterion.name}'] = loss_ 158 | 159 | loss_cls = self.cls_criterion(outputs['cls'], cls_label) if self.cls_weight > 0. else 0. 160 | 161 | seg_soft = torch.softmax(seg, dim=1) 162 | seg_tf_soft = torch.softmax(seg_tf, dim=1) 163 | consistency_weight = self.get_current_consistency_weight(step // 100) 164 | consistency_loss = torch.mean((seg_soft[self.labeled_bs:] - seg_tf_soft[self.labeled_bs:]) ** 2) 165 | 166 | loss = supervised_loss + consistency_weight * consistency_loss + self.cls_weight * loss_cls 167 | 168 | log_infos['loss_cls'] = float(format(loss_cls, '.5f')) 169 | log_infos['con_weight'] = float(format(consistency_weight, '.5f')) 170 | log_infos['loss_con'] = float(format(consistency_loss, '.5f')) 171 | log_infos['loss'] = float(format(loss, '.5f')) 172 | scalars['loss/loss_cls'] = loss_cls 173 | scalars['consistency_weight'] = consistency_weight 174 | scalars['loss/loss_consistency'] = consistency_loss 175 | scalars['loss/total'] = loss 176 | 177 | preds = torch.argmax(seg_soft, dim=1, keepdim=True).to(torch.float) 178 | 179 | log_infos['loss_contrast'] = 0. 180 | scalars['loss/contrast'] = 0. 181 | if step > 1000 and self.contrast_weight > 0.: 182 | # queue = torch.cat((self.segment_queue, self.pixel_queue), dim=1) if self.memory else None 183 | queue = self.segment_queue if self.memory else None 184 | if self.contrast_type == 'ugpcl': 185 | seg_mean = torch.mean(torch.stack([F.softmax(seg, dim=1), F.softmax(seg_tf, dim=1)]), dim=0) 186 | uncertainty = -1.0 * torch.sum(seg_mean * torch.log(seg_mean + 1e-6), dim=1, keepdim=True) 187 | threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(step, self.max_iter)) * np.log(2) 188 | uncertainty_mask = (uncertainty > threshold) 189 | mean_preds = torch.argmax(F.softmax(seg_mean, dim=1).detach(), dim=1, keepdim=True).float() 190 | certainty_pseudo = mean_preds.clone() 191 | certainty_pseudo[uncertainty_mask] = -1 192 | certainty_pseudo[:self.labeled_bs] = label[:self.labeled_bs] 193 | contrast_loss = self.contrast_criterion(outputs['embed'], certainty_pseudo, preds, queue=queue) 194 | scalars['uncertainty_rate'] = torch.sum(uncertainty_mask == True) / \ 195 | (torch.sum(uncertainty_mask == True) + torch.sum( 196 | uncertainty_mask == False)) 197 | if self.memory: 198 | self._dequeue_and_enqueue(outputs['embed'].detach(), certainty_pseudo.detach()) 199 | if save_image: 200 | grid_image = make_grid(mean_preds * 50., 4, normalize=False) 201 | images['train/mean_preds'] = grid_image 202 | grid_image = make_grid(certainty_pseudo * 50., 4, normalize=False) 203 | images['train/certainty_pseudo'] = grid_image 204 | grid_image = make_grid(uncertainty, 4, normalize=False) 205 | images['train/uncertainty'] = grid_image 206 | grid_image = make_grid(uncertainty_mask.float(), 4, normalize=False) 207 | images['train/uncertainty_mask'] = grid_image 208 | elif self.contrast_type == 'pseudo': 209 | contrast_loss = self.contrast_criterion(outputs['embed'], preds.detach(), preds, queue=queue) 210 | if self.memory: 211 | self._dequeue_and_enqueue(outputs['embed'].detach(), preds.detach()) 212 | elif self.contrast_type == 'sup': 213 | contrast_loss = self.contrast_criterion(outputs['embed'][:self.labeled_bs], label[:self.labeled_bs], 214 | preds[:self.labeled_bs], queue=queue) 215 | if self.memory: 216 | self._dequeue_and_enqueue(outputs['embed'].detach()[:self.labeled_bs], 217 | label.detach()[:self.labeled_bs]) 218 | else: 219 | contrast_loss = 0. 220 | loss += self.contrast_weight * contrast_loss 221 | log_infos['loss_contrast'] = float(format(contrast_loss, '.5f')) 222 | scalars['loss/contrast'] = contrast_loss 223 | 224 | tf_preds = torch.argmax(seg_tf_soft, dim=1, keepdim=True).to(torch.float) 225 | metric_res = self.metrics[0](preds, label) 226 | for key in metric_res.keys(): 227 | log_infos[f'{self.metrics[0].name}.{key}'] = float(format(metric_res[key], '.5f')) 228 | scalars[f'train/{self.metrics[0].name}.{key}'] = metric_res[key] 229 | 230 | if save_image: 231 | grid_image = make_grid(data, 4, normalize=True) 232 | images['train/images'] = grid_image 233 | grid_image = make_grid(preds * 50., 4, normalize=False) 234 | images['train/preds'] = grid_image 235 | grid_image = make_grid(tf_preds * 50., 4, normalize=False) 236 | images['train/tf_preds'] = grid_image 237 | grid_image = make_grid(label * 50., 4, normalize=False) 238 | images['train/labels'] = grid_image 239 | 240 | return loss, log_infos, scalars, images 241 | 242 | def val_step(self, batch_data): 243 | data, labels = batch_data['image'].to(self.device), batch_data['label'].to(self.device) 244 | preds = self.model.inference(data) 245 | metric_total_res = {} 246 | for metric in self.metrics: 247 | metric_total_res[metric.name] = metric(preds, labels) 248 | return metric_total_res 249 | 250 | def val_step_tf(self, batch_data): 251 | data, labels = batch_data['image'].to(self.device), batch_data['label'].to(self.device) 252 | preds = self.model.inference_tf(data, self.device) 253 | metric_total_res = {} 254 | for metric in self.metrics: 255 | metric_total_res[metric.name] = metric(preds, labels) 256 | return metric_total_res 257 | 258 | @torch.no_grad() 259 | def val_tf(self, val_loader, test=False): 260 | self.model.eval() 261 | val_res = None 262 | val_scalars = {} 263 | if self.logger is not None: 264 | self.logger.info('Evaluating...') 265 | if test: 266 | val_loader = tqdm(val_loader, desc='Testing', unit='batch', 267 | bar_format='%s{l_bar}{bar}{r_bar}%s' % (Fore.LIGHTCYAN_EX, Fore.RESET)) 268 | for batch_data in val_loader: 269 | batch_res = self.val_step_tf(batch_data) # {'Dice':{'c1':0.1, 'c2':0.1, ...}, ...} 270 | if val_res is None: 271 | val_res = batch_res 272 | else: 273 | for metric_name in val_res.keys(): 274 | for key in val_res[metric_name].keys(): 275 | val_res[metric_name][key] += batch_res[metric_name][key] 276 | for metric_name in val_res.keys(): 277 | for key in val_res[metric_name].keys(): 278 | val_res[metric_name][key] = val_res[metric_name][key] / len(val_loader) 279 | val_scalars[f'val_tf/{metric_name}.{key}'] = val_res[metric_name][key] 280 | 281 | val_res_list = [_.cpu() for _ in val_res[metric_name].values()] 282 | val_res[metric_name]['Mean'] = np.mean(val_res_list[1:]) 283 | val_scalars[f'val_tf/{metric_name}.Mean'] = val_res[metric_name]['Mean'] 284 | 285 | val_table = PrettyTable() 286 | val_table.field_names = ['Metirc'] + list(list(val_res.values())[0].keys()) 287 | for metric_name in val_res.keys(): 288 | if metric_name in ['Dice', 'Jaccard', 'Acc', 'IoU', 'Recall', 'Precision']: 289 | temp = [float(format(_ * 100, '.2f')) for _ in val_res[metric_name].values()] 290 | else: 291 | temp = [float(format(_, '.2f')) for _ in val_res[metric_name].values()] 292 | val_table.add_row([metric_name] + temp) 293 | return val_res, val_scalars, val_table 294 | 295 | def train(self, train_loader, val_loader): 296 | # iter_train_loader = iter(train_loader) 297 | max_epoch = self.max_iter // len(train_loader) + 1 298 | step = self.start_step 299 | self.model.train() 300 | with tqdm(total=self.max_iter - self.start_step, bar_format='[{elapsed}<{remaining}] ') as pbar: 301 | for _ in range(max_epoch): 302 | for batch_data in train_loader: 303 | save_image = True if (step + 1) % self.save_image_interval == 0 else False 304 | 305 | loss, log_infos, scalars, images = self.train_step(batch_data, step, save_image) 306 | 307 | self.optimizer.zero_grad() 308 | loss.backward() 309 | self.optimizer.step() 310 | self.scheduler.step() 311 | 312 | if (step + 1) % 10 == 0: 313 | scalars.update({'lr': self.scheduler.get_lr()[0]}) 314 | log_infos.update({'lr': self.scheduler.get_lr()[0]}) 315 | self.logger.update_scalars(scalars, step + 1) 316 | self.logger.info(f'[{step + 1}/{self.max_iter}] {log_infos}') 317 | 318 | if save_image: 319 | self.logger.update_images(images, step + 1) 320 | 321 | if (step + 1) % self.eval_interval == 0: 322 | if val_loader is not None: 323 | val_res, val_scalars, val_table = self.val(val_loader) 324 | self.logger.info(f'val result:\n{val_table.get_string()}') 325 | self.logger.update_scalars(val_scalars, step + 1) 326 | self.model.train() 327 | 328 | val_res, val_scalars, val_table = self.val_tf(val_loader) 329 | self.logger.info(f'val_tf result:\n{val_table.get_string()}') 330 | self.logger.update_scalars(val_scalars, step + 1) 331 | self.model.train() 332 | 333 | if (step + 1) % self.save_ckpt_interval == 0: 334 | if not os.path.exists(self.ckpt_save_path): 335 | os.makedirs(self.ckpt_save_path) 336 | self.save_ckpt(step + 1, f'{self.ckpt_save_path}/iter_{step + 1}.pth') 337 | step += 1 338 | pbar.update(1) 339 | if step >= self.max_iter: 340 | break 341 | if step >= self.max_iter: 342 | break 343 | 344 | if not os.path.exists(self.ckpt_save_path): 345 | os.makedirs(self.ckpt_save_path) 346 | torch.save(self.model.state_dict(), f'{self.ckpt_save_path}/ckpt_final.pth') 347 | -------------------------------------------------------------------------------- /codes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/codes/utils/__init__.py -------------------------------------------------------------------------------- /codes/utils/analyze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import cv2 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | from torchvision.transforms import transforms 8 | from codes.builder import build_model 9 | from codes.utils.utils import parse_yaml, Namespace 10 | 11 | color1 = (102, 253, 204, 1.) 12 | color2 = (255, 255, 102, 1.) 13 | color3 = (255, 255, 255, 1.) 14 | 15 | 16 | def _color(img, rgba=(102, 253, 204, 1.)): 17 | h, w, c = img.shape 18 | rgba = list(rgba) 19 | rgba[0] /= 255. 20 | rgba[1] /= 255. 21 | rgba[2] /= 255. 22 | img = img.tolist() 23 | for i in range(h): 24 | for j in range(w): 25 | if img[i][j] == [1., 1., 1., 1.]: 26 | img[i][j] = rgba 27 | return np.array(img) 28 | 29 | 30 | def mixed_color(img1, img2): 31 | h, w, c = img1.shape 32 | for i in range(h): 33 | for j in range(w): 34 | if (img1[i][j] == [0., 0., 0., 1.]).all() or (img2[i][j] == [0., 0., 0., 1.]).all(): 35 | img1[i][j] = img1[i][j] + img2[i][j] 36 | img1[i][j][-1] = 1. 37 | else: 38 | img1[i][j] = img1[i][j] * 0.5 + img2[i][j] * 0.5 39 | return np.array(img1) 40 | 41 | 42 | def cal_dice(pred, gt): 43 | intersection = (pred * gt).sum() 44 | return (2 * intersection) / (pred.sum() + gt.sum()).item() 45 | 46 | 47 | def show_pred_gt(model, 48 | device='cuda', 49 | img_size=(3, 256, 256), 50 | datasets=r'D:\datasets\Breast\large\val', 51 | output_path=r'../analyze_results/hardmseg'): 52 | mean = [123.675, 116.28, 103.53] 53 | std = [58.395, 57.12, 57.375] 54 | if not os.path.exists(output_path): 55 | os.makedirs(output_path) 56 | model = model.to(device) 57 | img_path = os.path.join(datasets, 'images') 58 | mask_path = os.path.join(datasets, 'masks') 59 | names = os.listdir(img_path) 60 | dice_list = [] 61 | for name in tqdm(names): 62 | if img_size[0] < 3: 63 | img = cv2.imread(os.path.join(img_path, name), cv2.IMREAD_GRAYSCALE) 64 | img = img[:np.newaxis] 65 | else: 66 | img = cv2.imread(os.path.join(img_path, name), cv2.IMREAD_COLOR) 67 | h, w, _ = img.shape 68 | img = cv2.resize(img, (img_size[1], img_size[2])) 69 | img = torch.tensor(img.transpose(2, 0, 1)).unsqueeze(0).to(device, torch.float32) 70 | img = transforms.Normalize(std=std, mean=mean)(img) 71 | img = torch.cat([img, img], dim=0) 72 | with torch.no_grad(): 73 | pred = model.inference(img)[0].argmax(0) 74 | pred = pred.view(img_size[1], img_size[1]).cpu().detach().numpy().astype(np.float32) 75 | mask = cv2.imread(os.path.join(mask_path, name.replace('jpg', 'png')), cv2.IMREAD_GRAYSCALE) 76 | dice = cal_dice(pred, mask / 255) 77 | dice_list.append(dice) 78 | mask = cv2.resize(mask, (img_size[1], img_size[2])) 79 | mask = (mask != 0).astype(np.float32) 80 | pred = cv2.cvtColor(pred, cv2.COLOR_GRAY2BGRA) 81 | mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGRA) 82 | plt.imsave(os.path.join(output_path, f'{int(dice * 10000)}_{name}'), 83 | _color(mask, color3) * 0.5 + _color(pred, color1) * 0.5) 84 | # plt.imsave(os.path.join(output_path, name), mixed_color(_color(mask, color1), _color(pred, color3))) 85 | f = open(f'{output_path}/result.txt', 'w') 86 | for dice, name in zip(dice_list, names): 87 | f.write(f'{name}:\t{dice * 100:.2f}\n') 88 | f.write(f'average:\t{np.mean(dice_list) * 100:.2f}\n') 89 | f.close() 90 | 91 | 92 | if __name__ == '__main__': 93 | args_dict = parse_yaml(r'F:\projects\PyCharmProjects\SSLSeg\configs\unet_breast_mri.yaml') 94 | model = build_model(args_dict['model']) 95 | # model = torch.nn.DataParallel(model) 96 | model.load_state_dict( 97 | torch.load(r'F:\projects\PyCharmProjects\SSLSeg\results\breast_mri\unet_breast_mri_0924194128\iter_6000.pth')[ 98 | 'state_dict']) 99 | show_pred_gt(model, 100 | device='cuda', 101 | img_size=(3, 256, 256), 102 | datasets=r'F:\datasets\breast_mri\256x256\val', 103 | output_path=r'F:\DeskTop\preds_unet_') 104 | -------------------------------------------------------------------------------- /codes/utils/init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from timm.models.layers import trunc_normal_ 4 | 5 | 6 | def kaiming_normal_init_weight(model): 7 | for m in model.modules(): 8 | if isinstance(m, nn.Conv2d): 9 | torch.nn.init.kaiming_normal_(m.weight) 10 | elif isinstance(m, nn.BatchNorm2d): 11 | m.weight.data.fill_(1) 12 | m.bias.data.zero_() 13 | elif isinstance(m, nn.Linear): 14 | trunc_normal_(m.weight, std=.02) 15 | if m.bias is not None: 16 | nn.init.constant_(m.bias, 0) 17 | elif isinstance(m, nn.LayerNorm): 18 | nn.init.constant_(m.bias, 0) 19 | nn.init.constant_(m.weight, 1.0) 20 | 21 | 22 | def xavier_normal_init_weight(model): 23 | for m in model.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | torch.nn.init.xavier_normal_(m.weight) 26 | elif isinstance(m, nn.BatchNorm2d): 27 | m.weight.data.fill_(1) 28 | m.bias.data.zero_() -------------------------------------------------------------------------------- /codes/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /codes/utils/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import yaml 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | class Namespace(object): 26 | def __init__(self, some_dict): 27 | for key, value in some_dict.items(): 28 | assert isinstance(key, str) and re.match("[A-Za-z_-]", key) 29 | if isinstance(value, dict): 30 | self.__dict__[key] = Namespace(value) 31 | elif isinstance(value, list): 32 | value = value.copy() 33 | if isinstance(value[0], dict): 34 | for i in range(len(value)): 35 | value[i] = Namespace(value[i]) 36 | self.__dict__[key] = value 37 | else: 38 | self.__dict__[key] = value 39 | 40 | def __getattr__(self, attribute): 41 | 42 | raise AttributeError( 43 | f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!") 44 | 45 | 46 | def deep_update_dict(res_dict, in_dict): 47 | for key in in_dict.keys(): 48 | if key in res_dict and isinstance(in_dict[key], dict) and isinstance(res_dict[key], dict) and \ 49 | 'name' in in_dict[key].keys() and 'kwargs' in in_dict[key].keys() and \ 50 | 'name' in res_dict[key].keys() and 'kwargs' in res_dict[key].keys() and \ 51 | in_dict[key]['name'] == res_dict[key]['name']: 52 | deep_update_dict(res_dict[key]['kwargs'], in_dict[key]['kwargs']) 53 | else: 54 | res_dict[key] = in_dict[key] 55 | 56 | 57 | def parse_yaml(yaml_file_path): 58 | res = {} 59 | with open(yaml_file_path, 'r') as yaml_file: 60 | f = yaml.load(yaml_file, Loader=yaml.FullLoader) 61 | if 'include' in f: 62 | abs_path = os.path.abspath(yaml_file_path) 63 | abs_path = abs_path.replace('\\', '/') 64 | abs_path_list = abs_path.split('/') 65 | for include_file_path in f['include']: 66 | include_file_path = include_file_path.replace('\\', '/') 67 | include_path_list = include_file_path.split('/') 68 | if '' in include_path_list: 69 | include_path_list.remove('') 70 | if '.' in include_path_list: 71 | include_path_list.remove('.') 72 | n = include_path_list.count('..') 73 | include_file_path = '/'.join(abs_path_list[:-(n + 1)] + include_path_list[n:]) 74 | with open(include_file_path, 'r') as include_file: 75 | deep_update_dict(res, yaml.load(include_file, Loader=yaml.FullLoader)) 76 | # res.update(yaml.load(include_file, Loader=yaml.FullLoader)) 77 | deep_update_dict(res, f) 78 | # res.update(f) 79 | if 'include' in res.keys(): 80 | res.pop('include') 81 | return res -------------------------------------------------------------------------------- /configs/_datasets/acdc.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: acdc 3 | kwargs: 4 | root_dir: F:/datasets/ACDC/ 5 | labeled_num: 7 6 | labeled_bs: 12 7 | batch_size: 24 8 | batch_size_val: 16 9 | num_workers: 0 10 | train_transforms: 11 | - name: RandomGenerator 12 | kwargs: { output_size: [ 256, 256 ] } 13 | - name: ToRGB 14 | - name: RandomCrop 15 | kwargs: { size: [ 256, 256 ] } 16 | - name: RandomFlip 17 | kwargs: { p: 0.5 } 18 | - name: ColorJitter 19 | kwargs: { brightness: 0.4,contrast: 0.4, saturation: 0.4, hue: 0.1, p: 0.8 } 20 | val_transforms: 21 | - name: RandomGenerator 22 | kwargs: 23 | p_flip: 0.0 24 | p_rot: 0.0 25 | output_size: [ 256, 256 ] 26 | - name: ToRGB -------------------------------------------------------------------------------- /configs/_datasets/acdc_224.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: acdc 3 | kwargs: 4 | root_dir: F:/datasets/ACDC/ 5 | labeled_num: 7 6 | labeled_bs: 8 7 | batch_size: 16 8 | batch_size_val: 16 9 | num_workers: 0 10 | train_transforms: 11 | - name: RandomGenerator 12 | kwargs: { output_size: [ 224, 224 ], p_flip: 0.5, p_rot: 0.5 } 13 | - name: ToRGB 14 | - name: RandomCrop 15 | kwargs: { size: [ 224, 224 ] } 16 | - name: RandomFlip 17 | kwargs: { p: 0.5 } 18 | - name: ColorJitter 19 | kwargs: { brightness: 0.4,contrast: 0.4, saturation: 0.4, hue: 0.1, p: 0.8 } 20 | val_transforms: 21 | - name: RandomGenerator 22 | kwargs: { output_size: [ 224, 224 ], p_flip: 0.5, p_rot: 0.5 } 23 | - name: ToRGB -------------------------------------------------------------------------------- /configs/_models/unet_r50.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: UNet 3 | kwargs: 4 | in_channels: 1 5 | classes: 4 6 | encoder_name: resnet50 7 | encoder_weights: imagenet -------------------------------------------------------------------------------- /configs/_models/unet_tf_r50.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: UNetTF 3 | kwargs: 4 | in_channels: 3 5 | classes: 4 6 | encoder_name: resnet50 7 | encoder_weights: imagenet 8 | contrast_embed: False 9 | contrast_embed_dim: 256 10 | contrast_embed_index: -3 -------------------------------------------------------------------------------- /configs/_trainers/mt.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | name: MeanTeacherTrainer 3 | kwargs: 4 | criterions: 5 | - name: DiceLoss_ 6 | kwargs: { weight: 0.5 ,n_classes: 4, class_weight: [ 0.0, 1.0, 1.0, 1.0 ] } 7 | - name: CrossEntropyLoss_ 8 | kwargs: { weight: 0.5 } 9 | metrics: # metrics config list 10 | - name: Dice 11 | kwargs: 12 | name: Dice 13 | class_indexs: [ 0, 1, 2, 3 ] 14 | class_names: [ bg, c1, c2, c3 ] 15 | - name: Jaccard 16 | kwargs: 17 | name: Jaccard 18 | class_indexs: [ 0, 1, 2, 3 ] 19 | class_names: [ bg, c1, c2, c3 ] 20 | labeled_bs: 8 21 | alpha: 0.99 22 | consistency: 0.1 23 | consistency_rampup: 40.0 24 | max_iter: 6000 25 | eval_interval: 1000 26 | save_image_interval: 50 27 | save_ckpt_interval: 1000 28 | 29 | ema_model: 30 | name: UNet 31 | kwargs: 32 | in_channels: 3 33 | classes: 2 34 | encoder_name: resnet50 35 | encoder_weights: imagenet 36 | 37 | optimizer: 38 | name: SGD 39 | kwargs: 40 | lr: 0.01 41 | weight_decay: 0.0001 42 | momentum: 0.9 43 | 44 | scheduler: 45 | name: StepLR 46 | kwargs: 47 | step_size: 2500 48 | gamma: 0.1 49 | 50 | #scheduler: 51 | # name: PolyLR 52 | # kwargs: 53 | # max_iters: 10000 54 | # power: 0.9 55 | # last_epoch: -1 56 | # min_lr: 0.0001 57 | 58 | logger: 59 | log_file: True 60 | tensorboard: True 61 | -------------------------------------------------------------------------------- /configs/_trainers/supervised.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | name: SupervisedTrainer 3 | kwargs: 4 | criterions: 5 | - name: DiceLoss_ 6 | kwargs: { weight: 0.5 ,n_classes: 4, class_weight: [ 0.0, 1.0, 1.0, 1.0 ] } 7 | - name: CrossEntropyLoss_ 8 | kwargs: { weight: 0.5 } 9 | metrics: # metrics config list 10 | - name: Dice 11 | kwargs: 12 | name: Dice 13 | class_indexs: [ 0, 1, 2, 3 ] 14 | class_names: [ bg, c1, c2, c3 ] 15 | - name: Jaccard 16 | kwargs: 17 | name: Jaccard 18 | class_indexs: [ 0, 1, 2, 3 ] 19 | class_names: [ bg, c1, c2, c3 ] 20 | labeled_bs: 8 21 | max_iter: 6000 22 | eval_interval: 1000 23 | save_image_interval: 50 24 | save_ckpt_interval: 1000 25 | 26 | optimizer: 27 | name: SGD 28 | kwargs: 29 | lr: 0.01 30 | weight_decay: 0.0001 31 | momentum: 0.9 32 | 33 | scheduler: 34 | name: StepLR 35 | kwargs: 36 | step_size: 2500 37 | gamma: 0.1 38 | 39 | logger: 40 | log_file: True 41 | tensorboard: True -------------------------------------------------------------------------------- /configs/_trainers/ugpcl.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | name: UGPCLTrainer 3 | kwargs: 4 | criterions: 5 | - name: DiceLoss_ 6 | kwargs: { weight: 0.5 ,n_classes: 4, class_weight: [ 0.0, 1.0, 1.0, 1.0 ] } 7 | - name: CrossEntropyLoss_ 8 | kwargs: { weight: 0.5 } 9 | metrics: # metrics config list 10 | - name: Dice 11 | kwargs: 12 | name: Dice 13 | class_indexs: [ 0, 1, 2, 3 ] 14 | class_names: [ bg, c1, c2, c3 ] 15 | - name: Jaccard 16 | kwargs: 17 | name: Jaccard 18 | class_indexs: [ 0, 1, 2, 3 ] 19 | class_names: [ bg, c1, c2, c3 ] 20 | labeled_bs: 8 21 | consistency: 0.1 22 | consistency_rampup: 60.0 23 | contrast_weight: 0.1 24 | temperature: 0.07 25 | base_temperature: 0.07 26 | memory: true 27 | max_samples: 1024 28 | max_views: 1 29 | memory_size: 2000 30 | pixel_update_freq: 10 31 | pixel_classes: 4 32 | max_iter: 6000 33 | eval_interval: 1000 34 | save_image_interval: 50 35 | save_ckpt_interval: 1000 36 | 37 | optimizer: 38 | name: SGD 39 | kwargs: 40 | lr: 0.01 41 | weight_decay: 0.0005 42 | momentum: 0.9 43 | 44 | scheduler: 45 | name: PolyLR 46 | kwargs: 47 | max_iters: 6000 48 | power: 0.9 49 | last_epoch: -1 50 | min_lr: 0.0001 51 | 52 | logger: 53 | log_file: True 54 | tensorboard: True 55 | -------------------------------------------------------------------------------- /configs/comparison_acdc_224_136/mt_unet_r50.yaml: -------------------------------------------------------------------------------- 1 | name: mt_unet_r50 2 | 3 | include: 4 | - ../_datasets/acdc_224.yaml 5 | - ../_models/unet_r50.yaml 6 | - ../_trainers/mt.yaml 7 | 8 | train: 9 | name: MeanTeacherTrainer 10 | kwargs: 11 | consistency: 0.1 12 | consistency_rampup: 100.0 # max_iter // 100 13 | max_iter: 10000 14 | eval_interval: 1000 15 | save_image_interval: 50 16 | save_ckpt_interval: 1000 17 | 18 | scheduler: 19 | name: PolyLR 20 | kwargs: 21 | max_iters: 10000 22 | 23 | model: 24 | name: UNet 25 | kwargs: 26 | in_channels: 3 27 | classes: 4 28 | 29 | ema_model: 30 | name: UNet 31 | kwargs: 32 | in_channels: 3 33 | classes: 4 -------------------------------------------------------------------------------- /configs/comparison_acdc_224_136/ugpcl_unet_r50.yaml: -------------------------------------------------------------------------------- 1 | name: ugpcl_unet_r50 2 | 3 | include: 4 | - ../_datasets/acdc_224.yaml 5 | - ../_models/unet_tf_r50.yaml 6 | - ../_trainers/ugpcl.yaml 7 | 8 | dataset: 9 | name: acdc 10 | kwargs: 11 | labeled_bs: 8 12 | batch_size: 16 13 | train_transforms: 14 | - name: RandomGenerator 15 | kwargs: { output_size: [ 224, 224 ], p_flip: 0.0, p_rot: 0.0 } # Remove random rotation 16 | - name: ToRGB 17 | - name: RandomCrop 18 | kwargs: { size: [ 224, 224 ] } 19 | - name: RandomFlip 20 | kwargs: { p: 0.5 } 21 | - name: ColorJitter 22 | kwargs: { brightness: 0.4,contrast: 0.4, saturation: 0.4, hue: 0.1, p: 0.8 } 23 | 24 | train: 25 | name: UGPCLTrainer 26 | kwargs: 27 | contrast_weight: 0.1 28 | labeled_bs: 8 29 | consistency: 0.01 30 | consistency_rampup: 100.0 # max_iter // 100 31 | memory: true 32 | max_samples: 1024 33 | max_views: 1 34 | memory_size: 500 35 | pixel_update_freq: 10 36 | pixel_classes: 4 37 | max_iter: 10000 38 | 39 | model: 40 | name: UNetTF 41 | kwargs: 42 | contrast_embed: True 43 | 44 | scheduler: 45 | name: PolyLR 46 | kwargs: 47 | max_iters: 10000 48 | -------------------------------------------------------------------------------- /configs/comparison_acdc_224_136/unet_r50.yaml: -------------------------------------------------------------------------------- 1 | name: unet_r50 2 | 3 | include: 4 | - ../_datasets/acdc_224.yaml 5 | - ../_models/unet_r50.yaml 6 | - ../_trainers/supervised.yaml 7 | 8 | train: 9 | name: SupervisedTrainer 10 | kwargs: 11 | max_iter: 10000 12 | eval_interval: 1000 13 | save_image_interval: 50 14 | save_ckpt_interval: 1000 15 | 16 | model: 17 | name: UNet 18 | kwargs: 19 | in_channels: 3 20 | classes: 4 21 | 22 | scheduler: 23 | name: PolyLR 24 | kwargs: 25 | max_iters: 10000 -------------------------------------------------------------------------------- /configs/comparison_acdc_224_136/unet_r50_full.yaml: -------------------------------------------------------------------------------- 1 | name: unet_r50 2 | 3 | include: 4 | - ../_datasets/acdc_224.yaml 5 | - ../_models/unet_r50.yaml 6 | - ../_trainers/supervised.yaml 7 | 8 | dataset: 9 | name: acdc 10 | kwargs: 11 | labeled_bs: 16 12 | 13 | train: 14 | name: SupervisedTrainer 15 | kwargs: 16 | labeled_bs: 16 17 | max_iter: 10000 18 | eval_interval: 1000 19 | save_image_interval: 50 20 | save_ckpt_interval: 1000 21 | 22 | model: 23 | name: UNet 24 | kwargs: 25 | in_channels: 3 26 | classes: 4 27 | 28 | scheduler: 29 | name: PolyLR 30 | kwargs: 31 | max_iters: 10000 -------------------------------------------------------------------------------- /pics/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/pics/overview.jpg -------------------------------------------------------------------------------- /pics/preds.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/pics/preds.jpg -------------------------------------------------------------------------------- /pics/show_feats.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/pics/show_feats.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work 2 | h5py==3.4.0 3 | matplotlib==3.4.2 4 | MedPy==0.4.0 5 | numpy==1.21.1 6 | prettytable==2.1.0 7 | scipy==1.7.1 8 | segmentation-models-pytorch==0.2.0 9 | sklearn==0.0 10 | tensorboardX==2.4 11 | timm==0.4.12 12 | torch==1.9.1+cu102 13 | torchvision==0.10.1+cu102 14 | tqdm==4.62.0 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import cv2 6 | 7 | from tqdm import tqdm 8 | from colorama import Fore 9 | from prettytable import PrettyTable 10 | from codes.builder import * 11 | from codes.utils.utils import Namespace, parse_yaml 12 | 13 | from sklearn.manifold import TSNE 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | def get_args(config_file): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--config', type=str, 20 | default=config_file, 21 | help='train config file path: xxx.yaml') 22 | parser.add_argument('--device', type=str, default='cuda:1' if torch.cuda.is_available() else 'cpu') 23 | args = parser.parse_args() 24 | args_dict = parse_yaml(args.config) 25 | for key, value in Namespace(args_dict).__dict__.items(): 26 | vars(args)[key] = value 27 | return args 28 | 29 | 30 | def save_result(img, 31 | seg, 32 | file_path, 33 | opacity=0.5, 34 | palette=[[0, 0, 0], [0, 255, 255], [255, 106, 106], [255, 250, 240]]): 35 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 36 | for label, color in enumerate(palette): 37 | color_seg[seg == label, :] = color 38 | # convert to BGR 39 | color_seg = color_seg[..., ::-1] 40 | 41 | img = img * (1 - opacity) + color_seg * opacity 42 | img = img.astype(np.uint8) 43 | 44 | dir_name = os.path.abspath(os.path.dirname(file_path)) 45 | os.makedirs(dir_name, exist_ok=True) 46 | cv2.imwrite(file_path, img) 47 | 48 | 49 | def test(config_file, weights, save_pred=True): 50 | args = get_args(config_file) 51 | 52 | _, test_loader = build_dataloader(args.dataset, None) 53 | 54 | model = build_model(args.model) 55 | state_dict = torch.load(weights) 56 | if 'state_dict' in state_dict.keys(): 57 | state_dict = state_dict['state_dict'] 58 | model.load_state_dict(state_dict) 59 | 60 | print(F'\nModel: {model.__class__.__name__}') 61 | print(F'Weights file: {weights}') 62 | 63 | test_loader = tqdm(test_loader, desc='Testing', unit='batch', 64 | bar_format='%s{l_bar}{bar}{r_bar}%s' % (Fore.LIGHTCYAN_EX, Fore.RESET)) 65 | 66 | metrics = [] 67 | if args.train.kwargs.metrics is not None: 68 | for metric_cfg in args.train.kwargs.metrics: 69 | metrics.append(_build_from_cfg(metric_cfg)) 70 | 71 | test_res = None 72 | model.to(args.device) 73 | model.eval() 74 | i = 0 75 | for batch_data in test_loader: 76 | data, label = batch_data['image'].to(args.device), batch_data['label'].to(args.device) 77 | preds = model.inference(data) 78 | batch_metric_res = {} 79 | for metric in metrics: 80 | batch_metric_res[metric.name] = metric(preds, label) 81 | 82 | if save_pred: 83 | for j in range(preds.size(0)): 84 | save_result( 85 | data[j].permute(1, 2, 0).cpu().numpy() * 255., 86 | preds[j][0].cpu().numpy(), 87 | f'./shows/{i}.png', 88 | opacity=1.0) 89 | i += 1 90 | 91 | if test_res is None: 92 | test_res = batch_metric_res 93 | else: 94 | for metric_name in test_res.keys(): 95 | for key in test_res[metric_name].keys(): 96 | test_res[metric_name][key] += batch_metric_res[metric_name][key] 97 | for metric_name in test_res.keys(): 98 | for key in test_res[metric_name].keys(): 99 | test_res[metric_name][key] = test_res[metric_name][key] / len(test_loader) 100 | 101 | test_res_list = [_.cpu() for _ in test_res[metric_name].values()] 102 | test_res[metric_name]['Mean'] = np.mean(test_res_list[1:]) 103 | 104 | test_table = PrettyTable() 105 | test_table.field_names = ['Metirc'] + list(list(test_res.values())[0].keys()) 106 | for metric_name in test_res.keys(): 107 | if metric_name in ['Dice', 'Jaccard', 'Acc', 'IoU', 'Recall', 'Precision']: 108 | temp = [float(format(_ * 100, '.2f')) for _ in test_res[metric_name].values()] 109 | else: 110 | temp = [float(format(_, '.2f')) for _ in test_res[metric_name].values()] 111 | test_table.add_row([metric_name] + temp) 112 | print(test_table.get_string()) 113 | 114 | 115 | def show_features(config_file, weights): 116 | args = get_args(config_file) 117 | 118 | _, test_loader = build_dataloader(args.dataset, None) 119 | 120 | model = build_model(args.model) 121 | state_dict = torch.load(weights) 122 | if 'state_dict' in state_dict.keys(): 123 | state_dict = state_dict['state_dict'] 124 | model.load_state_dict(state_dict) 125 | 126 | print(F'\nModel: {model.__class__.__name__}') 127 | print(F'Weights file: {weights}') 128 | 129 | model.to(args.device) 130 | model.eval() 131 | for batch_data in test_loader: 132 | data, label = batch_data['image'].to(args.device), batch_data['label'].to(args.device) 133 | preds = model.inference_features(data) 134 | features_ = preds['feats'][-3].permute(0, 2, 3, 1).contiguous().view(16 * 28 * 28, -1).cpu().detach().numpy() 135 | labels_ = torch.nn.functional.interpolate(label, (28, 28), mode='nearest') 136 | labels_ = labels_.permute(0, 2, 3, 1).contiguous().view(16 * 28 * 28, -1).squeeze(1).cpu().detach().numpy() 137 | 138 | features = features_[labels_ != 0] 139 | labels = labels_[labels_ != 0] 140 | tsne = TSNE(n_components=2) 141 | X_tsne = tsne.fit_transform(features) 142 | x_min, x_max = X_tsne.min(0), X_tsne.max(0) 143 | X_norm = (X_tsne - x_min) / (x_max - x_min) # 归一化 144 | plt.figure(figsize=(8, 8)) 145 | for i in range(X_norm.shape[0]): 146 | plt.text(X_norm[i, 0], X_norm[i, 1], str(labels[i]), color=plt.cm.Set1(labels[i]), 147 | fontdict={'weight': 'bold', 'size': 9}) 148 | plt.xticks([]) 149 | plt.yticks([]) 150 | plt.show() 151 | exit() 152 | 153 | 154 | if __name__ == '__main__': 155 | # test(r'weights/ugpcl_acdc_136/ugpcl_unet_r50.yaml', 156 | # r'F:\DeskTop\UGPCL\weights\ugpcl_acdc_136\ugpcl_88.11.pth') 157 | show_features(r'weights/ugpcl_acdc_136/ugpcl_unet_r50.yaml', 158 | r'F:\DeskTop\UGPCL\weights\ugpcl_acdc_136\ugpcl_88.11.pth') 159 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | import warnings 5 | import random 6 | 7 | import torch 8 | import numpy as np 9 | 10 | from datetime import datetime 11 | from codes.builder import build_dataloader, build_logger 12 | from codes.utils.utils import Namespace, parse_yaml 13 | from codes.trainers import * 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--config', type=str, default='configs/comparison_acdc_224_136/ugpcl_unet_r50.yaml', 19 | help='train config file path: xxx.yaml') 20 | parser.add_argument('--work_dir', type=str, 21 | default=f'results/comparison_acdc_224_136', 22 | help='the dir to save logs and models') 23 | parser.add_argument('--resume_from', type=str, 24 | # default='results/comparison_acdc_224_136/ugcl_mem_unet_r50_0430155558/iter_1000.pth', 25 | default=None, 26 | help='the checkpoint file to resume from') 27 | parser.add_argument('--start_step', type=int, default=0) 28 | parser.add_argument('--device', type=str, default='cuda:0' if torch.cuda.is_available() else 'cpu') 29 | parser.add_argument('--data_parallel', type=bool, default=False) 30 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 31 | parser.add_argument('--deterministic', type=bool, default=True, 32 | help='whether to set deterministic options for CUDNN backend.') 33 | args = parser.parse_args() 34 | 35 | args_dict = parse_yaml(args.config) 36 | 37 | for key, value in Namespace(args_dict).__dict__.items(): 38 | if key in ['name', 'dataset', 'train', 'logger']: 39 | vars(args)[key] = value 40 | 41 | for key, value in Namespace(args_dict).__dict__.items(): 42 | if key not in ['name', 'dataset', 'train', 'logger']: 43 | vars(args.train.kwargs)[key] = value 44 | 45 | if args.work_dir is None: 46 | args.work_dir = f'results/{args.dataset.name}' 47 | if args.resume_from is not None: 48 | args.logger.log_dir = os.path.split(os.path.abspath(args.resume_from))[0] 49 | args.logger.file_mode = 'a' 50 | else: 51 | args.logger.log_dir = f'{args.work_dir}/{args.name}_{datetime.now().strftime("%m%d%H%M%S")}' 52 | args.ckpt_save_path = args.logger.log_dir 53 | 54 | for key in args.__dict__.keys(): 55 | if key not in args_dict.keys(): 56 | args_dict[key] = args.__dict__[key] 57 | 58 | return args, args_dict 59 | 60 | 61 | def set_deterministic(seed): 62 | if seed is not None: 63 | random.seed(seed) 64 | np.random.seed(seed) 65 | torch.manual_seed(seed) 66 | torch.cuda.manual_seed(seed) 67 | torch.backends.cudnn.deterministic = True 68 | torch.backends.cudnn.benchmark = False 69 | 70 | 71 | def build_trainer(name, 72 | logger=None, 73 | device='cuda', 74 | data_parallel=False, 75 | ckpt_save_path=None, 76 | resume_from=None, 77 | **kwargs): 78 | return eval(f'{name}')(logger=logger, device=device, data_parallel=data_parallel, ckpt_save_path=ckpt_save_path, 79 | resume_from=resume_from, **kwargs) 80 | 81 | 82 | def train(): 83 | args, args_dict = get_args() 84 | set_deterministic(args.seed) 85 | 86 | def worker_init_fn(worker_id): 87 | random.seed(worker_id + args.seed) 88 | 89 | train_loader, val_loader = build_dataloader(args.dataset, worker_init_fn) 90 | logger = build_logger(args.logger) 91 | 92 | args_yaml_info = yaml.dump(args_dict, sort_keys=False, default_flow_style=None) 93 | yaml_file_name = os.path.split(args.config)[-1] 94 | with open(os.path.join(args.ckpt_save_path, yaml_file_name), 'w') as f: 95 | f.write(args_yaml_info) 96 | f.close() 97 | 98 | logger.info(f'\n{args_yaml_info}\n') 99 | 100 | trainer = build_trainer(name=args.train.name, 101 | logger=logger, 102 | device=args.device, 103 | data_parallel=args.data_parallel, 104 | ckpt_save_path=args.ckpt_save_path, 105 | resume_from=args.resume_from, 106 | **args.train.kwargs.__dict__) 107 | trainer.train(train_loader, val_loader) 108 | logger.close() 109 | 110 | 111 | if __name__ == '__main__': 112 | warnings.filterwarnings("ignore") 113 | train() 114 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 2 | 3 | 4 | python train.py --config='configs/comparison_acdc_224_136/ugpcl_unet_r50.yaml' --device='cuda:1' \ 5 | --work_dir='results/comparison_acdc_224_136' 6 | 7 | --------------------------------------------------------------------------------