├── .gitignore
├── README.md
├── codes
├── __init__.py
├── builder.py
├── dataloaders
│ ├── __init__.py
│ ├── _base.py
│ ├── _samplers.py
│ ├── _transforms.py
│ └── acdc.py
├── logger.py
├── losses
│ ├── __init__.py
│ ├── losses.py
│ └── pixel_contras_loss.py
├── metrics
│ ├── __init__.py
│ └── metrics.py
├── models
│ ├── __init__.py
│ ├── _base.py
│ ├── swin_decoder.py
│ ├── unet.py
│ └── unet_tf.py
├── schedulers
│ ├── __init__.py
│ └── poly.py
├── trainers
│ ├── __init__.py
│ ├── _base.py
│ ├── mt_trainer.py
│ ├── supervised_trainer.py
│ └── ugpcl_trainer.py
└── utils
│ ├── __init__.py
│ ├── analyze.py
│ ├── init.py
│ ├── ramps.py
│ └── utils.py
├── configs
├── _datasets
│ ├── acdc.yaml
│ └── acdc_224.yaml
├── _models
│ ├── unet_r50.yaml
│ └── unet_tf_r50.yaml
├── _trainers
│ ├── mt.yaml
│ ├── supervised.yaml
│ └── ugpcl.yaml
└── comparison_acdc_224_136
│ ├── mt_unet_r50.yaml
│ ├── ugpcl_unet_r50.yaml
│ ├── unet_r50.yaml
│ └── unet_r50_full.yaml
├── pics
├── overview.jpg
├── preds.jpg
└── show_feats.jpg
├── requirements.txt
├── test.py
├── train.py
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .git
3 | results
4 | shows
5 | weights
6 | convert.py
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UGPCL
2 |
3 | > [IJCAI' 22] Uncertainty-Guided Pixel Contrastive Learning for Semi-Supervised Medical Image Segmentation.
4 | >
5 | > Tao Wang, Jianglin Lu, Zhihui Lai, Jiajun Wen and Heng Kong
6 |
7 | 
8 |
9 | # Installation
10 | Please refer to [requirements.txt](requirements.txt)
11 |
12 | # Dataset
13 | *ACDC* dataset can be found in this [Link](https://github.com/HiLab-git/SSL4MIS/tree/master/data/ACDC).
14 |
15 | Please change the dataset root directory in _configs/\_datasets/acdc_224.yaml_.
16 |
17 | # Results
18 |
19 | ### *ACDC* (136 labels | 10%):
20 |
21 | | Method | Model | Iterations | Batch Size | Label Size | DsC | Ckpt | Config File |
22 | | :----------: | :------: | :--------: | :--------: | :--------: | :---: | :----------------------------------------------------------: | :-------------------------------------------------: |
23 | | UGPCL | UNet-R50 | 6000 | 16 | 8 | 88.11 | [Link](https://drive.google.com/file/d/1T8T6g_xiJWGetQhZeFMNG2q7dzmYyN4s/view?usp=sharing) | configs/comparison_acdc_224_136/ugpcl_unet_r50.yaml |
24 | | Mean Teacher | UNet-R50 | 6000 | 16 | 8 | 85.75 | [Link](https://drive.google.com/file/d/1mWKKoeZbSlf6DNxqnoypr50ialPMqFYL/view?usp=sharing) | configs/comparison_acdc_224_136/mt_unet_r50.yaml |
25 |
26 | ### Visualization
27 |
28 | - Segmentation results:
29 |
30 |
31 |
32 | - Pixel features (t-SNE):
33 |
34 |
35 |
36 |
37 | # Reference
38 | - [https://github.com/HiLab-git/SSL4MIS](https://github.com/HiLab-git/SSL4MIS)
39 | - [https://github.com/tfzhou/ContrastiveSeg](https://github.com/tfzhou/ContrastiveSeg)
40 |
41 | # Citation
42 | ```bibtex
43 | @inproceedings{ijcai2022-201,
44 | title = {Uncertainty-Guided Pixel Contrastive Learning for Semi-Supervised Medical Image Segmentation},
45 | author = {Wang, Tao and Lu, Jianglin and Lai, Zhihui and Wen, Jiajun and Kong, Heng},
46 | booktitle = {Proceedings of the Thirty-First International Joint Conference on
47 | Artificial Intelligence, {IJCAI-22}},
48 | publisher = {International Joint Conferences on Artificial Intelligence Organization},
49 | editor = {Lud De Raedt},
50 | pages = {1444--1450},
51 | year = {2022},
52 | month = {7},
53 | note = {Main Track},
54 | doi = {10.24963/ijcai.2022/201},
55 | url = {https://doi.org/10.24963/ijcai.2022/201},
56 | }
57 | ```
58 |
--------------------------------------------------------------------------------
/codes/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/codes/__init__.py
--------------------------------------------------------------------------------
/codes/builder.py:
--------------------------------------------------------------------------------
1 | from .dataloaders import *
2 | from .losses import *
3 | from .metrics import *
4 | from .models import *
5 | from .schedulers import *
6 | from .losses import *
7 | from .logger import Logger
8 |
9 | from torch.optim.lr_scheduler import *
10 | from torch.optim import *
11 |
12 | __all__ = ['build_model', 'build_metric', 'build_dataloader', 'build_scheduler',
13 | 'build_optimizer', 'build_criterion', 'build_logger', '_build_from_cfg']
14 |
15 |
16 | def _build_from_cfg(cfg):
17 | if isinstance(cfg, dict):
18 | name = cfg['name']
19 | if 'kwargs' in cfg.keys() and cfg['kwargs'] is not None:
20 | return eval(f"{name}")(**cfg['kwargs'])
21 | else:
22 | name = cfg.name
23 | if hasattr(cfg, 'kwargs') and cfg.kwargs is not None:
24 | return eval(f"{name}")(**cfg.kwargs.__dict__)
25 | return eval(f"{name}()")
26 |
27 |
28 | def build_model(cfg):
29 | return _build_from_cfg(cfg)
30 |
31 |
32 | def build_optimizer(model_parameters, cfg):
33 | if isinstance(cfg, dict):
34 | name = cfg['name']
35 | if 'kwargs' in cfg.keys() and cfg['kwargs'] is not None:
36 | kwargs = cfg['kwargs']
37 | kwargs['params'] = model_parameters
38 | return eval(f"{name}")(**kwargs)
39 | else:
40 | name = cfg.name
41 | if hasattr(cfg, 'kwargs') and cfg.kwargs is not None:
42 | kwargs = cfg.kwargs.__dict__
43 | kwargs['params'] = model_parameters
44 | return eval(f"{name}")(**kwargs)
45 | kwargs = {'params': model_parameters}
46 | return eval(f"{name}")(**kwargs)
47 |
48 |
49 | def build_scheduler(optimizer_, cfg):
50 | if isinstance(cfg, dict):
51 | name = cfg['name']
52 | if 'kwargs' in cfg.keys() and cfg['kwargs'] is not None:
53 | kwargs = cfg['kwargs']
54 | kwargs['optimizer'] = optimizer_
55 | return eval(f"{name}")(**kwargs)
56 | else:
57 | name = cfg.name
58 | if hasattr(cfg, 'kwargs') and cfg.kwargs is not None:
59 | kwargs = cfg.kwargs.__dict__
60 | kwargs['optimizer'] = optimizer_
61 | return eval(f"{name}")(**kwargs)
62 | kwargs = {'optimizer': optimizer_}
63 | return eval(f"{name}")(**kwargs)
64 |
65 |
66 | def build_criterion(cfg):
67 | return _build_from_cfg(cfg)
68 |
69 |
70 | def build_metric(cfg):
71 | return _build_from_cfg(cfg)
72 |
73 |
74 | def build_dataloader(cfg, worker_init_fn):
75 | if not isinstance(cfg, dict):
76 | kwargs = cfg.kwargs.__dict__
77 | else:
78 | kwargs = cfg['kwargs']
79 | kwargs['worker_init_fn'] = worker_init_fn
80 | return eval(f"get_{cfg.name}_loaders")(**kwargs)
81 |
82 | def build_logger(cfg):
83 | return eval(f"Logger")(**cfg.__dict__)
84 |
--------------------------------------------------------------------------------
/codes/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 | from .acdc import get_acdc_loaders
2 |
--------------------------------------------------------------------------------
/codes/dataloaders/_base.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from torch.utils.data.dataset import T_co
3 |
4 |
5 | class BaseDataSet(Dataset):
6 |
7 | def __init__(self) -> None:
8 | super().__init__()
9 |
10 | def __getitem__(self, index) -> T_co:
11 | pass
12 |
13 | def __len__(self):
14 | pass
15 |
--------------------------------------------------------------------------------
/codes/dataloaders/_samplers.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import numpy as np
3 | from torch.utils.data.sampler import Sampler
4 |
5 |
6 | class TwoStreamBatchSampler(Sampler):
7 | """
8 | Iterate two sets of indices
9 |
10 | An 'epoch' is one iteration through the primary indices.
11 | During the epoch, the secondary indices are iterated through
12 | as many times as needed.
13 | """
14 |
15 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
16 | self.primary_indices = primary_indices
17 | self.secondary_indices = secondary_indices
18 | self.primary_batch_size = batch_size - secondary_batch_size
19 | self.secondary_batch_size = secondary_batch_size
20 |
21 | assert len(self.primary_indices) >= self.primary_batch_size > 0
22 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0
23 |
24 | def __iter__(self):
25 | """
26 | 生成每个batch的数据采样索引:分别从labeled indices和unlabeled indices中选取
27 | 每次调用索引序列都会被打乱,确保随机采样
28 | Returns: (labeled index1,labeled index2,...,unlabeled index1,unlabeled index2,...)
29 |
30 | """
31 | primary_iter = iterate_once(self.primary_indices)
32 | secondary_iter = iterate_eternally(self.secondary_indices)
33 | return (
34 | primary_batch + secondary_batch
35 | for (primary_batch, secondary_batch)
36 | in zip(grouper(primary_iter, self.primary_batch_size),
37 | grouper(secondary_iter, self.secondary_batch_size))
38 | )
39 |
40 | def __len__(self):
41 | return len(self.primary_indices) // self.primary_batch_size
42 |
43 |
44 | def iterate_once(iterable):
45 | return np.random.permutation(iterable) # 随机排列序列
46 |
47 |
48 | def iterate_eternally(indices):
49 | """
50 | 异步不断生成随机打乱的indices序列 并由itertools.chain连接后返回
51 | Args:
52 | indices:
53 |
54 | Returns:
55 |
56 | """
57 | def infinite_shuffles():
58 | while True:
59 | yield np.random.permutation(indices)
60 |
61 | return itertools.chain.from_iterable(infinite_shuffles())
62 |
63 |
64 | def grouper(iterable, n):
65 | """
66 | Collect data into fixed-length chunks or blocks
67 | eg: grouper('ABCDEFG', 3) --> ABC DEF
68 | Args:
69 | iterable:
70 | n:
71 |
72 | Returns:
73 |
74 | """
75 | args = [iter(iterable)] * n
76 | return zip(*args)
77 |
--------------------------------------------------------------------------------
/codes/dataloaders/_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision.transforms.functional as F
4 |
5 | from numpy import random
6 | from scipy import ndimage
7 | from scipy.ndimage.interpolation import zoom
8 | from torchvision.transforms.functional import InterpolationMode
9 | from torchvision.transforms import transforms as T
10 |
11 |
12 | def build_transforms(transforms_):
13 | transforms = []
14 | for transform in transforms_:
15 | if hasattr(transform, 'kwargs') and transform.kwargs is not None:
16 | kwargs = transform.kwargs.__dict__
17 | transform = eval(f"{transform.name}")(**kwargs)
18 | else:
19 | transform = eval(f"{transform.name}()")
20 | transforms.append(transform)
21 | return Compose(transforms)
22 |
23 |
24 | class Compose:
25 |
26 | def __init__(self, transforms):
27 | self.transforms = transforms
28 |
29 | def __call__(self, img, **kwargs):
30 | for t in self.transforms:
31 | img = t(img, **kwargs)
32 | return img
33 |
34 | def __repr__(self):
35 | format_string = self.__class__.__name__ + '('
36 | for t in self.transforms:
37 | format_string += '\n'
38 | format_string += ' {0}'.format(t)
39 | format_string += '\n)'
40 | return format_string
41 |
42 |
43 | class ToTensor3D(object):
44 | """Convert ndarrays in sample to Tensors."""
45 |
46 | def __call__(self, sample, with_sdf=False):
47 | img = torch.from_numpy(sample['image'])
48 | label = torch.from_numpy(sample['label'])
49 | if len(img.shape) == 3:
50 | img = img.unsqueeze(0)
51 | if len(label.shape) == 3:
52 | label = label.unsqueeze(0)
53 | if with_sdf:
54 | sdf = torch.from_numpy(sample['sdf'])
55 | return {'image': img, 'label': label, 'sdf': sdf}
56 | return {'image': img, 'label': label}
57 |
58 |
59 | # 3D transforms
60 | class CenterCrop3D(object):
61 | def __init__(self, output_size):
62 | self.output_size = output_size
63 |
64 | def __call__(self, sample, with_sdf=False):
65 | image, label = sample['image'], sample['label']
66 | if with_sdf:
67 | sdf = sample['sdf']
68 | # pad the sample if necessary
69 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
70 | self.output_size[2]:
71 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
72 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
73 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
74 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
75 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
76 | if with_sdf:
77 | for _ in range(sdf.shape[0]):
78 | sdf[_] = np.pad(sdf[_], [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
79 |
80 | (w, h, d) = image.shape
81 |
82 | w1 = int(round((w - self.output_size[0]) / 2.))
83 | h1 = int(round((h - self.output_size[1]) / 2.))
84 | d1 = int(round((d - self.output_size[2]) / 2.))
85 |
86 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
87 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
88 | if with_sdf:
89 | for _ in range(sdf.shape[0]):
90 | sdf[_] = sdf[_][w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
91 | return {'image': image, 'label': label, 'sdf': sdf}
92 | return {'image': image, 'label': label}
93 |
94 |
95 | class RandomCrop3D(object):
96 | """
97 | Crop randomly the image in a sample
98 | Args:
99 | output_size (int): Desired output size
100 | """
101 |
102 | def __init__(self, output_size):
103 | self.output_size = output_size
104 |
105 | def __call__(self, sample, with_sdf=False):
106 | image, label = sample['image'], sample['label']
107 | if with_sdf:
108 | sdf = sample['sdf']
109 | # pad the sample if necessary
110 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \
111 | self.output_size[2]:
112 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0)
113 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
114 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
115 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
116 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
117 | if with_sdf:
118 | temp_sdf = []
119 | for _ in range(sdf.shape[0]):
120 | temp_sdf.append(np.pad(sdf[_], [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0))
121 | sdf = np.stack(temp_sdf)
122 |
123 | (w, h, d) = image.shape
124 | # if np.random.uniform() > 0.33:
125 | # w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4)
126 | # h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4)
127 | # else:
128 | w1 = np.random.randint(0, w - self.output_size[0])
129 | h1 = np.random.randint(0, h - self.output_size[1])
130 | d1 = np.random.randint(0, d - self.output_size[2])
131 |
132 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
133 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
134 |
135 | if with_sdf:
136 | sdf = sdf[:, w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
137 | return {'image': image, 'label': label, 'sdf': sdf}
138 | else:
139 | return {'image': image, 'label': label}
140 |
141 |
142 | class RandomRotFlip3D(object):
143 | """
144 | Crop randomly flip the dataset in a sample
145 | Args:
146 | output_size (int): Desired output size
147 | """
148 |
149 | def __call__(self, sample, with_sdf=False):
150 | image, label = sample['image'], sample['label']
151 | if with_sdf:
152 | sdf = sample['sdf']
153 | k = np.random.randint(0, 4)
154 | image = np.rot90(image, k)
155 | label = np.rot90(label, k)
156 | if with_sdf:
157 | sdf = np.rot90(sdf, k, axes=(1, 2))
158 | axis = np.random.randint(0, 2)
159 | image = np.flip(image, axis=axis).copy()
160 | label = np.flip(label, axis=axis).copy()
161 | if with_sdf:
162 | sdf = np.flip(sdf, axis=axis + 1).copy()
163 | if with_sdf:
164 | return {'image': image, 'label': label, 'sdf': sdf}
165 | return {'image': image, 'label': label}
166 |
167 |
168 | class RandomNoise3D(object):
169 | def __init__(self, mu=0, sigma=0.1):
170 | self.mu = mu
171 | self.sigma = sigma
172 |
173 | def __call__(self, sample, with_sdf=False):
174 | image, label = sample['image'], sample['label']
175 | noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2 * self.sigma,
176 | 2 * self.sigma)
177 | noise = noise + self.mu
178 | image = image + noise
179 | if with_sdf:
180 | return {'image': image, 'label': label, 'sdf': sample['sdf']}
181 | return {'image': image, 'label': label}
182 |
183 |
184 | class CreateOnehotLabel3D(object):
185 | def __init__(self, num_classes):
186 | self.num_classes = num_classes
187 |
188 | def __call__(self, sample):
189 | image, label = sample['image'], sample['label']
190 | onehot_label = np.zeros((self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32)
191 | for i in range(self.num_classes):
192 | onehot_label[i, :, :, :] = (label == i).astype(np.float32)
193 | return {'image': image, 'label': label, 'onehot_label': onehot_label}
194 |
195 |
196 | class RandomGenerator(object):
197 | def __init__(self, output_size, p_flip=0.5, p_rot=0.5):
198 | self.output_size = output_size
199 | self.p_flip = p_flip
200 | self.p_rot = p_rot
201 |
202 | def __call__(self, sample):
203 | image, label = sample['image'], sample['label']
204 | if torch.rand(1) < self.p_flip:
205 | image, label = self.random_rot_flip(image, label)
206 | elif torch.rand(1) < self.p_rot:
207 | image, label = self.random_rotate(image, label)
208 | x, y = image.shape
209 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=0)
210 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
211 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
212 | label = torch.from_numpy(label.astype(np.uint8)).unsqueeze(0)
213 | sample['image'] = image
214 | sample['label'] = label
215 | return sample
216 |
217 | @staticmethod
218 | def random_rot_flip(image, label):
219 | k = np.random.randint(0, 4)
220 | image = np.rot90(image, k)
221 | label = np.rot90(label, k)
222 | axis = np.random.randint(0, 2)
223 | image = np.flip(image, axis=axis).copy()
224 | label = np.flip(label, axis=axis).copy()
225 | return image, label
226 |
227 | @staticmethod
228 | def random_rotate(image, label):
229 | angle = np.random.randint(-20, 20)
230 | image = ndimage.rotate(image, angle, order=0, reshape=False)
231 | label = ndimage.rotate(label, angle, order=0, reshape=False)
232 | return image, label
233 |
234 |
235 | class ToTensor:
236 | """
237 | Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
238 | """
239 |
240 | def __call__(self, sample):
241 | return {'image': F.to_tensor(sample['image']),
242 | 'label': F.to_tensor(sample['label'])}
243 |
244 | def __repr__(self):
245 | return self.__class__.__name__ + '()'
246 |
247 |
248 | class ToRGB:
249 |
250 | def __call__(self, sample):
251 | if sample['image'].shape[0] == 1:
252 | sample['image'] = sample['image'].repeat(3, 1, 1)
253 | return sample
254 |
255 | def __repr__(self):
256 | return self.__class__.__name__ + '()'
257 |
258 |
259 | class ConvertImageDtype(torch.nn.Module):
260 |
261 | def __init__(self, dtype: torch.dtype) -> None:
262 | super().__init__()
263 | self.dtype = dtype
264 |
265 | def forward(self, sample):
266 | sample['image'] = F.convert_image_dtype(sample['image'], self.dtype)
267 | sample['label'] = F.convert_image_dtype(sample['label'], self.dtype)
268 | return sample
269 |
270 |
271 | class ToPILImage:
272 |
273 | def __init__(self, mode=None):
274 | self.mode = mode
275 |
276 | def __call__(self, sample):
277 | sample['image'] = F.to_pil_image(sample['image'], self.mode)
278 | sample['label'] = F.to_pil_image(sample['label'], self.mode)
279 | return sample
280 |
281 | def __repr__(self):
282 | format_string = self.__class__.__name__ + '('
283 | if self.mode is not None:
284 | format_string += 'mode={0}'.format(self.mode)
285 | format_string += ')'
286 | return format_string
287 |
288 |
289 | # 2D transforms
290 | class Normalize(T.Normalize):
291 |
292 | def __init__(self, mean, std, inplace=False):
293 | super().__init__(mean, std, inplace)
294 |
295 | def forward(self, sample):
296 | sample['image'] = F.normalize(sample['image'], self.mean, self.std, self.inplace)
297 | return sample
298 |
299 |
300 | class Resize(T.Resize):
301 |
302 | def __init__(self, size, interpolation=InterpolationMode.BILINEAR):
303 | super().__init__(size, interpolation)
304 |
305 | def forward(self, sample):
306 | sample['image'] = F.resize(sample['image'], self.size, self.interpolation)
307 | sample['label'] = F.resize(sample['label'], self.size, self.interpolation)
308 | return sample
309 |
310 |
311 | class CenterCrop(T.CenterCrop):
312 |
313 | def __init__(self, size):
314 | super().__init__(size)
315 |
316 | def forward(self, sample):
317 | sample['image'] = F.center_crop(sample['image'], self.size)
318 | sample['label'] = F.center_crop(sample['label'], self.size)
319 | return sample
320 |
321 |
322 | class Pad(T.Pad):
323 |
324 | def __init__(self, padding, fill=0, padding_mode="constant"):
325 | super().__init__(padding, fill, padding_mode)
326 |
327 | def forward(self, sample):
328 | sample['label'] = F.pad(sample['image'], self.padding, self.fill, self.padding_mode)
329 | sample['label'] = F.pad(sample['label'], self.padding, self.fill, self.padding_mode)
330 | return sample
331 |
332 |
333 | class RandomCrop(T.RandomCrop):
334 |
335 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
336 | super().__init__(size, padding, pad_if_needed, fill, padding_mode)
337 |
338 | def forward(self, sample):
339 | img = sample['image']
340 | label = sample['label']
341 | if self.padding is not None:
342 | img = F.pad(img, self.padding, self.fill, self.padding_mode)
343 | label = F.pad(label, self.padding, self.fill, self.padding_mode)
344 |
345 | width, height = F._get_image_size(img)
346 | # pad the width if needed
347 | if self.pad_if_needed and width < self.size[1]:
348 | padding = [self.size[1] - width, 0]
349 | img = F.pad(img, padding, self.fill, self.padding_mode)
350 | label = F.pad(label, self.padding, self.fill, self.padding_mode)
351 | # pad the height if needed
352 | if self.pad_if_needed and height < self.size[0]:
353 | padding = [0, self.size[0] - height]
354 | img = F.pad(img, padding, self.fill, self.padding_mode)
355 | label = F.pad(label, self.padding, self.fill, self.padding_mode)
356 |
357 | i, j, h, w = self.get_params(img, self.size)
358 |
359 | sample['image'] = F.crop(img, i, j, h, w)
360 | sample['label'] = F.crop(label, i, j, h, w)
361 | return sample
362 |
363 | def __repr__(self):
364 | return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding)
365 |
366 |
367 | class RandomFlip(torch.nn.Module):
368 |
369 | def __init__(self, p=0.5, direction='horizontal'):
370 | super().__init__()
371 | assert 0 <= p <= 1
372 | assert direction in ['horizontal', 'vertical', None], 'direction should be horizontal, vertical or None'
373 | self.p = p
374 | self.direction = direction
375 |
376 | def forward(self, sample):
377 | if torch.rand(1) < self.p:
378 | img, label = sample['image'], sample['label']
379 | if self.direction == 'horizontal':
380 | sample['image'] = F.hflip(img)
381 | sample['label'] = F.hflip(label)
382 | elif self.direction == 'vertical':
383 | sample['image'] = F.vflip(img)
384 | sample['label'] = F.vflip(label)
385 | else:
386 | if torch.rand(1) < 0.5:
387 | sample['image'] = F.hflip(img)
388 | sample['label'] = F.hflip(label)
389 | else:
390 | sample['image'] = F.vflip(img)
391 | sample['label'] = F.vflip(label)
392 | return sample
393 |
394 | def __repr__(self):
395 | return self.__class__.__name__ + '(p={})'.format(self.p)
396 |
397 |
398 | class RandomResizedCrop(T.RandomResizedCrop):
399 |
400 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR):
401 | super().__init__(size, scale, ratio, interpolation)
402 |
403 |
404 | def forward(self, sample):
405 | img, mask = sample['image'], sample['label']
406 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
407 | sample['image'] = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
408 | sample['label'] = F.resized_crop(mask, i, j, h, w, self.size, self.interpolation)
409 | return sample
410 |
411 |
412 | class RandomRotation(T.RandomRotation):
413 |
414 | def __init__(
415 | self,
416 | degrees,
417 | interpolation=InterpolationMode.NEAREST,
418 | expand=False,
419 | center=None,
420 | fill=0,
421 | p=0.5,
422 | resample=None
423 | ):
424 | super().__init__(degrees, interpolation, expand, center, fill, resample)
425 | self.p = p
426 |
427 | def forward(self, sample):
428 | if torch.rand(1) > self.p:
429 | return sample
430 | img, label = sample['image'], sample['label']
431 | fill = self.fill
432 | if isinstance(img, torch.Tensor):
433 | if isinstance(fill, (int, float)):
434 | fill = [float(fill)] * F._get_image_num_channels(img)
435 | else:
436 | fill = [float(f) for f in fill]
437 | label_fill = self.fill
438 | if isinstance(label, torch.Tensor):
439 | if isinstance(label_fill, (int, float)):
440 | label_fill = [float(label_fill)] * F._get_image_num_channels(label)
441 | else:
442 | label_fill = [float(f) for f in label_fill]
443 | angle = self.get_params(self.degrees)
444 | sample['image'] = F.rotate(img, angle, self.resample, self.expand, self.center, fill)
445 | sample['label'] = F.rotate(label, angle, self.resample, self.expand, self.center, label_fill)
446 | return sample
447 |
448 |
449 | class RandomRotation90(torch.nn.Module):
450 |
451 | def __init__(self, p=0.5):
452 | super().__init__()
453 | self.p=p
454 |
455 | def forward(self, sample):
456 | if torch.rand(1) < self.p:
457 | rot_times = random.randint(0, 4)
458 | sample['image'] = torch.rot90(sample['image'], rot_times, [1, 2])
459 | sample['label'] = torch.rot90(sample['label'], rot_times, [1, 2])
460 | return sample
461 |
462 |
463 | class RandomErasing(T.RandomErasing):
464 |
465 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
466 | super().__init__(p, scale, ratio, value, inplace)
467 |
468 | def forward(self, sample):
469 | if torch.rand(1) < self.p:
470 | img, label = sample['image'], sample['label']
471 | # cast self.value to script acceptable type
472 | if isinstance(self.value, (int, float)):
473 | value = [self.value, ]
474 | elif isinstance(self.value, str):
475 | value = None
476 | elif isinstance(self.value, tuple):
477 | value = list(self.value)
478 | else:
479 | value = self.value
480 |
481 | if value is not None and not (len(value) in (1, img.shape[-3])):
482 | raise ValueError(
483 | "If value is a sequence, it should have either a single value or "
484 | "{} (number of input channels)".format(img.shape[-3])
485 | )
486 |
487 | x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
488 | sample['image'] = F.erase(img, x, y, h, w, v, self.inplace)
489 | sample['label'] = F.erase(label, x, y, h, w, v, self.inplace)
490 | return sample
491 |
492 |
493 | class GaussianBlur(T.GaussianBlur):
494 |
495 | def __init__(self, kernel_size, sigma=(0.1, 2.0), p=0.5):
496 | super().__init__(kernel_size, sigma)
497 | self.p = p
498 |
499 | def forward(self, sample):
500 | if torch.rand(1) < self.p:
501 | sigma = self.get_params(self.sigma[0], self.sigma[1])
502 | sample['image'] = F.gaussian_blur(sample['image'], self.kernel_size, [sigma, sigma])
503 | return sample
504 |
505 |
506 | class RandomGrayscale(T.RandomGrayscale):
507 |
508 | def __init__(self, p=0.1):
509 | super().__init__(p)
510 |
511 | def forward(self, sample):
512 | if torch.rand(1) < self.p:
513 | img = sample['image']
514 | if len(img.shape) == 4:
515 | img = img.permute(3, 0, 1, 2).contiguous()
516 | if img.size(1) == 1:
517 | img = img.repeat(1, 3, 1, 1)
518 | num_output_channels = F._get_image_num_channels(img)
519 | img = F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
520 | if len(img.shape) == 4:
521 | img = img.permute(1, 2, 3, 0).contiguous()
522 | img = img[0].unsqueeze(0)
523 | sample['image'] = img
524 | return sample
525 |
526 | def __repr__(self):
527 | return self.__class__.__name__ + '(p={0})'.format(self.p)
528 |
529 |
530 | class ColorJitter(T.ColorJitter):
531 |
532 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=1.):
533 | super().__init__(brightness, contrast, saturation, hue)
534 | self.p = p
535 |
536 | def forward(self, sample):
537 | if torch.rand(1) < self.p:
538 | img = sample['image']
539 | if len(img.shape) == 4:
540 | img = img.permute(3, 0, 1, 2).contiguous()
541 | elif img.size(0) == 1:
542 | img = img.repeat(3, 1, 1)
543 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
544 | self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
545 |
546 | for fn_id in fn_idx:
547 | if fn_id == 0 and brightness_factor is not None:
548 | img = F.adjust_brightness(img, brightness_factor)
549 | elif fn_id == 1 and contrast_factor is not None:
550 | img = F.adjust_contrast(img, contrast_factor)
551 | elif fn_id == 2 and saturation_factor is not None:
552 | img = F.adjust_saturation(img, saturation_factor)
553 | elif fn_id == 3 and hue_factor is not None:
554 | img = F.adjust_hue(img, hue_factor)
555 |
556 | if len(img.shape) == 4:
557 | img = img.permute(1, 2, 3, 0).contiguous()
558 | img = img[0].unsqueeze(0)
559 | sample['image'] = img
560 | return sample
561 |
562 | def __repr__(self):
563 | format_string = self.__class__.__name__ + '('
564 | format_string += 'brightness={0}'.format(self.brightness)
565 | format_string += ', contrast={0}'.format(self.contrast)
566 | format_string += ', saturation={0}'.format(self.saturation)
567 | format_string += ', hue={0})'.format(self.hue)
568 | return format_string
569 |
--------------------------------------------------------------------------------
/codes/dataloaders/acdc.py:
--------------------------------------------------------------------------------
1 | import h5py
2 |
3 | from torch.utils.data import Dataset, DataLoader
4 | from ._transforms import build_transforms
5 | from ._samplers import TwoStreamBatchSampler
6 |
7 |
8 | class ACDCDataSet(Dataset):
9 |
10 | def __init__(self, root_dir=r'F:/datasets/ACDC/', mode='train', num=None, transforms=None):
11 |
12 | self.root_dir = root_dir
13 | self.mode = mode
14 |
15 | if isinstance(transforms, list):
16 | transforms = build_transforms(transforms)
17 | self.transforms = transforms
18 |
19 | names_file = f'{self.root_dir}/train_slices.list' if self.mode == 'train' else f'{self.root_dir}/val_slices.list'
20 |
21 | with open(names_file, 'r') as f:
22 | self.sample_list = f.readlines()
23 | self.sample_list = [item.replace('\n', '') for item in self.sample_list]
24 |
25 | if num is not None and self.mode == 'train':
26 | self.sample_list = self.sample_list[:num]
27 |
28 | def __len__(self):
29 | return len(self.sample_list)
30 |
31 | def __getitem__(self, idx):
32 | case = self.sample_list[idx]
33 | h5f = h5py.File(f'{self.root_dir}/data/slices/{case}.h5', 'r')
34 | image = h5f['image'][:]
35 | label = h5f['label'][:]
36 | sample = {'image': image, 'label': label}
37 | if self.transforms:
38 | sample = self.transforms(sample)
39 | return sample
40 |
41 |
42 | def get_acdc_loaders(root_dir=r'F:/datasets/ACDC/', labeled_num=7, labeled_bs=12, batch_size=24, batch_size_val=16,
43 | num_workers=4, worker_init_fn=None, train_transforms=None, val_transforms=None):
44 | ref_dict = {"3": 68, "7": 136, "14": 256, "21": 396, "28": 512, "35": 664, "140": 1312}
45 |
46 | db_train = ACDCDataSet(root_dir=root_dir, mode="train", transforms=train_transforms)
47 | db_val = ACDCDataSet(root_dir=root_dir, mode="val", transforms=val_transforms)
48 |
49 | if labeled_bs < batch_size:
50 | labeled_slice = ref_dict[str(labeled_num)]
51 | labeled_idxs = list(range(0, labeled_slice))
52 | unlabeled_idxs = list(range(labeled_slice, len(db_train)))
53 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size - labeled_bs)
54 | train_loader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=num_workers,
55 | pin_memory=True, worker_init_fn=worker_init_fn)
56 | else:
57 | train_loader = DataLoader(db_train, batch_size=batch_size, num_workers=num_workers,
58 | pin_memory=True, worker_init_fn=worker_init_fn)
59 | val_loader = DataLoader(db_val, batch_size=batch_size_val, shuffle=False, num_workers=num_workers)
60 |
61 | return train_loader, val_loader
62 |
--------------------------------------------------------------------------------
/codes/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 | from torch import Tensor
5 | from tensorboardX import SummaryWriter
6 |
7 |
8 | class Logger(object):
9 | def __init__(self,
10 | log_dir,
11 | logger_name='root',
12 | log_level=logging.INFO,
13 | log_file=True,
14 | file_mode='w',
15 | tensorboard=True):
16 |
17 | assert log_dir is not None, 'log_dir is None!'
18 |
19 | self.log_dir = log_dir
20 | if not os.path.exists(log_dir):
21 | os.makedirs(self.log_dir)
22 | self.writer = SummaryWriter(log_dir=f'{self.log_dir}/tensorboard_log') if tensorboard else None
23 | self.logger = self.init_logger(logger_name, log_file, log_level, file_mode)
24 |
25 | def init_logger(self, name, log_file, log_level, file_mode):
26 | logger = logging.getLogger(name)
27 | logger.handlers.clear()
28 | stream_handler = logging.StreamHandler()
29 | handlers = [stream_handler]
30 | if log_file:
31 | log_file = os.path.join(self.log_dir, 'info.log')
32 | file_handler = logging.FileHandler(log_file, file_mode)
33 | handlers.append(file_handler)
34 |
35 | date_format = '%Y-%m-%d %H:%M:%S'
36 | # basic_format = '%(asctime)s-%(name)s-%(levelname)s-%(message)s'
37 | basic_format = '%(asctime)s - %(name)s: %(message)s'
38 | formatter = logging.Formatter(basic_format, date_format)
39 | for handler in handlers:
40 | handler.setFormatter(formatter)
41 | handler.setLevel(log_level)
42 | logger.addHandler(handler)
43 |
44 | logger.setLevel(log_level)
45 | return logger
46 |
47 | def update_scalars(self, ordered_dict, step):
48 | for key, value in ordered_dict.items():
49 | if isinstance(value, Tensor):
50 | ordered_dict[key] = value.item()
51 | self.writer.add_scalar(key, value, step)
52 |
53 | def update_images(self, images_dict, step):
54 | for key, value in images_dict.items():
55 | self.writer.add_image(key, value, step)
56 |
57 | def info(self, *kwargs):
58 | self.logger.info(*kwargs)
59 |
60 | def close(self):
61 | if self.writer is not None:
62 | self.writer.close()
63 | self.logger.handlers.clear()
64 | del self.logger
65 |
--------------------------------------------------------------------------------
/codes/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import *
2 | from .pixel_contras_loss import PixelContrastLoss
3 |
--------------------------------------------------------------------------------
/codes/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from abc import abstractmethod
3 | from torch.nn import (BCELoss, BCEWithLogitsLoss, CrossEntropyLoss)
4 |
5 |
6 | class BaseWeightLoss:
7 | def __init__(self, name='loss', weight=1.) -> None:
8 | super().__init__()
9 | self.name = name
10 | self.weight = weight
11 |
12 | @abstractmethod
13 | def _cal_loss(self, preds, targets, **kwargs):
14 | pass
15 |
16 | def __call__(self, preds, targets, **kwargs):
17 | return self._cal_loss(preds, targets, **kwargs) * self.weight
18 |
19 |
20 | class BCELoss_(BaseWeightLoss):
21 | def __init__(self, name='loss_bce', weight=1., **kwargs) -> None:
22 | super().__init__(name, weight)
23 | self.loss = BCELoss(**kwargs)
24 |
25 | def _cal_loss(self, preds, targets, **kwargs):
26 | return self.loss(preds, targets)
27 |
28 |
29 | class BCEWithLogitsLoss_(BaseWeightLoss):
30 | def __init__(self, name='loss_bce', weight=1., **kwargs) -> None:
31 | super().__init__(name, weight)
32 | self.loss = BCEWithLogitsLoss(**kwargs)
33 |
34 | def _cal_loss(self, preds, targets, **kwargs):
35 | return self.loss(preds, targets)
36 |
37 |
38 | class CrossEntropyLoss_(BaseWeightLoss):
39 | def __init__(self, name='loss_ce', weight=1., **kwargs) -> None:
40 | super().__init__(name, weight)
41 | self.loss = CrossEntropyLoss(**kwargs)
42 |
43 | def _cal_loss(self, preds, targets, **kwargs):
44 | targets = targets.to(torch.long).squeeze(1)
45 | return self.loss(preds, targets)
46 |
47 |
48 | class BinaryDiceLoss_(BaseWeightLoss):
49 | def __init__(self, name='loss_dice', weight=1., smooth=1e-5, softmax=True, **kwargs) -> None:
50 | super().__init__(name, weight)
51 | self.smooth = smooth
52 | self.softmax = softmax
53 |
54 | def _cal_loss(self, preds, targets, **kwargs):
55 | assert preds.shape[0] == targets.shape[0]
56 | if self.softmax:
57 | preds = torch.argmax(torch.softmax(preds, dim=1), dim=1, keepdim=True).to(torch.float32)
58 | intersect = torch.sum(torch.mul(preds, targets))
59 | loss = 1 - (2 * intersect + self.smooth) / (torch.sum(preds.pow(2)) + torch.sum(targets.pow(2)) + self.smooth)
60 | return loss
61 |
62 |
63 | class DiceLoss_(BaseWeightLoss):
64 | def __init__(self, name='loss_dice', weight=1., smooth=1e-5, n_classes=2, class_weight=None, softmax=True,
65 | **kwargs):
66 | super().__init__(name, weight)
67 | self.n_classes = n_classes
68 | self.smooth = smooth
69 | self.class_weight = [1.] * self.n_classes if class_weight is None else class_weight
70 | self.softmax = softmax
71 |
72 | def _one_hot_encoder(self, targets):
73 | target_list = []
74 | for _ in range(self.n_classes):
75 | temp_prob = targets == _ * torch.ones_like(targets)
76 | target_list.append(temp_prob)
77 | output_tensor = torch.cat(target_list, dim=1)
78 | return output_tensor.float()
79 |
80 | def _dice_loss(self, pred, target):
81 | assert pred.shape[0] == target.shape[0]
82 | intersect = torch.sum(torch.mul(pred, target))
83 | loss = 1 - (2 * intersect + self.smooth) / (torch.sum(pred.pow(2)) + torch.sum(target.pow(2)) + self.smooth)
84 | return loss
85 |
86 | def _cal_loss(self, preds, targets, **kwargs):
87 | if self.softmax:
88 | preds = torch.softmax(preds, dim=1)
89 | targets = self._one_hot_encoder(targets)
90 | assert preds.size() == targets.size(), 'pred & target shape do not match'
91 | loss = 0.0
92 | for _ in range(self.n_classes):
93 | dice = self._dice_loss(preds[:, _], targets[:, _])
94 | loss += dice * self.class_weight[_]
95 | return loss / self.n_classes
96 |
--------------------------------------------------------------------------------
/codes/losses/pixel_contras_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from abc import ABC
5 | from torch import nn
6 |
7 |
8 | class PixelContrastLoss(nn.Module, ABC):
9 | def __init__(self,
10 | temperature=0.07,
11 | base_temperature=0.07,
12 | max_samples=1024,
13 | max_views=100,
14 | ignore_index=-1,
15 | device='cuda:0'):
16 | super(PixelContrastLoss, self).__init__()
17 |
18 | self.temperature = temperature
19 | self.base_temperature = base_temperature
20 |
21 | self.ignore_label = ignore_index
22 |
23 | self.max_samples = max_samples
24 | self.max_views = max_views
25 |
26 | self.device = device
27 |
28 | def _anchor_sampling(self, X, y_hat, y):
29 | batch_size, feat_dim = X.shape[0], X.shape[-1]
30 |
31 | classes = []
32 | total_classes = 0
33 | for ii in range(batch_size):
34 | this_y = y_hat[ii]
35 |
36 | this_classes = torch.unique(this_y)
37 | this_classes = [x for x in this_classes if x != self.ignore_label]
38 | this_classes = [x for x in this_classes if (this_y == x).nonzero().shape[0] > self.max_views]
39 |
40 | classes.append(this_classes)
41 | total_classes += len(this_classes)
42 | if total_classes == 0:
43 | return None, None
44 |
45 | n_view = self.max_samples // total_classes
46 | n_view = min(n_view, self.max_views)
47 |
48 | X_ = torch.zeros((total_classes, n_view, feat_dim), dtype=torch.float).to(self.device)
49 | y_ = torch.zeros(total_classes, dtype=torch.float).to(self.device)
50 |
51 | X_ptr = 0
52 | for ii in range(batch_size):
53 | this_y_hat = y_hat[ii]
54 | this_y = y[ii]
55 | this_classes = classes[ii]
56 |
57 | for cls_id in this_classes:
58 | hard_indices = ((this_y_hat == cls_id) & (this_y != cls_id)).nonzero()
59 | easy_indices = ((this_y_hat == cls_id) & (this_y == cls_id)).nonzero()
60 |
61 | num_hard = hard_indices.shape[0]
62 | num_easy = easy_indices.shape[0]
63 |
64 | if num_hard >= n_view / 2 and num_easy >= n_view / 2:
65 | num_hard_keep = n_view // 2
66 | num_easy_keep = n_view - num_hard_keep
67 | elif num_hard >= n_view / 2:
68 | num_easy_keep = num_easy
69 | num_hard_keep = n_view - num_easy_keep
70 | elif num_easy >= n_view / 2:
71 | num_hard_keep = num_hard
72 | num_easy_keep = n_view - num_hard_keep
73 | else:
74 | print('this shoud be never touched! {} {} {}'.format(num_hard, num_easy, n_view))
75 | raise Exception
76 |
77 | perm = torch.randperm(num_hard)
78 | hard_indices = hard_indices[perm[:num_hard_keep]]
79 | perm = torch.randperm(num_easy)
80 | easy_indices = easy_indices[perm[:num_easy_keep]]
81 | indices = torch.cat((hard_indices, easy_indices), dim=0)
82 |
83 | X_[X_ptr, :, :] = X[ii, indices, :].squeeze(1)
84 | y_[X_ptr] = cls_id
85 | X_ptr += 1
86 | return X_, y_
87 |
88 | def _sample_negative(self, Q):
89 | class_num, memory_size, feat_size = Q.shape
90 |
91 | x_ = torch.zeros((class_num * memory_size, feat_size)).float().to(self.device)
92 | y_ = torch.zeros((class_num * memory_size, 1)).float().to(self.device)
93 |
94 | sample_ptr = 0
95 | for c in range(class_num):
96 | if c == 0:
97 | continue
98 | this_q = Q[c, :memory_size, :]
99 | x_[sample_ptr:sample_ptr + memory_size, ...] = this_q
100 | y_[sample_ptr:sample_ptr + memory_size, ...] = c
101 | sample_ptr += memory_size
102 | return x_, y_
103 |
104 | def _contrastive(self, X_anchor, y_anchor, queue=None):
105 | anchor_num, n_view = X_anchor.shape[0], X_anchor.shape[1]
106 |
107 | y_anchor = y_anchor.contiguous().view(-1, 1) # (anchor_num × n_view) × 1
108 | anchor_count = n_view
109 | anchor_feature = torch.cat(torch.unbind(X_anchor, dim=1), dim=0) # (anchor_num × n_view) × feat_dim
110 |
111 | if queue is not None:
112 | X_contrast, y_contrast = self._sample_negative(queue)
113 | y_contrast = y_contrast.contiguous().view(-1, 1)
114 | contrast_count = 1
115 | contrast_feature = X_contrast
116 | else:
117 | y_contrast = y_anchor
118 | contrast_count = n_view
119 | contrast_feature = torch.cat(torch.unbind(X_anchor, dim=1), dim=0)
120 |
121 | # (anchor_num × n_view) × (anchor_num × n_view)
122 | mask = torch.eq(y_anchor, y_contrast.T).float().to(self.device)
123 | # (anchor_num × n_view) × (anchor_num × n_view)
124 | anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
125 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
126 | logits = anchor_dot_contrast - logits_max.detach()
127 | mask = mask.repeat(anchor_count, contrast_count)
128 | neg_mask = 1 - mask
129 |
130 | logits_mask = torch.ones_like(mask).\
131 | scatter_(1, torch.arange(anchor_num * anchor_count).view(-1, 1).to(self.device), 0)
132 | mask = mask * logits_mask
133 |
134 | neg_logits = torch.exp(logits) * neg_mask
135 | neg_logits = neg_logits.sum(1, keepdim=True)
136 |
137 | exp_logits = torch.exp(logits)
138 |
139 | log_prob = logits - torch.log(exp_logits + neg_logits)
140 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-5)
141 |
142 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
143 | loss = loss.mean()
144 | return loss
145 |
146 | def forward(self, feats, labels=None, predict=None, queue=None):
147 | labels = labels.float().clone()
148 | labels = torch.nn.functional.interpolate(labels, (feats.shape[2], feats.shape[3]), mode='nearest')
149 | predict = torch.nn.functional.interpolate(predict, (feats.shape[2], feats.shape[3]), mode='nearest')
150 | labels = labels.long()
151 | assert labels.shape[-1] == feats.shape[-1], '{} {}'.format(labels.shape, feats.shape)
152 |
153 | batch_size = feats.shape[0]
154 |
155 | labels = labels.contiguous().view(batch_size, -1)
156 | predict = predict.contiguous().view(batch_size, -1)
157 | feats = feats.permute(0, 2, 3, 1)
158 | feats = feats.contiguous().view(feats.shape[0], -1, feats.shape[-1])
159 |
160 | # feats: N×(HW)×C
161 | # labels: N×(HW)
162 | # predict: N×(HW)
163 | feats_, labels_ = self._anchor_sampling(feats, labels, predict)
164 | loss = self._contrastive(feats_, labels_, queue=queue)
165 | return loss
166 |
--------------------------------------------------------------------------------
/codes/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import Dice, Jaccard, HD95, ASD
2 |
--------------------------------------------------------------------------------
/codes/metrics/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | preds.shape: (N,1,H,W) || (N,1,H,W,D) || (N,H,W) || (N,H,W,D)
3 | labels.shape: (N,1,H,W) || (N,1,H,W,D) || (N,H,W) || (N,H,W,D)
4 | """
5 | import torch
6 | from medpy import metric
7 |
8 |
9 | class Dice:
10 |
11 | def __init__(self, name='Dice', class_indexs=[1], class_names=['xx']) -> None:
12 | super().__init__()
13 | self.name = name
14 | self.class_indexs = class_indexs
15 | self.class_names = class_names
16 |
17 | def __call__(self, preds, labels):
18 | res = {}
19 | for class_index, class_name in zip(self.class_indexs, self.class_names):
20 | preds_ = (preds == class_index).to(torch.int)
21 | labels_ = (labels == class_index).to(torch.int)
22 | intersection = (preds_ * labels_).sum()
23 | try:
24 | res[class_name] = (2 * intersection) / (preds_.sum() + labels_.sum()).item()
25 | except ZeroDivisionError:
26 | res[class_name] = 1.0
27 | # res[class_name] = metric.dc(preds_.cpu().numpy(), labels_.cpu().numpy())
28 | return res
29 |
30 |
31 | class Jaccard:
32 | def __init__(self, name='Jaccard', class_indexs=[1], class_names=['xx']) -> None:
33 | super().__init__()
34 | self.name = name
35 | self.class_indexs = class_indexs
36 | self.class_names = class_names
37 |
38 | def __call__(self, preds, labels):
39 | res = {}
40 | for class_index, class_name in zip(self.class_indexs, self.class_names):
41 | preds_ = (preds == class_index).to(torch.int)
42 | labels_ = (labels == class_index).to(torch.int)
43 | intersection = (preds_ * labels_).sum()
44 | union = ((preds_ + labels_) != 0).sum()
45 | res[class_name] = intersection / union
46 | try:
47 | res[class_name] = intersection / union
48 | except ZeroDivisionError:
49 | res[class_name] = 1.0
50 | # res[class_name] = metric.jc(preds_.cpu().numpy(), labels_.cpu().numpy())
51 | return res
52 |
53 |
54 | class HD95:
55 | """
56 | 95th percentile of the Hausdorff Distance.
57 | """
58 |
59 | def __init__(self, name='95HD', class_indexs=[1], class_names=['xx']) -> None:
60 | super().__init__()
61 | self.name = name
62 | self.class_indexs = class_indexs
63 | self.class_names = class_names
64 |
65 | def __call__(self, preds, labels):
66 | if preds.size(1) == 1:
67 | preds = preds.squeeze(1)
68 | if labels.size(1) == 1:
69 | labels = labels.squeeze(1)
70 | res = {}
71 | for class_index, class_name in zip(self.class_indexs, self.class_names):
72 | res[class_name] = 0.
73 | for i in range(preds.size(0)):
74 | preds_ = (preds[i] == class_index).to(torch.int)
75 | labels_ = (labels[i] == class_index).to(torch.int)
76 | if preds_.sum() == 0.:
77 | preds_ = (preds_ == 0).to(torch.int)
78 | res[class_name] += torch.tensor(metric.hd95(preds_.cpu().numpy(), labels_.cpu().numpy()))
79 | res[class_name] /= preds.size(0)
80 | return res
81 |
82 |
83 | class ASD:
84 | """
85 | Average surface distance.
86 | """
87 |
88 | def __init__(self, name='ASD', class_indexs=[1], class_names=['xx']) -> None:
89 | super().__init__()
90 | self.name = name
91 | self.class_indexs = class_indexs
92 | self.class_names = class_names
93 |
94 | def __call__(self, preds, labels):
95 | if preds.size(1) == 1:
96 | preds = preds.squeeze(1)
97 | if labels.size(1) == 1:
98 | labels = labels.squeeze(1)
99 | res = {}
100 | for class_index, class_name in zip(self.class_indexs, self.class_names):
101 | res[class_name] = 0.
102 | for i in range(preds.size(0)):
103 | preds_ = (preds[i] == class_index).to(torch.int)
104 | labels_ = (labels[i] == class_index).to(torch.int)
105 | if preds_.sum() == 0.:
106 | preds_ = (preds_ == 0).to(torch.int)
107 | res[class_name] += torch.tensor(metric.asd(preds_.cpu().numpy(), labels_.cpu().numpy()))
108 | res[class_name] /= preds.size(0)
109 | return res
--------------------------------------------------------------------------------
/codes/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet import UNet
2 | from .unet_tf import UNetTF
3 |
--------------------------------------------------------------------------------
/codes/models/_base.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from ..losses import *
3 |
4 |
5 | class BaseModel2D(nn.Module):
6 |
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def _init_weights(self, **kwargs):
11 | for m in self.modules():
12 | if isinstance(m, nn.Conv2d):
13 | torch.nn.init.kaiming_normal_(m.weight)
14 | elif isinstance(m, nn.BatchNorm2d):
15 | nn.init.constant_(m.weight, 1)
16 | nn.init.constant_(m.bias, 0)
17 |
18 | def inference(self, x, **kwargs):
19 | logits = self(x)
20 | preds = torch.argmax(logits['seg'], dim=1, keepdim=True).to(torch.float)
21 | return preds
22 |
--------------------------------------------------------------------------------
/codes/models/swin_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | from einops import rearrange
5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 |
7 |
8 | class Mlp(nn.Module):
9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
10 | super().__init__()
11 | out_features = out_features or in_features
12 | hidden_features = hidden_features or in_features
13 | self.fc1 = nn.Linear(in_features, hidden_features)
14 | self.act = act_layer()
15 | self.fc2 = nn.Linear(hidden_features, out_features)
16 | self.drop = nn.Dropout(drop)
17 |
18 | def forward(self, x):
19 | x = self.fc1(x)
20 | x = self.act(x)
21 | x = self.drop(x)
22 | x = self.fc2(x)
23 | x = self.drop(x)
24 | return x
25 |
26 |
27 | def window_partition(x, window_size):
28 | """
29 | Args:
30 | x: (B, H, W, C)
31 | window_size (int): window size
32 | Returns:
33 | windows: (num_windows*B, window_size, window_size, C)
34 | """
35 | B, H, W, C = x.shape
36 | x = x.view(B, H // window_size, window_size,
37 | W // window_size, window_size, C)
38 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
39 | ).view(-1, window_size, window_size, C)
40 | return windows
41 |
42 |
43 | def window_reverse(windows, window_size, H, W):
44 | """
45 | Args:
46 | windows: (num_windows*B, window_size, window_size, C)
47 | window_size (int): Window size
48 | H (int): Height of image
49 | W (int): Width of image
50 | Returns:
51 | x: (B, H, W, C)
52 | """
53 | B = int(windows.shape[0] / (H * W / window_size / window_size))
54 | x = windows.view(B, H // window_size, W // window_size,
55 | window_size, window_size, -1)
56 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
57 | return x
58 |
59 |
60 | class WindowAttention(nn.Module):
61 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
62 | It supports both of shifted and non-shifted window.
63 | Args:
64 | dim (int): Number of input channels.
65 | window_size (tuple[int]): The height and width of the window.
66 | num_heads (int): Number of attention heads.
67 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
68 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
69 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
70 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
71 | """
72 |
73 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
74 |
75 | super().__init__()
76 | self.dim = dim
77 | self.window_size = window_size # Wh, Ww
78 | self.num_heads = num_heads
79 | head_dim = dim // num_heads
80 | self.scale = qk_scale or head_dim ** -0.5
81 |
82 | # define a parameter table of relative position bias
83 | self.relative_position_bias_table = nn.Parameter(
84 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
85 |
86 | # get pair-wise relative position index for each token inside the window
87 | coords_h = torch.arange(self.window_size[0])
88 | coords_w = torch.arange(self.window_size[1])
89 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
90 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
91 | relative_coords = coords_flatten[:, :, None] - \
92 | coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
93 | relative_coords = relative_coords.permute(
94 | 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
95 | relative_coords[:, :, 0] += self.window_size[0] - \
96 | 1 # shift to start from 0
97 | relative_coords[:, :, 1] += self.window_size[1] - 1
98 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
99 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
100 | self.register_buffer("relative_position_index",
101 | relative_position_index)
102 |
103 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
104 | self.attn_drop = nn.Dropout(attn_drop)
105 | self.proj = nn.Linear(dim, dim)
106 | self.proj_drop = nn.Dropout(proj_drop)
107 |
108 | trunc_normal_(self.relative_position_bias_table, std=.02)
109 | self.softmax = nn.Softmax(dim=-1)
110 |
111 | def forward(self, x, mask=None):
112 | """
113 | Args:
114 | x: input features with shape of (num_windows*B, N, C)
115 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
116 | """
117 | B_, N, C = x.shape
118 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C //
119 | self.num_heads).permute(2, 0, 3, 1, 4)
120 | # make torchscript happy (cannot use tensor as tuple)
121 | q, k, v = qkv[0], qkv[1], qkv[2]
122 |
123 | q = q * self.scale
124 | attn = (q @ k.transpose(-2, -1))
125 |
126 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
127 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
128 | relative_position_bias = relative_position_bias.permute(
129 | 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
130 | attn = attn + relative_position_bias.unsqueeze(0)
131 |
132 | if mask is not None:
133 | nW = mask.shape[0]
134 | attn = attn.view(B_ // nW, nW, self.num_heads, N,
135 | N) + mask.unsqueeze(1).unsqueeze(0)
136 | attn = attn.view(-1, self.num_heads, N, N)
137 | attn = self.softmax(attn)
138 | else:
139 | attn = self.softmax(attn)
140 |
141 | attn = self.attn_drop(attn)
142 |
143 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
144 | x = self.proj(x)
145 | x = self.proj_drop(x)
146 | return x
147 |
148 | def extra_repr(self) -> str:
149 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
150 |
151 | def flops(self, N):
152 | # calculate flops for 1 window with token length of N
153 | flops = 0
154 | # qkv = self.qkv(x)
155 | flops += N * self.dim * 3 * self.dim
156 | # attn = (q @ k.transpose(-2, -1))
157 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
158 | # x = (attn @ v)
159 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
160 | # x = self.proj(x)
161 | flops += N * self.dim * self.dim
162 | return flops
163 |
164 |
165 | class SwinTransformerBlock(nn.Module):
166 | r""" Swin Transformer Block.
167 | Args:
168 | dim (int): Number of input channels.
169 | input_resolution (tuple[int]): Input resulotion.
170 | num_heads (int): Number of attention heads.
171 | window_size (int): Window size.
172 | shift_size (int): Shift size for SW-MSA.
173 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
174 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
175 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
176 | drop (float, optional): Dropout rate. Default: 0.0
177 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
178 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
179 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
180 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
181 | """
182 |
183 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
184 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
185 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
186 | super().__init__()
187 | self.dim = dim
188 | self.input_resolution = input_resolution
189 | self.num_heads = num_heads
190 | self.window_size = window_size
191 | self.shift_size = shift_size
192 | self.mlp_ratio = mlp_ratio
193 | if min(self.input_resolution) <= self.window_size:
194 | # if window size is larger than input resolution, we don't partition windows
195 | self.shift_size = 0
196 | self.window_size = min(self.input_resolution)
197 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
198 |
199 | self.norm1 = norm_layer(dim)
200 | self.attn = WindowAttention(
201 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
202 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
203 |
204 | self.drop_path = DropPath(
205 | drop_path) if drop_path > 0. else nn.Identity()
206 | self.norm2 = norm_layer(dim)
207 | mlp_hidden_dim = int(dim * mlp_ratio)
208 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
209 | act_layer=act_layer, drop=drop)
210 |
211 | if self.shift_size > 0:
212 | # calculate attention mask for SW-MSA
213 | H, W = self.input_resolution
214 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
215 | h_slices = (slice(0, -self.window_size),
216 | slice(-self.window_size, -self.shift_size),
217 | slice(-self.shift_size, None))
218 | w_slices = (slice(0, -self.window_size),
219 | slice(-self.window_size, -self.shift_size),
220 | slice(-self.shift_size, None))
221 | cnt = 0
222 | for h in h_slices:
223 | for w in w_slices:
224 | img_mask[:, h, w, :] = cnt
225 | cnt += 1
226 |
227 | # nW, window_size, window_size, 1
228 | mask_windows = window_partition(img_mask, self.window_size)
229 | mask_windows = mask_windows.view(-1,
230 | self.window_size * self.window_size)
231 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
232 | attn_mask = attn_mask.masked_fill(
233 | attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
234 | else:
235 | attn_mask = None
236 |
237 | self.register_buffer("attn_mask", attn_mask)
238 |
239 | def forward(self, x):
240 | H, W = self.input_resolution
241 | B, L, C = x.shape
242 | assert L == H * W, "input feature has wrong size"
243 |
244 | shortcut = x
245 | x = self.norm1(x)
246 | x = x.view(B, H, W, C)
247 |
248 | # cyclic shift
249 | if self.shift_size > 0:
250 | shifted_x = torch.roll(
251 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
252 | else:
253 | shifted_x = x
254 |
255 | # partition windows
256 | # nW*B, window_size, window_size, C
257 | x_windows = window_partition(shifted_x, self.window_size)
258 | # nW*B, window_size*window_size, C
259 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
260 |
261 | # W-MSA/SW-MSA
262 | # nW*B, window_size*window_size, C
263 | attn_windows = self.attn(x_windows, mask=self.attn_mask)
264 |
265 | # merge windows
266 | attn_windows = attn_windows.view(-1,
267 | self.window_size, self.window_size, C)
268 | shifted_x = window_reverse(
269 | attn_windows, self.window_size, H, W) # B H' W' C
270 |
271 | # reverse cyclic shift
272 | if self.shift_size > 0:
273 | x = torch.roll(shifted_x, shifts=(
274 | self.shift_size, self.shift_size), dims=(1, 2))
275 | else:
276 | x = shifted_x
277 | x = x.view(B, H * W, C)
278 |
279 | # FFN
280 | x = shortcut + self.drop_path(x)
281 | x = x + self.drop_path(self.mlp(self.norm2(x)))
282 |
283 | return x
284 |
285 | def extra_repr(self) -> str:
286 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
287 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
288 |
289 | def flops(self):
290 | flops = 0
291 | H, W = self.input_resolution
292 | # norm1
293 | flops += self.dim * H * W
294 | # W-MSA/SW-MSA
295 | nW = H * W / self.window_size / self.window_size
296 | flops += nW * self.attn.flops(self.window_size * self.window_size)
297 | # mlp
298 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
299 | # norm2
300 | flops += self.dim * H * W
301 | return flops
302 |
303 |
304 | class PatchMerging(nn.Module):
305 | r""" Patch Merging Layer.
306 | Args:
307 | input_resolution (tuple[int]): Resolution of input feature.
308 | dim (int): Number of input channels.
309 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
310 | """
311 |
312 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
313 | super().__init__()
314 | self.input_resolution = input_resolution
315 | self.dim = dim
316 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
317 | self.norm = norm_layer(4 * dim)
318 |
319 | def forward(self, x):
320 | """
321 | x: B, H*W, C
322 | """
323 | H, W = self.input_resolution
324 | B, L, C = x.shape
325 | assert L == H * W, "input feature has wrong size"
326 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
327 |
328 | x = x.view(B, H, W, C)
329 |
330 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
331 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
332 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
333 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
334 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
335 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
336 |
337 | x = self.norm(x)
338 | x = self.reduction(x)
339 |
340 | return x
341 |
342 | def extra_repr(self) -> str:
343 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
344 |
345 | def flops(self):
346 | H, W = self.input_resolution
347 | flops = H * W * self.dim
348 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
349 | return flops
350 |
351 |
352 | class PatchExpand(nn.Module):
353 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
354 | super().__init__()
355 | self.input_resolution = input_resolution
356 | self.dim = dim
357 | self.expand = nn.Linear(
358 | dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
359 | self.norm = norm_layer(dim // dim_scale)
360 |
361 | def forward(self, x):
362 | """
363 | x: B, H*W, C
364 | """
365 | H, W = self.input_resolution
366 | x = self.expand(x)
367 | B, L, C = x.shape
368 | assert L == H * W, "input feature has wrong size"
369 |
370 | x = x.view(B, H, W, C)
371 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c',
372 | p1=2, p2=2, c=C // 4)
373 | x = x.view(B, -1, C // 4)
374 | x = self.norm(x)
375 |
376 | return x
377 |
378 |
379 | class FinalPatchExpand_X4(nn.Module):
380 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
381 | super().__init__()
382 | self.input_resolution = input_resolution
383 | self.dim = dim
384 | self.dim_scale = dim_scale
385 | self.expand = nn.Linear(dim, 16 * dim, bias=False)
386 | self.output_dim = dim
387 | self.norm = norm_layer(self.output_dim)
388 |
389 | def forward(self, x):
390 | """
391 | x: B, H*W, C
392 | """
393 | H, W = self.input_resolution
394 | x = self.expand(x)
395 | B, L, C = x.shape
396 | assert L == H * W, "input feature has wrong size"
397 |
398 | x = x.view(B, H, W, C)
399 | x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c',
400 | p1=self.dim_scale, p2=self.dim_scale, c=C // (self.dim_scale ** 2))
401 | x = x.view(B, -1, self.output_dim)
402 | x = self.norm(x)
403 |
404 | return x
405 |
406 |
407 | class BasicLayer(nn.Module):
408 | """ A basic Swin Transformer layer for one stage.
409 | Args:
410 | dim (int): Number of input channels.
411 | input_resolution (tuple[int]): Input resolution.
412 | depth (int): Number of blocks.
413 | num_heads (int): Number of attention heads.
414 | window_size (int): Local window size.
415 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
416 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
417 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
418 | drop (float, optional): Dropout rate. Default: 0.0
419 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
420 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
421 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
422 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
423 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
424 | """
425 |
426 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
427 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
428 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
429 |
430 | super().__init__()
431 | self.dim = dim
432 | self.input_resolution = input_resolution
433 | self.depth = depth
434 | self.use_checkpoint = use_checkpoint
435 |
436 | # build blocks
437 | self.blocks = nn.ModuleList([
438 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
439 | num_heads=num_heads, window_size=window_size,
440 | shift_size=0 if (
441 | i % 2 == 0) else window_size // 2,
442 | mlp_ratio=mlp_ratio,
443 | qkv_bias=qkv_bias, qk_scale=qk_scale,
444 | drop=drop, attn_drop=attn_drop,
445 | drop_path=drop_path[i] if isinstance(
446 | drop_path, list) else drop_path,
447 | norm_layer=norm_layer)
448 | for i in range(depth)])
449 |
450 | # patch merging layer
451 | if downsample is not None:
452 | self.downsample = downsample(
453 | input_resolution, dim=dim, norm_layer=norm_layer)
454 | else:
455 | self.downsample = None
456 |
457 | def forward(self, x):
458 | for blk in self.blocks:
459 | if self.use_checkpoint:
460 | x = checkpoint.checkpoint(blk, x)
461 | else:
462 | x = blk(x)
463 | if self.downsample is not None:
464 | x = self.downsample(x)
465 | return x
466 |
467 | def extra_repr(self) -> str:
468 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
469 |
470 | def flops(self):
471 | flops = 0
472 | for blk in self.blocks:
473 | flops += blk.flops()
474 | if self.downsample is not None:
475 | flops += self.downsample.flops()
476 | return flops
477 |
478 |
479 | class BasicLayer_up(nn.Module):
480 | """ A basic Swin Transformer layer for one stage.
481 | Args:
482 | dim (int): Number of input channels.
483 | input_resolution (tuple[int]): Input resolution.
484 | depth (int): Number of blocks.
485 | num_heads (int): Number of attention heads.
486 | window_size (int): Local window size.
487 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
488 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
489 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
490 | drop (float, optional): Dropout rate. Default: 0.0
491 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
492 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
493 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
494 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
495 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
496 | """
497 |
498 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
499 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
500 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
501 |
502 | super().__init__()
503 | self.dim = dim
504 | self.input_resolution = input_resolution
505 | self.depth = depth
506 | self.use_checkpoint = use_checkpoint
507 |
508 | # build blocks
509 | self.blocks = nn.ModuleList([
510 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
511 | num_heads=num_heads, window_size=window_size,
512 | shift_size=0 if (
513 | i % 2 == 0) else window_size // 2,
514 | mlp_ratio=mlp_ratio,
515 | qkv_bias=qkv_bias, qk_scale=qk_scale,
516 | drop=drop, attn_drop=attn_drop,
517 | drop_path=drop_path[i] if isinstance(
518 | drop_path, list) else drop_path,
519 | norm_layer=norm_layer)
520 | for i in range(depth)])
521 |
522 | # patch merging layer
523 | if upsample is not None:
524 | self.upsample = PatchExpand(
525 | input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
526 | else:
527 | self.upsample = None
528 |
529 | def forward(self, x):
530 | for blk in self.blocks:
531 | if self.use_checkpoint:
532 | x = checkpoint.checkpoint(blk, x)
533 | else:
534 | x = blk(x)
535 | if self.upsample is not None:
536 | x = self.upsample(x)
537 | return x
538 |
539 |
540 | class PatchEmbed(nn.Module):
541 | r""" Image to Patch Embedding
542 | Args:
543 | img_size (int): Image size. Default: 224.
544 | patch_size (int): Patch token size. Default: 4.
545 | in_chans (int): Number of input image channels. Default: 3.
546 | embed_dim (int): Number of linear projection output channels. Default: 96.
547 | norm_layer (nn.Module, optional): Normalization layer. Default: None
548 | """
549 |
550 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
551 | super().__init__()
552 | img_size = to_2tuple(img_size)
553 | patch_size = to_2tuple(patch_size)
554 | patches_resolution = [img_size[0] //
555 | patch_size[0], img_size[1] // patch_size[1]]
556 | self.img_size = img_size
557 | self.patch_size = patch_size
558 | self.patches_resolution = patches_resolution
559 | self.num_patches = patches_resolution[0] * patches_resolution[1]
560 |
561 | self.in_chans = in_chans
562 | self.embed_dim = embed_dim
563 |
564 | self.proj = nn.Conv2d(in_chans, embed_dim,
565 | kernel_size=patch_size, stride=patch_size)
566 | if norm_layer is not None:
567 | self.norm = norm_layer(embed_dim)
568 | else:
569 | self.norm = None
570 |
571 | def forward(self, x):
572 | B, C, H, W = x.shape
573 | # FIXME look at relaxing size constraints
574 | assert H == self.img_size[0] and W == self.img_size[1], \
575 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
576 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
577 | if self.norm is not None:
578 | x = self.norm(x)
579 | return x
580 |
581 | def flops(self):
582 | Ho, Wo = self.patches_resolution
583 | flops = Ho * Wo * self.embed_dim * self.in_chans * \
584 | (self.patch_size[0] * self.patch_size[1])
585 | if self.norm is not None:
586 | flops += Ho * Wo * self.embed_dim
587 | return flops
588 |
589 |
590 | class SwinTransformerSys(nn.Module):
591 | r""" Swin Transformer
592 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
593 | https://arxiv.org/pdf/2103.14030
594 | Args:
595 | img_size (int | tuple(int)): Input image size. Default 224
596 | patch_size (int | tuple(int)): Patch size. Default: 4
597 | in_chans (int): Number of input image channels. Default: 3
598 | num_classes (int): Number of classes for classification head. Default: 1000
599 | embed_dim (int): Patch embedding dimension. Default: 96
600 | depths (tuple(int)): Depth of each Swin Transformer layer.
601 | num_heads (tuple(int)): Number of attention heads in different layers.
602 | window_size (int): Window size. Default: 7
603 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
604 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
605 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
606 | drop_rate (float): Dropout rate. Default: 0
607 | attn_drop_rate (float): Attention dropout rate. Default: 0
608 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
609 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
610 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
611 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
612 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
613 | """
614 |
615 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
616 | embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
617 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
618 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
619 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
620 | use_checkpoint=False, final_upsample="expand_first", **kwargs):
621 | super().__init__()
622 |
623 | print(
624 | "SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(
625 | depths,
626 | depths_decoder, drop_path_rate, num_classes))
627 |
628 | self.num_classes = num_classes
629 | self.num_layers = len(depths)
630 | self.embed_dim = embed_dim
631 | self.ape = ape
632 | self.patch_norm = patch_norm
633 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
634 | self.num_features_up = int(embed_dim * 2)
635 | self.mlp_ratio = mlp_ratio
636 | self.final_upsample = final_upsample
637 |
638 | # split image into non-overlapping patches
639 | self.patch_embed = PatchEmbed(
640 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
641 | norm_layer=norm_layer if self.patch_norm else None)
642 | num_patches = self.patch_embed.num_patches
643 | patches_resolution = self.patch_embed.patches_resolution
644 | self.patches_resolution = patches_resolution
645 |
646 | # absolute position embedding
647 | if self.ape:
648 | self.absolute_pos_embed = nn.Parameter(
649 | torch.zeros(1, num_patches, embed_dim))
650 | trunc_normal_(self.absolute_pos_embed, std=.02)
651 |
652 | self.pos_drop = nn.Dropout(p=drop_rate)
653 |
654 | # stochastic depth
655 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
656 | sum(depths))] # stochastic depth decay rule
657 |
658 | # build encoder and bottleneck layers
659 | self.layers = nn.ModuleList()
660 | for i_layer in range(self.num_layers):
661 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
662 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
663 | patches_resolution[1] // (2 ** i_layer)),
664 | depth=depths[i_layer],
665 | num_heads=num_heads[i_layer],
666 | window_size=window_size,
667 | mlp_ratio=self.mlp_ratio,
668 | qkv_bias=qkv_bias, qk_scale=qk_scale,
669 | drop=drop_rate, attn_drop=attn_drop_rate,
670 | drop_path=dpr[sum(depths[:i_layer]):sum(
671 | depths[:i_layer + 1])],
672 | norm_layer=norm_layer,
673 | downsample=PatchMerging if (
674 | i_layer < self.num_layers - 1) else None,
675 | use_checkpoint=use_checkpoint)
676 | self.layers.append(layer)
677 |
678 | # build decoder layers
679 | self.layers_up = nn.ModuleList()
680 | self.concat_back_dim = nn.ModuleList()
681 | for i_layer in range(self.num_layers):
682 | concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
683 | int(embed_dim * 2 ** (
684 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
685 | if i_layer == 0:
686 | layer_up = PatchExpand(
687 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
688 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
689 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer)
690 | else:
691 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
692 | input_resolution=(
693 | patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
694 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
695 | depth=depths[(
696 | self.num_layers - 1 - i_layer)],
697 | num_heads=num_heads[(
698 | self.num_layers - 1 - i_layer)],
699 | window_size=window_size,
700 | mlp_ratio=self.mlp_ratio,
701 | qkv_bias=qkv_bias, qk_scale=qk_scale,
702 | drop=drop_rate, attn_drop=attn_drop_rate,
703 | drop_path=dpr[sum(depths[:(
704 | self.num_layers - 1 - i_layer)]):sum(
705 | depths[:(self.num_layers - 1 - i_layer) + 1])],
706 | norm_layer=norm_layer,
707 | upsample=PatchExpand if (
708 | i_layer < self.num_layers - 1) else None,
709 | use_checkpoint=use_checkpoint)
710 | self.layers_up.append(layer_up)
711 | self.concat_back_dim.append(concat_linear)
712 |
713 | self.norm = norm_layer(self.num_features)
714 | self.norm_up = norm_layer(self.embed_dim)
715 |
716 | if self.final_upsample == "expand_first":
717 | # print("---final upsample expand_first---")
718 | self.up = FinalPatchExpand_X4(input_resolution=(
719 | img_size // patch_size, img_size // patch_size), dim_scale=4, dim=embed_dim)
720 | self.output = nn.Conv2d(
721 | in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
722 |
723 | self.apply(self._init_weights)
724 |
725 | def _init_weights(self, m):
726 | if isinstance(m, nn.Linear):
727 | trunc_normal_(m.weight, std=.02)
728 | if m.bias is not None:
729 | nn.init.constant_(m.bias, 0)
730 | elif isinstance(m, nn.LayerNorm):
731 | nn.init.constant_(m.bias, 0)
732 | nn.init.constant_(m.weight, 1.0)
733 |
734 | @torch.jit.ignore
735 | def no_weight_decay(self):
736 | return {'absolute_pos_embed'}
737 |
738 | @torch.jit.ignore
739 | def no_weight_decay_keywords(self):
740 | return {'relative_position_bias_table'}
741 |
742 | # Encoder and Bottleneck
743 | def forward_features(self, x):
744 | x = self.patch_embed(x)
745 | if self.ape:
746 | x = x + self.absolute_pos_embed
747 | x = self.pos_drop(x)
748 | x_downsample = []
749 |
750 | for layer in self.layers:
751 | x_downsample.append(x)
752 | x = layer(x)
753 |
754 | x = self.norm(x) # B L C
755 |
756 | return x, x_downsample
757 |
758 | # Dencoder and Skip connection
759 | def forward_up_features(self, x, x_downsample):
760 | for inx, layer_up in enumerate(self.layers_up):
761 | if inx == 0:
762 | x = layer_up(x)
763 | else:
764 | x = torch.cat([x, x_downsample[3 - inx]], -1)
765 | x = self.concat_back_dim[inx](x)
766 | x = layer_up(x)
767 |
768 | x = self.norm_up(x) # B L C
769 |
770 | return x
771 |
772 | def up_x4(self, x):
773 | H, W = self.patches_resolution
774 | B, L, C = x.shape
775 | assert L == H * W, "input features has wrong size"
776 |
777 | if self.final_upsample == "expand_first":
778 | x = self.up(x)
779 | x = x.view(B, 4 * H, 4 * W, -1)
780 | x = x.permute(0, 3, 1, 2) # B,C,H,W
781 | x = self.output(x)
782 |
783 | return x
784 |
785 | def forward(self, x):
786 | x, x_downsample = self.forward_features(x)
787 | print(x.shape)
788 | for _ in x_downsample:
789 | print(_.shape)
790 | x = self.forward_up_features(x, x_downsample)
791 | x = self.up_x4(x)
792 |
793 | return x
794 |
795 | def flops(self):
796 | flops = 0
797 | flops += self.patch_embed.flops()
798 | for i, layer in enumerate(self.layers):
799 | flops += layer.flops()
800 | flops += self.num_features * \
801 | self.patches_resolution[0] * \
802 | self.patches_resolution[1] // (2 ** self.num_layers)
803 | flops += self.num_features * self.num_classes
804 | return flops
805 |
806 |
807 | class SwinTransDecoder(nn.Module):
808 | def __init__(self,
809 | classes=2,
810 | embed_dim=96,
811 | norm_layer=nn.LayerNorm,
812 | img_size=224,
813 | patch_size=4,
814 | depths=[2, 2, 2, 2],
815 | num_heads=[3, 6, 12, 24],
816 | window_size=7,
817 | qkv_bias=True,
818 | qk_scale=None,
819 | drop_rate=0.,
820 | attn_drop_rate=0.,
821 | use_checkpoint=False,
822 | ape=True,
823 | mlp_ratio=4.,
824 | drop_path_rate=0.1,
825 | final_upsample="expand_first",
826 | patches_resolution=[56, 56],
827 | encoder_channels=[256, 512, 1024, 2048]):
828 | super().__init__(
829 | )
830 | self.patches_resolution = patches_resolution
831 | self.ape = ape
832 | self.final_upsample = final_upsample
833 | self.num_layers = len(depths)
834 | self.mlp_ratio = mlp_ratio
835 |
836 | if self.ape:
837 | self.absolute_pos_embeds = [
838 | nn.Parameter(torch.zeros(1, 3136, 96)),
839 | nn.Parameter(torch.zeros(1, 784, 192)),
840 | nn.Parameter(torch.zeros(1, 196, 384)),
841 | nn.Parameter(torch.zeros(1, 49, 768))
842 | ]
843 | for abs in self.absolute_pos_embeds:
844 | trunc_normal_(abs, std=.02)
845 | self.pos_drop = nn.Dropout(p=drop_rate)
846 |
847 | # build decoder layers
848 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
849 | self.layers_up = nn.ModuleList()
850 | self.concat_back_dim = nn.ModuleList()
851 | for i_layer in range(self.num_layers):
852 | concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
853 | int(embed_dim * 2 ** (
854 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
855 | if i_layer == 0:
856 | layer_up = PatchExpand(
857 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
858 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
859 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer)
860 | else:
861 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
862 | input_resolution=(
863 | patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
864 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
865 | depth=depths[(self.num_layers - 1 - i_layer)],
866 | num_heads=num_heads[(self.num_layers - 1 - i_layer)],
867 | window_size=window_size,
868 | mlp_ratio=self.mlp_ratio,
869 | qkv_bias=qkv_bias, qk_scale=qk_scale,
870 | drop=drop_rate, attn_drop=attn_drop_rate,
871 | drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum(
872 | depths[:(self.num_layers - 1 - i_layer) + 1])],
873 | norm_layer=norm_layer,
874 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
875 | use_checkpoint=use_checkpoint)
876 | self.layers_up.append(layer_up)
877 | self.concat_back_dim.append(concat_linear)
878 |
879 | self.patches_embed = nn.ModuleList([
880 | PatchEmbed(img_size=56, patch_size=1, in_chans=encoder_channels[-4], embed_dim=96, norm_layer=norm_layer),
881 | PatchEmbed(img_size=28, patch_size=1, in_chans=encoder_channels[-3], embed_dim=192, norm_layer=norm_layer),
882 | PatchEmbed(img_size=14, patch_size=1, in_chans=encoder_channels[-2], embed_dim=384, norm_layer=norm_layer),
883 | PatchEmbed(img_size=7, patch_size=1, in_chans=encoder_channels[-1], embed_dim=768, norm_layer=norm_layer),
884 | ])
885 |
886 | self.norm = norm_layer(768)
887 | self.norm_up = norm_layer(96)
888 |
889 | if self.final_upsample == "expand_first":
890 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size),
891 | dim_scale=4, dim=embed_dim)
892 | self.output = nn.Conv2d(in_channels=embed_dim, out_channels=classes, kernel_size=1, bias=False)
893 |
894 | self.apply(self._init_weights)
895 |
896 | def _init_weights(self, m):
897 | if isinstance(m, nn.Linear):
898 | trunc_normal_(m.weight, std=.02)
899 | if isinstance(m, nn.Linear) and m.bias is not None:
900 | nn.init.constant_(m.bias, 0)
901 | elif isinstance(m, nn.LayerNorm):
902 | nn.init.constant_(m.bias, 0)
903 | nn.init.constant_(m.weight, 1.0)
904 |
905 | def forward_up_features(self, x, x_downsample):
906 | for inx, layer_up in enumerate(self.layers_up):
907 | if inx == 0:
908 | x = layer_up(x)
909 | else:
910 | x = torch.cat([x, x_downsample[3 - inx]], -1)
911 | x = self.concat_back_dim[inx](x)
912 | x = layer_up(x)
913 | x = self.norm_up(x) # B L C
914 | return x
915 |
916 | def up_x4(self, x):
917 | H, W = self.patches_resolution
918 | B, L, C = x.shape
919 | assert L == H * W, "input features has wrong size"
920 |
921 | if self.final_upsample == "expand_first":
922 | x = self.up(x)
923 | x = x.view(B, 4 * H, 4 * W, -1)
924 | x = x.permute(0, 3, 1, 2) # B,C,H,W
925 | x = self.output(x)
926 | return x
927 |
928 | def forward(self, features, device):
929 | patches = []
930 | for i, f in enumerate(features[2:]):
931 | p_embed = self.patches_embed[i](f)
932 | if self.ape:
933 | p_embed = p_embed + self.absolute_pos_embeds[i].to(device)
934 | p_embed = self.pos_drop(p_embed)
935 | patches.append(p_embed)
936 | p = self.forward_up_features(patches[-1], patches)
937 | return self.up_x4(p)
938 |
--------------------------------------------------------------------------------
/codes/models/unet.py:
--------------------------------------------------------------------------------
1 | from ._base import BaseModel2D
2 | from typing import Optional, Union, List
3 | from segmentation_models_pytorch.encoders import get_encoder
4 | from segmentation_models_pytorch.base import SegmentationModel, SegmentationHead, ClassificationHead
5 | from segmentation_models_pytorch.unet.decoder import UnetDecoder
6 |
7 |
8 | class Unet(SegmentationModel):
9 |
10 | def __init__(
11 | self,
12 | encoder_name: str = "resnet34",
13 | encoder_depth: int = 5,
14 | encoder_weights: Optional[str] = "imagenet",
15 | decoder_use_batchnorm: bool = True,
16 | decoder_channels: List[int] = (256, 128, 64, 32, 16),
17 | decoder_attention_type: Optional[str] = None,
18 | in_channels: int = 3,
19 | classes: int = 1,
20 | activation: Optional[Union[str, callable]] = None,
21 | aux_params: Optional[dict] = None,
22 | ):
23 | super().__init__()
24 |
25 | self.encoder = get_encoder(
26 | encoder_name,
27 | in_channels=in_channels,
28 | depth=encoder_depth,
29 | weights=encoder_weights,
30 | )
31 |
32 | self.decoder = UnetDecoder(
33 | encoder_channels=self.encoder.out_channels,
34 | decoder_channels=decoder_channels,
35 | n_blocks=encoder_depth,
36 | use_batchnorm=decoder_use_batchnorm,
37 | center=True if encoder_name.startswith("vgg") else False,
38 | attention_type=decoder_attention_type,
39 | )
40 |
41 | self.segmentation_head = SegmentationHead(
42 | in_channels=decoder_channels[-1],
43 | out_channels=classes,
44 | activation=activation,
45 | kernel_size=3,
46 | )
47 |
48 | if aux_params is not None:
49 | self.classification_head = ClassificationHead(
50 | in_channels=self.encoder.out_channels[-1], **aux_params
51 | )
52 | else:
53 | self.classification_head = None
54 |
55 | self.name = "u-{}".format(encoder_name)
56 | self.initialize()
57 |
58 | def forward_features(self, x):
59 | features = self.encoder(x)
60 | return features
61 |
62 |
63 | class UNet(BaseModel2D):
64 |
65 | def __init__(self,
66 | encoder_name="resnet34",
67 | encoder_depth=5,
68 | encoder_weights="imagenet",
69 | decoder_use_batchnorm: bool = True,
70 | decoder_channels=(256, 128, 64, 32, 16),
71 | decoder_attention_type=None,
72 | in_channels=3,
73 | classes=1,
74 | activation=None,
75 | aux_params=None):
76 | super().__init__()
77 |
78 | self.segmentor = Unet(encoder_name=encoder_name,
79 | encoder_depth=encoder_depth,
80 | encoder_weights=encoder_weights,
81 | decoder_use_batchnorm=decoder_use_batchnorm,
82 | decoder_channels=decoder_channels,
83 | decoder_attention_type=decoder_attention_type,
84 | in_channels=in_channels,
85 | classes=classes,
86 | activation=activation,
87 | aux_params=aux_params)
88 | self.num_classes = classes
89 |
90 | def forward(self, x):
91 | return {'seg': self.segmentor(x)}
92 |
93 | def inference_features(self, x, **kwargs):
94 | feats = self.segmentor.forward_features(x)
95 | return {'feats': feats}
96 |
--------------------------------------------------------------------------------
/codes/models/unet_tf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from codes.models.swin_decoder import SwinTransDecoder
6 | from codes.models._base import BaseModel2D
7 | from codes.utils.init import kaiming_normal_init_weight
8 | from segmentation_models_pytorch.encoders import get_encoder
9 | from segmentation_models_pytorch.unet.decoder import UnetDecoder
10 | from segmentation_models_pytorch.base import SegmentationHead, ClassificationHead
11 |
12 |
13 | class EmbeddingHead(nn.Module):
14 | def __init__(self, dim_in, embed_dim=256, embed='convmlp'):
15 | super(EmbeddingHead, self).__init__()
16 |
17 | if embed == 'linear':
18 | self.embed = nn.Conv2d(dim_in, embed_dim, kernel_size=1)
19 | elif embed == 'convmlp':
20 | self.embed = nn.Sequential(
21 | nn.Conv2d(dim_in, dim_in, kernel_size=1),
22 | nn.BatchNorm2d(dim_in),
23 | nn.ReLU(),
24 | nn.Conv2d(dim_in, embed_dim, kernel_size=1)
25 | )
26 |
27 | def forward(self, x):
28 | return F.normalize(self.embed(x), p=2, dim=1)
29 |
30 |
31 | class UNetTF(BaseModel2D):
32 |
33 | def __init__(self,
34 | encoder_name="resnet50",
35 | encoder_depth=5,
36 | encoder_weights="imagenet",
37 | decoder_use_batchnorm=True,
38 | decoder_channels=(256, 128, 64, 32, 16),
39 | decoder_attention_type=None,
40 | in_channels=3,
41 | classes=2,
42 | activation=None,
43 | embed_dim=96,
44 | norm_layer=nn.LayerNorm,
45 | img_size=224,
46 | patch_size=4,
47 | depths=[2, 2, 2, 2],
48 | num_heads=[3, 6, 12, 24],
49 | window_size=7,
50 | qkv_bias=True,
51 | qk_scale=None,
52 | drop_rate=0.,
53 | attn_drop_rate=0.,
54 | use_checkpoint=False,
55 | ape=True,
56 | cls=True,
57 | contrast_embed=False,
58 | contrast_embed_dim=256,
59 | contrast_embed_index=-3,
60 | mlp_ratio=4.,
61 | drop_path_rate=0.1,
62 | final_upsample="expand_first",
63 | patches_resolution=[56, 56]
64 | ):
65 | super().__init__()
66 | self.cls = cls
67 | self.contrast_embed_index = contrast_embed_index
68 |
69 | self.encoder = get_encoder(
70 | encoder_name,
71 | in_channels=in_channels,
72 | depth=encoder_depth,
73 | weights=encoder_weights,
74 | )
75 | encoder_channels = self.encoder.out_channels
76 | self.cnn_decoder = UnetDecoder(
77 | encoder_channels=encoder_channels,
78 | decoder_channels=decoder_channels,
79 | n_blocks=encoder_depth,
80 | use_batchnorm=decoder_use_batchnorm,
81 | center=True if encoder_name.startswith("vgg") else False,
82 | attention_type=decoder_attention_type,
83 | )
84 | self.seg_head = SegmentationHead(
85 | in_channels=decoder_channels[-1],
86 | out_channels=classes,
87 | activation=activation,
88 | kernel_size=3,
89 | )
90 | self.swin_decoder = SwinTransDecoder(classes, embed_dim, norm_layer, img_size, patch_size, depths, num_heads,
91 | window_size, qkv_bias, qk_scale, drop_rate, attn_drop_rate, use_checkpoint,
92 | ape, mlp_ratio, drop_path_rate, final_upsample, patches_resolution,
93 | encoder_channels)
94 |
95 | self.cls_head = ClassificationHead(in_channels=encoder_channels[-1], classes=4) if cls else None
96 | self.embed_head = EmbeddingHead(dim_in=encoder_channels[contrast_embed_index],
97 | embed_dim=contrast_embed_dim) if contrast_embed else None
98 | self._init_weights()
99 |
100 | def _init_weights(self):
101 | kaiming_normal_init_weight(self.cnn_decoder)
102 | kaiming_normal_init_weight(self.seg_head)
103 | if self.cls_head is not None:
104 | kaiming_normal_init_weight(self.cls_head)
105 | if self.embed_head is not None:
106 | kaiming_normal_init_weight(self.embed_head.embed)
107 |
108 | def forward(self, x, device):
109 | features = self.encoder(x)
110 | seg = self.seg_head(self.cnn_decoder(*features))
111 | seg_tf = self.swin_decoder(features, device)
112 |
113 | embedding = self.embed_head(features[self.contrast_embed_index]) if self.embed_head else None
114 | cls = self.cls_head(features[-1]) if self.cls_head else None
115 | return {'seg': seg, 'seg_tf': seg_tf, 'cls': cls, 'embed': embedding}
116 |
117 | def inference(self, x, **kwargs):
118 | features = self.encoder(x)
119 | seg = self.seg_head(self.cnn_decoder(*features))
120 | preds = torch.argmax(seg, dim=1, keepdim=True).to(torch.float)
121 | return preds
122 |
123 | def inference_features(self, x, **kwargs):
124 | features = self.encoder(x)
125 | embedding = self.embed_head(features[self.contrast_embed_index]) if self.embed_head else None
126 | return {'feats': features, 'embed': embedding}
127 |
128 | def inference_tf(self, x, device, **kwargs):
129 | features = self.encoder(x)
130 | seg_tf = self.swin_decoder(features, device)
131 | preds = torch.argmax(seg_tf, dim=1, keepdim=True).to(torch.float)
132 | return preds
133 |
--------------------------------------------------------------------------------
/codes/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .poly import PolyLR
2 |
--------------------------------------------------------------------------------
/codes/schedulers/poly.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 |
4 | class PolyLR(_LRScheduler):
5 | def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6):
6 | self.power = power
7 | self.max_iters = max_iters + 1 # avoid zero lr
8 | self.min_lr = min_lr
9 | self.last_epoch = last_epoch
10 | super(PolyLR, self).__init__(optimizer, last_epoch)
11 |
12 | def get_lr(self):
13 | '''
14 | factor = pow(1 - self.last_epoch / self.max_iters, self.power)
15 | return [(base_lr) * factor for base_lr in self.base_lrs]
16 | '''
17 | return [max(base_lr * (1.0 - self.last_epoch / self.max_iters) ** self.power, self.min_lr)
18 | for base_lr in self.base_lrs]
19 |
20 | def __str__(self):
21 | return f'PolyLR(' \
22 | f'\n\tpower: {self.power}' \
23 | f'\n\tmax_iters: {self.max_iters}' \
24 | f'\n\tmin_lr: {self.min_lr}' \
25 | f'\n\tlast_epoch: {self.last_epoch}\n)'
26 |
--------------------------------------------------------------------------------
/codes/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .mt_trainer import MeanTeacherTrainer
2 | from .supervised_trainer import SupervisedTrainer
3 | from .ugpcl_trainer import UGPCLTrainer
4 |
--------------------------------------------------------------------------------
/codes/trainers/_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import abstractmethod
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from tqdm import tqdm
8 | from prettytable import PrettyTable
9 | from colorama import Fore
10 | from ..utils import ramps
11 | from ..builder import _build_from_cfg, build_model, build_optimizer, build_scheduler
12 |
13 |
14 | class BaseTrainer:
15 |
16 | def __init__(self,
17 | model=None,
18 | optimizer=None,
19 | scheduler=None,
20 | criterions=None,
21 | metrics=None,
22 | logger=None,
23 | device='cuda',
24 | resume_from=None,
25 | labeled_bs=12,
26 | consistency=1.0,
27 | consistency_rampup=40.0,
28 | data_parallel=False,
29 | ckpt_save_path=None,
30 | max_iter=10000,
31 | eval_interval=1000,
32 | save_image_interval=50,
33 | save_ckpt_interval=2000) -> None:
34 | super(BaseTrainer, self).__init__()
35 | self.model = None
36 | # build cfg
37 | if model is not None:
38 | self.model = build_model(model).to(device)
39 | if optimizer is not None:
40 | self.optimizer = build_optimizer(self.model.parameters(), optimizer)
41 | if scheduler is not None:
42 | self.scheduler = build_scheduler(self.optimizer, scheduler)
43 | self.criterions = []
44 | if criterions is not None:
45 | for criterion_cfg in criterions:
46 | self.criterions.append(_build_from_cfg(criterion_cfg))
47 | self.metrics = []
48 | if metrics is not None:
49 | for metric_cfg in metrics:
50 | self.metrics.append(_build_from_cfg(metric_cfg))
51 |
52 | # semi-supervised params
53 | self.labeled_bs = labeled_bs
54 | self.consistency = consistency
55 | self.consistency_rampup = consistency_rampup
56 |
57 | # train params
58 | self.logger = logger
59 | self.device = device
60 | self.data_parallel = data_parallel
61 | self.ckpt_save_path = ckpt_save_path
62 |
63 | self.max_iter = max_iter
64 | self.eval_interval = eval_interval
65 | self.save_image_interval = save_image_interval
66 | self.save_ckpt_interval = save_ckpt_interval
67 |
68 | if self.data_parallel:
69 | self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1])
70 |
71 | if resume_from is not None:
72 | ckpt = torch.load(resume_from)
73 | model.load_state_dict(ckpt['state_dict'])
74 | optimizer.load_state_dict(ckpt['optimizer'])
75 | for state in optimizer.state.values():
76 | for k, v in state.items():
77 | if isinstance(v, torch.Tensor):
78 | state[k] = v.cuda()
79 | scheduler.load_state_dict(ckpt['scheduler'])
80 | self.start_step = ckpt['step']
81 |
82 | logger.info(f'Resume from {resume_from}.')
83 | logger.info(f'Train from step {self.start_step}.')
84 | else:
85 | self.start_step = 0
86 | if self.model is not None:
87 | logger.info(f'\n{self.model}\n')
88 |
89 | logger.info(f'start training...')
90 |
91 | @abstractmethod
92 | def train_step(self, batch_data, step, save_image):
93 | loss = 0.
94 | log_infos, scalars, images = {}, {}, {}
95 | return loss, log_infos, scalars, images
96 |
97 | def val_step(self, batch_data):
98 | data, labels = batch_data['image'].to(self.device), batch_data['label'].to(self.device)
99 | preds = self.model.inference(data)
100 | metric_total_res = {}
101 | for metric in self.metrics:
102 | metric_total_res[metric.name] = metric(preds, labels)
103 | return metric_total_res
104 |
105 | def train(self, train_loader, val_loader):
106 | # iter_train_loader = iter(train_loader)
107 | max_epoch = self.max_iter // len(train_loader) + 1
108 | step = self.start_step
109 | self.model.train()
110 | with tqdm(total=self.max_iter - self.start_step, bar_format='[{elapsed}<{remaining}] ') as pbar:
111 | for _ in range(max_epoch):
112 | for batch_data in train_loader:
113 | save_image = True if (step + 1) % self.save_image_interval == 0 else False
114 |
115 | loss, log_infos, scalars, images = self.train_step(batch_data, step, save_image)
116 |
117 | self.optimizer.zero_grad()
118 | loss.backward()
119 | self.optimizer.step()
120 | self.scheduler.step()
121 |
122 | if (step + 1) % 10 == 0:
123 | scalars.update({'lr': self.scheduler.get_lr()[0]})
124 | log_infos.update({'lr': self.scheduler.get_lr()[0]})
125 | self.logger.update_scalars(scalars, step + 1)
126 | self.logger.info(f'[{step + 1}/{self.max_iter}] {log_infos}')
127 |
128 | if save_image:
129 | self.logger.update_images(images, step + 1)
130 |
131 | if (step + 1) % self.eval_interval == 0:
132 | if val_loader is not None:
133 | val_res, val_scalars, val_table = self.val(val_loader)
134 | self.logger.info(f'val result:\n{val_table.get_string()}')
135 | self.logger.update_scalars(val_scalars, step + 1)
136 | self.model.train()
137 |
138 | if (step + 1) % self.save_ckpt_interval == 0:
139 | if not os.path.exists(self.ckpt_save_path):
140 | os.makedirs(self.ckpt_save_path)
141 | self.save_ckpt(step + 1, f'{self.ckpt_save_path}/iter_{step + 1}.pth')
142 | step += 1
143 | pbar.update(1)
144 | if step >= self.max_iter:
145 | break
146 | if step >= self.max_iter:
147 | break
148 |
149 | if not os.path.exists(self.ckpt_save_path):
150 | os.makedirs(self.ckpt_save_path)
151 | torch.save(self.model.state_dict(), f'{self.ckpt_save_path}/ckpt_final.pth')
152 |
153 | @torch.no_grad()
154 | def val(self, val_loader, test=False):
155 | self.model.eval()
156 | val_res = None
157 | val_scalars = {}
158 | if self.logger is not None:
159 | self.logger.info('Evaluating...')
160 | if test:
161 | val_loader = tqdm(val_loader, desc='Testing', unit='batch',
162 | bar_format='%s{l_bar}{bar}{r_bar}%s' % (Fore.LIGHTCYAN_EX, Fore.RESET))
163 | for batch_data in val_loader:
164 | batch_res = self.val_step(batch_data) # {'Dice':{'c1':0.1, 'c2':0.1, ...}, ...}
165 | if val_res is None:
166 | val_res = batch_res
167 | else:
168 | for metric_name in val_res.keys():
169 | for key in val_res[metric_name].keys():
170 | val_res[metric_name][key] += batch_res[metric_name][key]
171 | for metric_name in val_res.keys():
172 | for key in val_res[metric_name].keys():
173 | val_res[metric_name][key] = val_res[metric_name][key] / len(val_loader)
174 | val_scalars[f'val/{metric_name}.{key}'] = val_res[metric_name][key]
175 |
176 | val_res_list = [_.cpu() for _ in val_res[metric_name].values()]
177 | val_res[metric_name]['Mean'] = np.mean(val_res_list[1:])
178 | val_scalars[f'val/{metric_name}.Mean'] = val_res[metric_name]['Mean']
179 |
180 | val_table = PrettyTable()
181 | val_table.field_names = ['Metirc'] + list(list(val_res.values())[0].keys())
182 | for metric_name in val_res.keys():
183 | if metric_name in ['Dice', 'Jaccard', 'Acc', 'IoU', 'Recall', 'Precision']:
184 | temp = [float(format(_ * 100, '.2f')) for _ in val_res[metric_name].values()]
185 | else:
186 | temp = [float(format(_, '.2f')) for _ in val_res[metric_name].values()]
187 | val_table.add_row([metric_name] + temp)
188 | return val_res, val_scalars, val_table
189 |
190 | def get_current_consistency_weight(self, epoch):
191 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242
192 | return self.consistency * ramps.sigmoid_rampup(epoch, self.consistency_rampup)
193 |
194 | def save_ckpt(self, step, save_path):
195 | ckpt = {'state_dict': self.model.state_dict(),
196 | 'optimizer': self.optimizer.state_dict(),
197 | 'scheduler': self.scheduler.state_dict(),
198 | 'step': step}
199 | torch.save(ckpt, save_path)
200 | self.logger.info('Checkpoint saved!')
201 |
--------------------------------------------------------------------------------
/codes/trainers/mt_trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results.
3 | Link: https://proceedings.neurips.cc/paper/2017/file/68053af2923e00204c3ca7c6a3150cf7-Paper.pdf
4 | Code modified from: https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_mean_teacher_2D.py
5 | """
6 |
7 | import os
8 | import torch
9 |
10 | from tqdm import tqdm
11 | from torchvision.utils import make_grid
12 | from ._base import BaseTrainer
13 | from ..builder import build_model
14 |
15 |
16 | class MeanTeacherTrainer(BaseTrainer):
17 |
18 | def __init__(self, model=None,
19 | ema_model=None,
20 | optimizer=None,
21 | scheduler=None,
22 | criterions=None,
23 | metrics=None,
24 | logger=None,
25 | device='cuda',
26 | resume_from=None,
27 | labeled_bs=12,
28 | alpha=0.99,
29 | consistency=1.0,
30 | consistency_rampup=40.0,
31 | data_parallel=False,
32 | ckpt_save_path=None,
33 | max_iter=60000,
34 | eval_interval=1000,
35 | save_image_interval=50,
36 | save_ckpt_interval=2000) -> None:
37 |
38 | super().__init__(model, optimizer, scheduler, criterions, metrics, logger, device, resume_from, labeled_bs,
39 | consistency, consistency_rampup, data_parallel, ckpt_save_path, max_iter, eval_interval,
40 | save_image_interval, save_ckpt_interval)
41 |
42 | self.ema_model = build_model(ema_model).to(device)
43 | for param in self.ema_model.parameters():
44 | param.detach_()
45 |
46 | self.alpha = alpha
47 |
48 | def train_step(self, batch_data, step, save_image):
49 | log_infos, scalars, images = {}, {}, {}
50 | data, label = batch_data['image'].to(self.device), batch_data['label'].to(self.device)
51 |
52 | unlabeled_data = data[self.labeled_bs:]
53 |
54 | noise = torch.clamp(torch.randn_like(unlabeled_data) * 0.1, -0.2, 0.2)
55 | ema_inputs = unlabeled_data + noise
56 |
57 | outputs = self.model(data)
58 | outputs_soft = torch.softmax(outputs['seg'], dim=1)
59 | with torch.no_grad():
60 | ema_output = self.ema_model(ema_inputs)
61 | ema_output_soft = torch.softmax(ema_output['seg'], dim=1)
62 |
63 | supervised_loss = 0.
64 | for criterion in self.criterions:
65 | loss_ = criterion(outputs['seg'][:self.labeled_bs], label[:self.labeled_bs])
66 | supervised_loss += loss_
67 | log_infos[criterion.name] = float(format(loss_, '.5f'))
68 | scalars[f'loss/{criterion.name}'] = loss_
69 |
70 | consistency_weight = self.get_current_consistency_weight(step // 150)
71 | if step < 1000:
72 | consistency_loss = 0.0
73 | else:
74 | consistency_loss = torch.mean((outputs_soft[self.labeled_bs:] - ema_output_soft) ** 2)
75 |
76 | loss = supervised_loss + consistency_weight * consistency_loss
77 |
78 | log_infos['con_weight'] = float(format(consistency_weight, '.5f'))
79 | log_infos['loss_con'] = float(format(consistency_loss, '.5f'))
80 | log_infos['loss'] = float(format(loss, '.5f'))
81 | scalars['consistency_weight'] = consistency_weight
82 | scalars['loss/loss_consistency'] = consistency_loss
83 | scalars['loss/total'] = loss
84 |
85 | preds = torch.argmax(outputs['seg'], dim=1, keepdim=True).to(torch.float)
86 |
87 | metric_res = self.metrics[0](preds, label)
88 | for key in metric_res.keys():
89 | log_infos[f'{self.metrics[0].name}.{key}'] = float(format(metric_res[key], '.5f'))
90 | scalars[f'train/{self.metrics[0].name}.{key}'] = metric_res[key]
91 |
92 | if save_image:
93 | grid_image = make_grid(data, 4, normalize=True)
94 | images['train/images'] = grid_image
95 | grid_image = make_grid(preds * 50., 4, normalize=False)
96 | images['train/preds'] = grid_image
97 | grid_image = make_grid(label * 50., 4, normalize=False)
98 | images['train/ground_truth'] = grid_image
99 | return loss, log_infos, scalars, images
100 |
101 | def train(self, train_loader, val_loader):
102 | max_epoch = self.max_iter // len(train_loader) + 1
103 | step = self.start_step
104 | with tqdm(total=self.max_iter - self.start_step, bar_format='[{elapsed}<{remaining}] ') as pbar:
105 | for _ in range(max_epoch):
106 | for batch_data in train_loader:
107 | save_image = True if (step + 1) % self.save_image_interval == 0 else False
108 |
109 | loss, log_infos, scalars, images = self.train_step(batch_data, step, save_image)
110 |
111 | self.optimizer.zero_grad()
112 | loss.backward()
113 | self.optimizer.step()
114 | self.scheduler.step()
115 |
116 | self.update_ema_variables(step)
117 |
118 | if (step + 1) % 10 == 0:
119 | scalars.update({'lr': self.scheduler.get_lr()[0]})
120 | log_infos.update({'lr': self.scheduler.get_lr()[0]})
121 | self.logger.update_scalars(scalars, step + 1)
122 | self.logger.info(f'[{step + 1}/{self.max_iter}] {log_infos}')
123 |
124 | if save_image:
125 | self.logger.update_images(images, step + 1)
126 |
127 | if (step + 1) % self.eval_interval == 0:
128 | if val_loader is not None:
129 | val_res, val_scalars, val_table = self.val(val_loader)
130 | self.logger.info(f'val result:\n{val_table.get_string()}')
131 | self.logger.update_scalars(val_scalars, step + 1)
132 | self.model.train()
133 |
134 | if (step + 1) % self.save_ckpt_interval == 0:
135 | if not os.path.exists(self.ckpt_save_path):
136 | os.makedirs(self.ckpt_save_path)
137 | self.save_ckpt(step + 1, f'{self.ckpt_save_path}/iter_{step + 1}.pth')
138 | step += 1
139 | pbar.update(1)
140 | if step >= self.max_iter:
141 | break
142 | if step >= self.max_iter:
143 | break
144 |
145 | if not os.path.exists(self.ckpt_save_path):
146 | os.makedirs(self.ckpt_save_path)
147 | torch.save(self.model.state_dict(), f'{self.ckpt_save_path}/ckpt_final.pth')
148 |
149 | def update_ema_variables(self, global_step):
150 | alpha = min(1 - 1 / (global_step + 1), self.alpha)
151 | for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()):
152 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
153 |
--------------------------------------------------------------------------------
/codes/trainers/supervised_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ._base import BaseTrainer
3 | from torchvision.utils import make_grid
4 |
5 |
6 | class SupervisedTrainer(BaseTrainer):
7 |
8 | def __init__(self,
9 | model=None,
10 | optimizer=None,
11 | scheduler=None,
12 | criterions=None,
13 | metrics=None,
14 | logger=None,
15 | device='cuda',
16 | resume_from=None,
17 | labeled_bs=12,
18 | data_parallel=False,
19 | ckpt_save_path=None,
20 | max_iter=10000,
21 | eval_interval=1000,
22 | save_image_interval=50,
23 | save_ckpt_interval=2000) -> None:
24 |
25 | super().__init__(model, optimizer, scheduler, criterions, metrics, logger, device, resume_from, labeled_bs,
26 | 1.0, 40.0, data_parallel, ckpt_save_path, max_iter, eval_interval,
27 | save_image_interval, save_ckpt_interval)
28 |
29 | def train_step(self, batch_data, step, save_image):
30 | log_infos, scalars, images = {}, {}, {}
31 | data, label = batch_data['image'].to(self.device), batch_data['label'].to(self.device)
32 | logits = self.model(data)
33 | loss = 0.
34 | for criterion in self.criterions:
35 | loss_ = criterion(logits['seg'][:self.labeled_bs], label[:self.labeled_bs])
36 | loss += loss_
37 | log_infos[criterion.name] = float(format(loss_, '.5f'))
38 | scalars[f'loss/{criterion.name}'] = loss_
39 |
40 | log_infos['loss'] = float(format(loss, '.5f'))
41 | scalars['loss/total'] = loss
42 | preds = torch.argmax(torch.softmax(logits['seg'], dim=1), dim=1, keepdim=True).to(torch.float)
43 |
44 | metric_res = self.metrics[0](preds, label)
45 | for key in metric_res.keys():
46 | log_infos[f'{self.metrics[0].name}.{key}'] = float(format(metric_res[key], '.5f'))
47 | scalars[f'train/{self.metrics[0].name}.{key}'] = metric_res[key]
48 |
49 | images = {}
50 | if save_image:
51 | grid_image = make_grid(data[:self.labeled_bs], 4, normalize=True)
52 | images['train/image'] = grid_image
53 | grid_image = make_grid(preds[:self.labeled_bs] * 50., 4, normalize=False)
54 | images['train/pred'] = grid_image
55 | grid_image = make_grid(label[:self.labeled_bs] * 50., 4, normalize=False)
56 | images['train/label'] = grid_image
57 | return loss, log_infos, scalars, images
--------------------------------------------------------------------------------
/codes/trainers/ugpcl_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import numpy as np
5 | import torch.nn.functional as F
6 | import torchvision.transforms.functional as TF
7 |
8 | from torch import nn
9 | from tqdm import tqdm
10 | from prettytable import PrettyTable
11 | from colorama import Fore
12 | from torchvision.utils import make_grid
13 | from ._base import BaseTrainer
14 | from ..utils import ramps
15 | from ..losses import PixelContrastLoss
16 |
17 |
18 | class UGPCLTrainer(BaseTrainer):
19 |
20 | def __init__(self,
21 | model=None,
22 | optimizer=None,
23 | scheduler=None,
24 | criterions=None,
25 | metrics=None,
26 | logger=None,
27 | device='cuda',
28 | resume_from=None,
29 | labeled_bs=8,
30 | data_parallel=False,
31 | ckpt_save_path=None,
32 | max_iter=6000,
33 | eval_interval=1000,
34 | save_image_interval=50,
35 | save_ckpt_interval=2000,
36 | consistency=0.1,
37 | consistency_rampup=40.0,
38 | tf_decoder_weight=0.4,
39 | cls_weight=0.1,
40 | contrast_type='ugpcl', # ugpcl, pseudo, sup, none
41 | contrast_weight=0.1,
42 | temperature=0.1,
43 | base_temperature=0.07,
44 | max_samples=1024,
45 | max_views=1,
46 | memory=True,
47 | memory_size=100,
48 | pixel_update_freq=10,
49 | pixel_classes=4,
50 | dim=256) -> None:
51 |
52 | super(UGPCLTrainer, self).__init__(model, optimizer, scheduler, criterions, metrics, logger, device,
53 | resume_from, labeled_bs, consistency, consistency_rampup, data_parallel,
54 | ckpt_save_path, max_iter, eval_interval, save_image_interval,
55 | save_ckpt_interval)
56 |
57 | self.tf_decoder_weight = tf_decoder_weight
58 | self.cls_weight = cls_weight
59 | self.cls_criterion = torch.nn.CrossEntropyLoss()
60 |
61 | self.contrast_type = contrast_type
62 | self.contrast_weight = contrast_weight
63 | self.contrast_criterion = PixelContrastLoss(temperature=temperature,
64 | base_temperature=base_temperature,
65 | max_samples=max_samples,
66 | max_views=max_views,
67 | device=device)
68 | # memory param
69 | self.memory = memory
70 | self.memory_size = memory_size
71 | self.pixel_update_freq = pixel_update_freq
72 |
73 | if self.memory:
74 | self.segment_queue = torch.randn(pixel_classes, self.memory_size, dim)
75 | self.segment_queue = nn.functional.normalize(self.segment_queue, p=2, dim=2)
76 | self.segment_queue_ptr = torch.zeros(pixel_classes, dtype=torch.long)
77 | self.pixel_queue = torch.zeros(pixel_classes, self.memory_size, dim)
78 | self.pixel_queue = nn.functional.normalize(self.pixel_queue, p=2, dim=2)
79 | self.pixel_queue_ptr = torch.zeros(pixel_classes, dtype=torch.long)
80 |
81 | def _dequeue_and_enqueue(self, keys, labels):
82 | batch_size = keys.shape[0]
83 | feat_dim = keys.shape[1]
84 |
85 | labels = torch.nn.functional.interpolate(labels, (keys.shape[2], keys.shape[3]), mode='nearest')
86 |
87 | for bs in range(batch_size):
88 | this_feat = keys[bs].contiguous().view(feat_dim, -1)
89 | this_label = labels[bs].contiguous().view(-1)
90 | this_label_ids = torch.unique(this_label)
91 | this_label_ids = [x for x in this_label_ids if x > 0]
92 | for lb in this_label_ids:
93 | idxs = (this_label == lb).nonzero()
94 | lb = int(lb.item())
95 | # segment enqueue and dequeue
96 | feat = torch.mean(this_feat[:, idxs], dim=1).squeeze(1)
97 | ptr = int(self.segment_queue_ptr[lb])
98 | self.segment_queue[lb, ptr, :] = nn.functional.normalize(feat.view(-1), p=2, dim=0)
99 | self.segment_queue_ptr[lb] = (self.segment_queue_ptr[lb] + 1) % self.memory_size
100 |
101 | # pixel enqueue and dequeue
102 | num_pixel = idxs.shape[0]
103 | perm = torch.randperm(num_pixel)
104 | K = min(num_pixel, self.pixel_update_freq)
105 | feat = this_feat[:, perm[:K]]
106 | feat = torch.transpose(feat, 0, 1)
107 | ptr = int(self.pixel_queue_ptr[lb])
108 |
109 | if ptr + K >= self.memory_size:
110 | self.pixel_queue[lb, -K:, :] = nn.functional.normalize(feat, p=2, dim=1)
111 | self.pixel_queue_ptr[lb] = 0
112 | else:
113 | self.pixel_queue[lb, ptr:ptr + K, :] = nn.functional.normalize(feat, p=2, dim=1)
114 | self.pixel_queue_ptr[lb] = (self.pixel_queue_ptr[lb] + 1) % self.memory_size
115 |
116 | @staticmethod
117 | def _random_rotate(image, label):
118 | angle = float(torch.empty(1).uniform_(-20., 20.).item())
119 | image = TF.rotate(image, angle)
120 | label = TF.rotate(label, angle)
121 | return image, label
122 |
123 | def train_step(self, batch_data, step, save_image):
124 | log_infos, scalars = {}, {}
125 | images = {}
126 | data_, label_ = batch_data['image'].to(self.device), batch_data['label'].to(self.device)
127 | # data, label = self._random_aug(data_, label_)
128 | if self.cls_weight >= 0.:
129 | images_, labels_ = [], []
130 | cls_label = []
131 | for image, label in zip(data_, label_):
132 | rot_times = random.randrange(0, 4)
133 | cls_label.append(rot_times)
134 | image = torch.rot90(image, rot_times, [1, 2])
135 | label = torch.rot90(label, rot_times, [1, 2])
136 | image, label = self._random_rotate(image, label)
137 | images_.append(image)
138 | labels_.append(label)
139 | cls_label = torch.tensor(cls_label).to(self.device)
140 | data = torch.stack(images_, dim=0).to(self.device)
141 | label = torch.stack(labels_, dim=0).to(self.device)
142 | else:
143 | data = data_
144 | label = label_
145 | cls_label = None
146 |
147 | outputs = self.model(data, self.device)
148 | seg = outputs['seg']
149 | seg_tf = outputs['seg_tf']
150 |
151 | supervised_loss = 0.
152 | for criterion in self.criterions:
153 | loss_ = criterion(seg[:self.labeled_bs], label[:self.labeled_bs]) + \
154 | self.tf_decoder_weight * criterion(seg_tf[:self.labeled_bs], label[:self.labeled_bs])
155 | supervised_loss += loss_
156 | log_infos[criterion.name] = float(format(loss_, '.5f'))
157 | scalars[f'loss/{criterion.name}'] = loss_
158 |
159 | loss_cls = self.cls_criterion(outputs['cls'], cls_label) if self.cls_weight > 0. else 0.
160 |
161 | seg_soft = torch.softmax(seg, dim=1)
162 | seg_tf_soft = torch.softmax(seg_tf, dim=1)
163 | consistency_weight = self.get_current_consistency_weight(step // 100)
164 | consistency_loss = torch.mean((seg_soft[self.labeled_bs:] - seg_tf_soft[self.labeled_bs:]) ** 2)
165 |
166 | loss = supervised_loss + consistency_weight * consistency_loss + self.cls_weight * loss_cls
167 |
168 | log_infos['loss_cls'] = float(format(loss_cls, '.5f'))
169 | log_infos['con_weight'] = float(format(consistency_weight, '.5f'))
170 | log_infos['loss_con'] = float(format(consistency_loss, '.5f'))
171 | log_infos['loss'] = float(format(loss, '.5f'))
172 | scalars['loss/loss_cls'] = loss_cls
173 | scalars['consistency_weight'] = consistency_weight
174 | scalars['loss/loss_consistency'] = consistency_loss
175 | scalars['loss/total'] = loss
176 |
177 | preds = torch.argmax(seg_soft, dim=1, keepdim=True).to(torch.float)
178 |
179 | log_infos['loss_contrast'] = 0.
180 | scalars['loss/contrast'] = 0.
181 | if step > 1000 and self.contrast_weight > 0.:
182 | # queue = torch.cat((self.segment_queue, self.pixel_queue), dim=1) if self.memory else None
183 | queue = self.segment_queue if self.memory else None
184 | if self.contrast_type == 'ugpcl':
185 | seg_mean = torch.mean(torch.stack([F.softmax(seg, dim=1), F.softmax(seg_tf, dim=1)]), dim=0)
186 | uncertainty = -1.0 * torch.sum(seg_mean * torch.log(seg_mean + 1e-6), dim=1, keepdim=True)
187 | threshold = (0.75 + 0.25 * ramps.sigmoid_rampup(step, self.max_iter)) * np.log(2)
188 | uncertainty_mask = (uncertainty > threshold)
189 | mean_preds = torch.argmax(F.softmax(seg_mean, dim=1).detach(), dim=1, keepdim=True).float()
190 | certainty_pseudo = mean_preds.clone()
191 | certainty_pseudo[uncertainty_mask] = -1
192 | certainty_pseudo[:self.labeled_bs] = label[:self.labeled_bs]
193 | contrast_loss = self.contrast_criterion(outputs['embed'], certainty_pseudo, preds, queue=queue)
194 | scalars['uncertainty_rate'] = torch.sum(uncertainty_mask == True) / \
195 | (torch.sum(uncertainty_mask == True) + torch.sum(
196 | uncertainty_mask == False))
197 | if self.memory:
198 | self._dequeue_and_enqueue(outputs['embed'].detach(), certainty_pseudo.detach())
199 | if save_image:
200 | grid_image = make_grid(mean_preds * 50., 4, normalize=False)
201 | images['train/mean_preds'] = grid_image
202 | grid_image = make_grid(certainty_pseudo * 50., 4, normalize=False)
203 | images['train/certainty_pseudo'] = grid_image
204 | grid_image = make_grid(uncertainty, 4, normalize=False)
205 | images['train/uncertainty'] = grid_image
206 | grid_image = make_grid(uncertainty_mask.float(), 4, normalize=False)
207 | images['train/uncertainty_mask'] = grid_image
208 | elif self.contrast_type == 'pseudo':
209 | contrast_loss = self.contrast_criterion(outputs['embed'], preds.detach(), preds, queue=queue)
210 | if self.memory:
211 | self._dequeue_and_enqueue(outputs['embed'].detach(), preds.detach())
212 | elif self.contrast_type == 'sup':
213 | contrast_loss = self.contrast_criterion(outputs['embed'][:self.labeled_bs], label[:self.labeled_bs],
214 | preds[:self.labeled_bs], queue=queue)
215 | if self.memory:
216 | self._dequeue_and_enqueue(outputs['embed'].detach()[:self.labeled_bs],
217 | label.detach()[:self.labeled_bs])
218 | else:
219 | contrast_loss = 0.
220 | loss += self.contrast_weight * contrast_loss
221 | log_infos['loss_contrast'] = float(format(contrast_loss, '.5f'))
222 | scalars['loss/contrast'] = contrast_loss
223 |
224 | tf_preds = torch.argmax(seg_tf_soft, dim=1, keepdim=True).to(torch.float)
225 | metric_res = self.metrics[0](preds, label)
226 | for key in metric_res.keys():
227 | log_infos[f'{self.metrics[0].name}.{key}'] = float(format(metric_res[key], '.5f'))
228 | scalars[f'train/{self.metrics[0].name}.{key}'] = metric_res[key]
229 |
230 | if save_image:
231 | grid_image = make_grid(data, 4, normalize=True)
232 | images['train/images'] = grid_image
233 | grid_image = make_grid(preds * 50., 4, normalize=False)
234 | images['train/preds'] = grid_image
235 | grid_image = make_grid(tf_preds * 50., 4, normalize=False)
236 | images['train/tf_preds'] = grid_image
237 | grid_image = make_grid(label * 50., 4, normalize=False)
238 | images['train/labels'] = grid_image
239 |
240 | return loss, log_infos, scalars, images
241 |
242 | def val_step(self, batch_data):
243 | data, labels = batch_data['image'].to(self.device), batch_data['label'].to(self.device)
244 | preds = self.model.inference(data)
245 | metric_total_res = {}
246 | for metric in self.metrics:
247 | metric_total_res[metric.name] = metric(preds, labels)
248 | return metric_total_res
249 |
250 | def val_step_tf(self, batch_data):
251 | data, labels = batch_data['image'].to(self.device), batch_data['label'].to(self.device)
252 | preds = self.model.inference_tf(data, self.device)
253 | metric_total_res = {}
254 | for metric in self.metrics:
255 | metric_total_res[metric.name] = metric(preds, labels)
256 | return metric_total_res
257 |
258 | @torch.no_grad()
259 | def val_tf(self, val_loader, test=False):
260 | self.model.eval()
261 | val_res = None
262 | val_scalars = {}
263 | if self.logger is not None:
264 | self.logger.info('Evaluating...')
265 | if test:
266 | val_loader = tqdm(val_loader, desc='Testing', unit='batch',
267 | bar_format='%s{l_bar}{bar}{r_bar}%s' % (Fore.LIGHTCYAN_EX, Fore.RESET))
268 | for batch_data in val_loader:
269 | batch_res = self.val_step_tf(batch_data) # {'Dice':{'c1':0.1, 'c2':0.1, ...}, ...}
270 | if val_res is None:
271 | val_res = batch_res
272 | else:
273 | for metric_name in val_res.keys():
274 | for key in val_res[metric_name].keys():
275 | val_res[metric_name][key] += batch_res[metric_name][key]
276 | for metric_name in val_res.keys():
277 | for key in val_res[metric_name].keys():
278 | val_res[metric_name][key] = val_res[metric_name][key] / len(val_loader)
279 | val_scalars[f'val_tf/{metric_name}.{key}'] = val_res[metric_name][key]
280 |
281 | val_res_list = [_.cpu() for _ in val_res[metric_name].values()]
282 | val_res[metric_name]['Mean'] = np.mean(val_res_list[1:])
283 | val_scalars[f'val_tf/{metric_name}.Mean'] = val_res[metric_name]['Mean']
284 |
285 | val_table = PrettyTable()
286 | val_table.field_names = ['Metirc'] + list(list(val_res.values())[0].keys())
287 | for metric_name in val_res.keys():
288 | if metric_name in ['Dice', 'Jaccard', 'Acc', 'IoU', 'Recall', 'Precision']:
289 | temp = [float(format(_ * 100, '.2f')) for _ in val_res[metric_name].values()]
290 | else:
291 | temp = [float(format(_, '.2f')) for _ in val_res[metric_name].values()]
292 | val_table.add_row([metric_name] + temp)
293 | return val_res, val_scalars, val_table
294 |
295 | def train(self, train_loader, val_loader):
296 | # iter_train_loader = iter(train_loader)
297 | max_epoch = self.max_iter // len(train_loader) + 1
298 | step = self.start_step
299 | self.model.train()
300 | with tqdm(total=self.max_iter - self.start_step, bar_format='[{elapsed}<{remaining}] ') as pbar:
301 | for _ in range(max_epoch):
302 | for batch_data in train_loader:
303 | save_image = True if (step + 1) % self.save_image_interval == 0 else False
304 |
305 | loss, log_infos, scalars, images = self.train_step(batch_data, step, save_image)
306 |
307 | self.optimizer.zero_grad()
308 | loss.backward()
309 | self.optimizer.step()
310 | self.scheduler.step()
311 |
312 | if (step + 1) % 10 == 0:
313 | scalars.update({'lr': self.scheduler.get_lr()[0]})
314 | log_infos.update({'lr': self.scheduler.get_lr()[0]})
315 | self.logger.update_scalars(scalars, step + 1)
316 | self.logger.info(f'[{step + 1}/{self.max_iter}] {log_infos}')
317 |
318 | if save_image:
319 | self.logger.update_images(images, step + 1)
320 |
321 | if (step + 1) % self.eval_interval == 0:
322 | if val_loader is not None:
323 | val_res, val_scalars, val_table = self.val(val_loader)
324 | self.logger.info(f'val result:\n{val_table.get_string()}')
325 | self.logger.update_scalars(val_scalars, step + 1)
326 | self.model.train()
327 |
328 | val_res, val_scalars, val_table = self.val_tf(val_loader)
329 | self.logger.info(f'val_tf result:\n{val_table.get_string()}')
330 | self.logger.update_scalars(val_scalars, step + 1)
331 | self.model.train()
332 |
333 | if (step + 1) % self.save_ckpt_interval == 0:
334 | if not os.path.exists(self.ckpt_save_path):
335 | os.makedirs(self.ckpt_save_path)
336 | self.save_ckpt(step + 1, f'{self.ckpt_save_path}/iter_{step + 1}.pth')
337 | step += 1
338 | pbar.update(1)
339 | if step >= self.max_iter:
340 | break
341 | if step >= self.max_iter:
342 | break
343 |
344 | if not os.path.exists(self.ckpt_save_path):
345 | os.makedirs(self.ckpt_save_path)
346 | torch.save(self.model.state_dict(), f'{self.ckpt_save_path}/ckpt_final.pth')
347 |
--------------------------------------------------------------------------------
/codes/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/codes/utils/__init__.py
--------------------------------------------------------------------------------
/codes/utils/analyze.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import cv2
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from tqdm import tqdm
7 | from torchvision.transforms import transforms
8 | from codes.builder import build_model
9 | from codes.utils.utils import parse_yaml, Namespace
10 |
11 | color1 = (102, 253, 204, 1.)
12 | color2 = (255, 255, 102, 1.)
13 | color3 = (255, 255, 255, 1.)
14 |
15 |
16 | def _color(img, rgba=(102, 253, 204, 1.)):
17 | h, w, c = img.shape
18 | rgba = list(rgba)
19 | rgba[0] /= 255.
20 | rgba[1] /= 255.
21 | rgba[2] /= 255.
22 | img = img.tolist()
23 | for i in range(h):
24 | for j in range(w):
25 | if img[i][j] == [1., 1., 1., 1.]:
26 | img[i][j] = rgba
27 | return np.array(img)
28 |
29 |
30 | def mixed_color(img1, img2):
31 | h, w, c = img1.shape
32 | for i in range(h):
33 | for j in range(w):
34 | if (img1[i][j] == [0., 0., 0., 1.]).all() or (img2[i][j] == [0., 0., 0., 1.]).all():
35 | img1[i][j] = img1[i][j] + img2[i][j]
36 | img1[i][j][-1] = 1.
37 | else:
38 | img1[i][j] = img1[i][j] * 0.5 + img2[i][j] * 0.5
39 | return np.array(img1)
40 |
41 |
42 | def cal_dice(pred, gt):
43 | intersection = (pred * gt).sum()
44 | return (2 * intersection) / (pred.sum() + gt.sum()).item()
45 |
46 |
47 | def show_pred_gt(model,
48 | device='cuda',
49 | img_size=(3, 256, 256),
50 | datasets=r'D:\datasets\Breast\large\val',
51 | output_path=r'../analyze_results/hardmseg'):
52 | mean = [123.675, 116.28, 103.53]
53 | std = [58.395, 57.12, 57.375]
54 | if not os.path.exists(output_path):
55 | os.makedirs(output_path)
56 | model = model.to(device)
57 | img_path = os.path.join(datasets, 'images')
58 | mask_path = os.path.join(datasets, 'masks')
59 | names = os.listdir(img_path)
60 | dice_list = []
61 | for name in tqdm(names):
62 | if img_size[0] < 3:
63 | img = cv2.imread(os.path.join(img_path, name), cv2.IMREAD_GRAYSCALE)
64 | img = img[:np.newaxis]
65 | else:
66 | img = cv2.imread(os.path.join(img_path, name), cv2.IMREAD_COLOR)
67 | h, w, _ = img.shape
68 | img = cv2.resize(img, (img_size[1], img_size[2]))
69 | img = torch.tensor(img.transpose(2, 0, 1)).unsqueeze(0).to(device, torch.float32)
70 | img = transforms.Normalize(std=std, mean=mean)(img)
71 | img = torch.cat([img, img], dim=0)
72 | with torch.no_grad():
73 | pred = model.inference(img)[0].argmax(0)
74 | pred = pred.view(img_size[1], img_size[1]).cpu().detach().numpy().astype(np.float32)
75 | mask = cv2.imread(os.path.join(mask_path, name.replace('jpg', 'png')), cv2.IMREAD_GRAYSCALE)
76 | dice = cal_dice(pred, mask / 255)
77 | dice_list.append(dice)
78 | mask = cv2.resize(mask, (img_size[1], img_size[2]))
79 | mask = (mask != 0).astype(np.float32)
80 | pred = cv2.cvtColor(pred, cv2.COLOR_GRAY2BGRA)
81 | mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGRA)
82 | plt.imsave(os.path.join(output_path, f'{int(dice * 10000)}_{name}'),
83 | _color(mask, color3) * 0.5 + _color(pred, color1) * 0.5)
84 | # plt.imsave(os.path.join(output_path, name), mixed_color(_color(mask, color1), _color(pred, color3)))
85 | f = open(f'{output_path}/result.txt', 'w')
86 | for dice, name in zip(dice_list, names):
87 | f.write(f'{name}:\t{dice * 100:.2f}\n')
88 | f.write(f'average:\t{np.mean(dice_list) * 100:.2f}\n')
89 | f.close()
90 |
91 |
92 | if __name__ == '__main__':
93 | args_dict = parse_yaml(r'F:\projects\PyCharmProjects\SSLSeg\configs\unet_breast_mri.yaml')
94 | model = build_model(args_dict['model'])
95 | # model = torch.nn.DataParallel(model)
96 | model.load_state_dict(
97 | torch.load(r'F:\projects\PyCharmProjects\SSLSeg\results\breast_mri\unet_breast_mri_0924194128\iter_6000.pth')[
98 | 'state_dict'])
99 | show_pred_gt(model,
100 | device='cuda',
101 | img_size=(3, 256, 256),
102 | datasets=r'F:\datasets\breast_mri\256x256\val',
103 | output_path=r'F:\DeskTop\preds_unet_')
104 |
--------------------------------------------------------------------------------
/codes/utils/init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from timm.models.layers import trunc_normal_
4 |
5 |
6 | def kaiming_normal_init_weight(model):
7 | for m in model.modules():
8 | if isinstance(m, nn.Conv2d):
9 | torch.nn.init.kaiming_normal_(m.weight)
10 | elif isinstance(m, nn.BatchNorm2d):
11 | m.weight.data.fill_(1)
12 | m.bias.data.zero_()
13 | elif isinstance(m, nn.Linear):
14 | trunc_normal_(m.weight, std=.02)
15 | if m.bias is not None:
16 | nn.init.constant_(m.bias, 0)
17 | elif isinstance(m, nn.LayerNorm):
18 | nn.init.constant_(m.bias, 0)
19 | nn.init.constant_(m.weight, 1.0)
20 |
21 |
22 | def xavier_normal_init_weight(model):
23 | for m in model.modules():
24 | if isinstance(m, nn.Conv2d):
25 | torch.nn.init.xavier_normal_(m.weight)
26 | elif isinstance(m, nn.BatchNorm2d):
27 | m.weight.data.fill_(1)
28 | m.bias.data.zero_()
--------------------------------------------------------------------------------
/codes/utils/ramps.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Functions for ramping hyperparameters up or down
9 |
10 | Each function takes the current training step or epoch, and the
11 | ramp length in the same format, and returns a multiplier between
12 | 0 and 1.
13 | """
14 |
15 |
16 | import numpy as np
17 |
18 |
19 | def sigmoid_rampup(current, rampup_length):
20 | """Exponential rampup from https://arxiv.org/abs/1610.02242"""
21 | if rampup_length == 0:
22 | return 1.0
23 | else:
24 | current = np.clip(current, 0.0, rampup_length)
25 | phase = 1.0 - current / rampup_length
26 | return float(np.exp(-5.0 * phase * phase))
27 |
28 |
29 | def linear_rampup(current, rampup_length):
30 | """Linear rampup"""
31 | assert current >= 0 and rampup_length >= 0
32 | if current >= rampup_length:
33 | return 1.0
34 | else:
35 | return current / rampup_length
36 |
37 |
38 | def cosine_rampdown(current, rampdown_length):
39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
40 | assert 0 <= current <= rampdown_length
41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
42 |
--------------------------------------------------------------------------------
/codes/utils/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import yaml
4 |
5 |
6 | class AverageMeter(object):
7 | """Computes and stores the average and current value"""
8 |
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 |
25 | class Namespace(object):
26 | def __init__(self, some_dict):
27 | for key, value in some_dict.items():
28 | assert isinstance(key, str) and re.match("[A-Za-z_-]", key)
29 | if isinstance(value, dict):
30 | self.__dict__[key] = Namespace(value)
31 | elif isinstance(value, list):
32 | value = value.copy()
33 | if isinstance(value[0], dict):
34 | for i in range(len(value)):
35 | value[i] = Namespace(value[i])
36 | self.__dict__[key] = value
37 | else:
38 | self.__dict__[key] = value
39 |
40 | def __getattr__(self, attribute):
41 |
42 | raise AttributeError(
43 | f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!")
44 |
45 |
46 | def deep_update_dict(res_dict, in_dict):
47 | for key in in_dict.keys():
48 | if key in res_dict and isinstance(in_dict[key], dict) and isinstance(res_dict[key], dict) and \
49 | 'name' in in_dict[key].keys() and 'kwargs' in in_dict[key].keys() and \
50 | 'name' in res_dict[key].keys() and 'kwargs' in res_dict[key].keys() and \
51 | in_dict[key]['name'] == res_dict[key]['name']:
52 | deep_update_dict(res_dict[key]['kwargs'], in_dict[key]['kwargs'])
53 | else:
54 | res_dict[key] = in_dict[key]
55 |
56 |
57 | def parse_yaml(yaml_file_path):
58 | res = {}
59 | with open(yaml_file_path, 'r') as yaml_file:
60 | f = yaml.load(yaml_file, Loader=yaml.FullLoader)
61 | if 'include' in f:
62 | abs_path = os.path.abspath(yaml_file_path)
63 | abs_path = abs_path.replace('\\', '/')
64 | abs_path_list = abs_path.split('/')
65 | for include_file_path in f['include']:
66 | include_file_path = include_file_path.replace('\\', '/')
67 | include_path_list = include_file_path.split('/')
68 | if '' in include_path_list:
69 | include_path_list.remove('')
70 | if '.' in include_path_list:
71 | include_path_list.remove('.')
72 | n = include_path_list.count('..')
73 | include_file_path = '/'.join(abs_path_list[:-(n + 1)] + include_path_list[n:])
74 | with open(include_file_path, 'r') as include_file:
75 | deep_update_dict(res, yaml.load(include_file, Loader=yaml.FullLoader))
76 | # res.update(yaml.load(include_file, Loader=yaml.FullLoader))
77 | deep_update_dict(res, f)
78 | # res.update(f)
79 | if 'include' in res.keys():
80 | res.pop('include')
81 | return res
--------------------------------------------------------------------------------
/configs/_datasets/acdc.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: acdc
3 | kwargs:
4 | root_dir: F:/datasets/ACDC/
5 | labeled_num: 7
6 | labeled_bs: 12
7 | batch_size: 24
8 | batch_size_val: 16
9 | num_workers: 0
10 | train_transforms:
11 | - name: RandomGenerator
12 | kwargs: { output_size: [ 256, 256 ] }
13 | - name: ToRGB
14 | - name: RandomCrop
15 | kwargs: { size: [ 256, 256 ] }
16 | - name: RandomFlip
17 | kwargs: { p: 0.5 }
18 | - name: ColorJitter
19 | kwargs: { brightness: 0.4,contrast: 0.4, saturation: 0.4, hue: 0.1, p: 0.8 }
20 | val_transforms:
21 | - name: RandomGenerator
22 | kwargs:
23 | p_flip: 0.0
24 | p_rot: 0.0
25 | output_size: [ 256, 256 ]
26 | - name: ToRGB
--------------------------------------------------------------------------------
/configs/_datasets/acdc_224.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: acdc
3 | kwargs:
4 | root_dir: F:/datasets/ACDC/
5 | labeled_num: 7
6 | labeled_bs: 8
7 | batch_size: 16
8 | batch_size_val: 16
9 | num_workers: 0
10 | train_transforms:
11 | - name: RandomGenerator
12 | kwargs: { output_size: [ 224, 224 ], p_flip: 0.5, p_rot: 0.5 }
13 | - name: ToRGB
14 | - name: RandomCrop
15 | kwargs: { size: [ 224, 224 ] }
16 | - name: RandomFlip
17 | kwargs: { p: 0.5 }
18 | - name: ColorJitter
19 | kwargs: { brightness: 0.4,contrast: 0.4, saturation: 0.4, hue: 0.1, p: 0.8 }
20 | val_transforms:
21 | - name: RandomGenerator
22 | kwargs: { output_size: [ 224, 224 ], p_flip: 0.5, p_rot: 0.5 }
23 | - name: ToRGB
--------------------------------------------------------------------------------
/configs/_models/unet_r50.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | name: UNet
3 | kwargs:
4 | in_channels: 1
5 | classes: 4
6 | encoder_name: resnet50
7 | encoder_weights: imagenet
--------------------------------------------------------------------------------
/configs/_models/unet_tf_r50.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | name: UNetTF
3 | kwargs:
4 | in_channels: 3
5 | classes: 4
6 | encoder_name: resnet50
7 | encoder_weights: imagenet
8 | contrast_embed: False
9 | contrast_embed_dim: 256
10 | contrast_embed_index: -3
--------------------------------------------------------------------------------
/configs/_trainers/mt.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | name: MeanTeacherTrainer
3 | kwargs:
4 | criterions:
5 | - name: DiceLoss_
6 | kwargs: { weight: 0.5 ,n_classes: 4, class_weight: [ 0.0, 1.0, 1.0, 1.0 ] }
7 | - name: CrossEntropyLoss_
8 | kwargs: { weight: 0.5 }
9 | metrics: # metrics config list
10 | - name: Dice
11 | kwargs:
12 | name: Dice
13 | class_indexs: [ 0, 1, 2, 3 ]
14 | class_names: [ bg, c1, c2, c3 ]
15 | - name: Jaccard
16 | kwargs:
17 | name: Jaccard
18 | class_indexs: [ 0, 1, 2, 3 ]
19 | class_names: [ bg, c1, c2, c3 ]
20 | labeled_bs: 8
21 | alpha: 0.99
22 | consistency: 0.1
23 | consistency_rampup: 40.0
24 | max_iter: 6000
25 | eval_interval: 1000
26 | save_image_interval: 50
27 | save_ckpt_interval: 1000
28 |
29 | ema_model:
30 | name: UNet
31 | kwargs:
32 | in_channels: 3
33 | classes: 2
34 | encoder_name: resnet50
35 | encoder_weights: imagenet
36 |
37 | optimizer:
38 | name: SGD
39 | kwargs:
40 | lr: 0.01
41 | weight_decay: 0.0001
42 | momentum: 0.9
43 |
44 | scheduler:
45 | name: StepLR
46 | kwargs:
47 | step_size: 2500
48 | gamma: 0.1
49 |
50 | #scheduler:
51 | # name: PolyLR
52 | # kwargs:
53 | # max_iters: 10000
54 | # power: 0.9
55 | # last_epoch: -1
56 | # min_lr: 0.0001
57 |
58 | logger:
59 | log_file: True
60 | tensorboard: True
61 |
--------------------------------------------------------------------------------
/configs/_trainers/supervised.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | name: SupervisedTrainer
3 | kwargs:
4 | criterions:
5 | - name: DiceLoss_
6 | kwargs: { weight: 0.5 ,n_classes: 4, class_weight: [ 0.0, 1.0, 1.0, 1.0 ] }
7 | - name: CrossEntropyLoss_
8 | kwargs: { weight: 0.5 }
9 | metrics: # metrics config list
10 | - name: Dice
11 | kwargs:
12 | name: Dice
13 | class_indexs: [ 0, 1, 2, 3 ]
14 | class_names: [ bg, c1, c2, c3 ]
15 | - name: Jaccard
16 | kwargs:
17 | name: Jaccard
18 | class_indexs: [ 0, 1, 2, 3 ]
19 | class_names: [ bg, c1, c2, c3 ]
20 | labeled_bs: 8
21 | max_iter: 6000
22 | eval_interval: 1000
23 | save_image_interval: 50
24 | save_ckpt_interval: 1000
25 |
26 | optimizer:
27 | name: SGD
28 | kwargs:
29 | lr: 0.01
30 | weight_decay: 0.0001
31 | momentum: 0.9
32 |
33 | scheduler:
34 | name: StepLR
35 | kwargs:
36 | step_size: 2500
37 | gamma: 0.1
38 |
39 | logger:
40 | log_file: True
41 | tensorboard: True
--------------------------------------------------------------------------------
/configs/_trainers/ugpcl.yaml:
--------------------------------------------------------------------------------
1 | train:
2 | name: UGPCLTrainer
3 | kwargs:
4 | criterions:
5 | - name: DiceLoss_
6 | kwargs: { weight: 0.5 ,n_classes: 4, class_weight: [ 0.0, 1.0, 1.0, 1.0 ] }
7 | - name: CrossEntropyLoss_
8 | kwargs: { weight: 0.5 }
9 | metrics: # metrics config list
10 | - name: Dice
11 | kwargs:
12 | name: Dice
13 | class_indexs: [ 0, 1, 2, 3 ]
14 | class_names: [ bg, c1, c2, c3 ]
15 | - name: Jaccard
16 | kwargs:
17 | name: Jaccard
18 | class_indexs: [ 0, 1, 2, 3 ]
19 | class_names: [ bg, c1, c2, c3 ]
20 | labeled_bs: 8
21 | consistency: 0.1
22 | consistency_rampup: 60.0
23 | contrast_weight: 0.1
24 | temperature: 0.07
25 | base_temperature: 0.07
26 | memory: true
27 | max_samples: 1024
28 | max_views: 1
29 | memory_size: 2000
30 | pixel_update_freq: 10
31 | pixel_classes: 4
32 | max_iter: 6000
33 | eval_interval: 1000
34 | save_image_interval: 50
35 | save_ckpt_interval: 1000
36 |
37 | optimizer:
38 | name: SGD
39 | kwargs:
40 | lr: 0.01
41 | weight_decay: 0.0005
42 | momentum: 0.9
43 |
44 | scheduler:
45 | name: PolyLR
46 | kwargs:
47 | max_iters: 6000
48 | power: 0.9
49 | last_epoch: -1
50 | min_lr: 0.0001
51 |
52 | logger:
53 | log_file: True
54 | tensorboard: True
55 |
--------------------------------------------------------------------------------
/configs/comparison_acdc_224_136/mt_unet_r50.yaml:
--------------------------------------------------------------------------------
1 | name: mt_unet_r50
2 |
3 | include:
4 | - ../_datasets/acdc_224.yaml
5 | - ../_models/unet_r50.yaml
6 | - ../_trainers/mt.yaml
7 |
8 | train:
9 | name: MeanTeacherTrainer
10 | kwargs:
11 | consistency: 0.1
12 | consistency_rampup: 100.0 # max_iter // 100
13 | max_iter: 10000
14 | eval_interval: 1000
15 | save_image_interval: 50
16 | save_ckpt_interval: 1000
17 |
18 | scheduler:
19 | name: PolyLR
20 | kwargs:
21 | max_iters: 10000
22 |
23 | model:
24 | name: UNet
25 | kwargs:
26 | in_channels: 3
27 | classes: 4
28 |
29 | ema_model:
30 | name: UNet
31 | kwargs:
32 | in_channels: 3
33 | classes: 4
--------------------------------------------------------------------------------
/configs/comparison_acdc_224_136/ugpcl_unet_r50.yaml:
--------------------------------------------------------------------------------
1 | name: ugpcl_unet_r50
2 |
3 | include:
4 | - ../_datasets/acdc_224.yaml
5 | - ../_models/unet_tf_r50.yaml
6 | - ../_trainers/ugpcl.yaml
7 |
8 | dataset:
9 | name: acdc
10 | kwargs:
11 | labeled_bs: 8
12 | batch_size: 16
13 | train_transforms:
14 | - name: RandomGenerator
15 | kwargs: { output_size: [ 224, 224 ], p_flip: 0.0, p_rot: 0.0 } # Remove random rotation
16 | - name: ToRGB
17 | - name: RandomCrop
18 | kwargs: { size: [ 224, 224 ] }
19 | - name: RandomFlip
20 | kwargs: { p: 0.5 }
21 | - name: ColorJitter
22 | kwargs: { brightness: 0.4,contrast: 0.4, saturation: 0.4, hue: 0.1, p: 0.8 }
23 |
24 | train:
25 | name: UGPCLTrainer
26 | kwargs:
27 | contrast_weight: 0.1
28 | labeled_bs: 8
29 | consistency: 0.01
30 | consistency_rampup: 100.0 # max_iter // 100
31 | memory: true
32 | max_samples: 1024
33 | max_views: 1
34 | memory_size: 500
35 | pixel_update_freq: 10
36 | pixel_classes: 4
37 | max_iter: 10000
38 |
39 | model:
40 | name: UNetTF
41 | kwargs:
42 | contrast_embed: True
43 |
44 | scheduler:
45 | name: PolyLR
46 | kwargs:
47 | max_iters: 10000
48 |
--------------------------------------------------------------------------------
/configs/comparison_acdc_224_136/unet_r50.yaml:
--------------------------------------------------------------------------------
1 | name: unet_r50
2 |
3 | include:
4 | - ../_datasets/acdc_224.yaml
5 | - ../_models/unet_r50.yaml
6 | - ../_trainers/supervised.yaml
7 |
8 | train:
9 | name: SupervisedTrainer
10 | kwargs:
11 | max_iter: 10000
12 | eval_interval: 1000
13 | save_image_interval: 50
14 | save_ckpt_interval: 1000
15 |
16 | model:
17 | name: UNet
18 | kwargs:
19 | in_channels: 3
20 | classes: 4
21 |
22 | scheduler:
23 | name: PolyLR
24 | kwargs:
25 | max_iters: 10000
--------------------------------------------------------------------------------
/configs/comparison_acdc_224_136/unet_r50_full.yaml:
--------------------------------------------------------------------------------
1 | name: unet_r50
2 |
3 | include:
4 | - ../_datasets/acdc_224.yaml
5 | - ../_models/unet_r50.yaml
6 | - ../_trainers/supervised.yaml
7 |
8 | dataset:
9 | name: acdc
10 | kwargs:
11 | labeled_bs: 16
12 |
13 | train:
14 | name: SupervisedTrainer
15 | kwargs:
16 | labeled_bs: 16
17 | max_iter: 10000
18 | eval_interval: 1000
19 | save_image_interval: 50
20 | save_ckpt_interval: 1000
21 |
22 | model:
23 | name: UNet
24 | kwargs:
25 | in_channels: 3
26 | classes: 4
27 |
28 | scheduler:
29 | name: PolyLR
30 | kwargs:
31 | max_iters: 10000
--------------------------------------------------------------------------------
/pics/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/pics/overview.jpg
--------------------------------------------------------------------------------
/pics/preds.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/pics/preds.jpg
--------------------------------------------------------------------------------
/pics/show_feats.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taovv/UGPCL/5980405b15ef1bc139be0447647a213b5a5f30b9/pics/show_feats.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work
2 | h5py==3.4.0
3 | matplotlib==3.4.2
4 | MedPy==0.4.0
5 | numpy==1.21.1
6 | prettytable==2.1.0
7 | scipy==1.7.1
8 | segmentation-models-pytorch==0.2.0
9 | sklearn==0.0
10 | tensorboardX==2.4
11 | timm==0.4.12
12 | torch==1.9.1+cu102
13 | torchvision==0.10.1+cu102
14 | tqdm==4.62.0
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import numpy as np
5 | import cv2
6 |
7 | from tqdm import tqdm
8 | from colorama import Fore
9 | from prettytable import PrettyTable
10 | from codes.builder import *
11 | from codes.utils.utils import Namespace, parse_yaml
12 |
13 | from sklearn.manifold import TSNE
14 | import matplotlib.pyplot as plt
15 |
16 |
17 | def get_args(config_file):
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--config', type=str,
20 | default=config_file,
21 | help='train config file path: xxx.yaml')
22 | parser.add_argument('--device', type=str, default='cuda:1' if torch.cuda.is_available() else 'cpu')
23 | args = parser.parse_args()
24 | args_dict = parse_yaml(args.config)
25 | for key, value in Namespace(args_dict).__dict__.items():
26 | vars(args)[key] = value
27 | return args
28 |
29 |
30 | def save_result(img,
31 | seg,
32 | file_path,
33 | opacity=0.5,
34 | palette=[[0, 0, 0], [0, 255, 255], [255, 106, 106], [255, 250, 240]]):
35 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
36 | for label, color in enumerate(palette):
37 | color_seg[seg == label, :] = color
38 | # convert to BGR
39 | color_seg = color_seg[..., ::-1]
40 |
41 | img = img * (1 - opacity) + color_seg * opacity
42 | img = img.astype(np.uint8)
43 |
44 | dir_name = os.path.abspath(os.path.dirname(file_path))
45 | os.makedirs(dir_name, exist_ok=True)
46 | cv2.imwrite(file_path, img)
47 |
48 |
49 | def test(config_file, weights, save_pred=True):
50 | args = get_args(config_file)
51 |
52 | _, test_loader = build_dataloader(args.dataset, None)
53 |
54 | model = build_model(args.model)
55 | state_dict = torch.load(weights)
56 | if 'state_dict' in state_dict.keys():
57 | state_dict = state_dict['state_dict']
58 | model.load_state_dict(state_dict)
59 |
60 | print(F'\nModel: {model.__class__.__name__}')
61 | print(F'Weights file: {weights}')
62 |
63 | test_loader = tqdm(test_loader, desc='Testing', unit='batch',
64 | bar_format='%s{l_bar}{bar}{r_bar}%s' % (Fore.LIGHTCYAN_EX, Fore.RESET))
65 |
66 | metrics = []
67 | if args.train.kwargs.metrics is not None:
68 | for metric_cfg in args.train.kwargs.metrics:
69 | metrics.append(_build_from_cfg(metric_cfg))
70 |
71 | test_res = None
72 | model.to(args.device)
73 | model.eval()
74 | i = 0
75 | for batch_data in test_loader:
76 | data, label = batch_data['image'].to(args.device), batch_data['label'].to(args.device)
77 | preds = model.inference(data)
78 | batch_metric_res = {}
79 | for metric in metrics:
80 | batch_metric_res[metric.name] = metric(preds, label)
81 |
82 | if save_pred:
83 | for j in range(preds.size(0)):
84 | save_result(
85 | data[j].permute(1, 2, 0).cpu().numpy() * 255.,
86 | preds[j][0].cpu().numpy(),
87 | f'./shows/{i}.png',
88 | opacity=1.0)
89 | i += 1
90 |
91 | if test_res is None:
92 | test_res = batch_metric_res
93 | else:
94 | for metric_name in test_res.keys():
95 | for key in test_res[metric_name].keys():
96 | test_res[metric_name][key] += batch_metric_res[metric_name][key]
97 | for metric_name in test_res.keys():
98 | for key in test_res[metric_name].keys():
99 | test_res[metric_name][key] = test_res[metric_name][key] / len(test_loader)
100 |
101 | test_res_list = [_.cpu() for _ in test_res[metric_name].values()]
102 | test_res[metric_name]['Mean'] = np.mean(test_res_list[1:])
103 |
104 | test_table = PrettyTable()
105 | test_table.field_names = ['Metirc'] + list(list(test_res.values())[0].keys())
106 | for metric_name in test_res.keys():
107 | if metric_name in ['Dice', 'Jaccard', 'Acc', 'IoU', 'Recall', 'Precision']:
108 | temp = [float(format(_ * 100, '.2f')) for _ in test_res[metric_name].values()]
109 | else:
110 | temp = [float(format(_, '.2f')) for _ in test_res[metric_name].values()]
111 | test_table.add_row([metric_name] + temp)
112 | print(test_table.get_string())
113 |
114 |
115 | def show_features(config_file, weights):
116 | args = get_args(config_file)
117 |
118 | _, test_loader = build_dataloader(args.dataset, None)
119 |
120 | model = build_model(args.model)
121 | state_dict = torch.load(weights)
122 | if 'state_dict' in state_dict.keys():
123 | state_dict = state_dict['state_dict']
124 | model.load_state_dict(state_dict)
125 |
126 | print(F'\nModel: {model.__class__.__name__}')
127 | print(F'Weights file: {weights}')
128 |
129 | model.to(args.device)
130 | model.eval()
131 | for batch_data in test_loader:
132 | data, label = batch_data['image'].to(args.device), batch_data['label'].to(args.device)
133 | preds = model.inference_features(data)
134 | features_ = preds['feats'][-3].permute(0, 2, 3, 1).contiguous().view(16 * 28 * 28, -1).cpu().detach().numpy()
135 | labels_ = torch.nn.functional.interpolate(label, (28, 28), mode='nearest')
136 | labels_ = labels_.permute(0, 2, 3, 1).contiguous().view(16 * 28 * 28, -1).squeeze(1).cpu().detach().numpy()
137 |
138 | features = features_[labels_ != 0]
139 | labels = labels_[labels_ != 0]
140 | tsne = TSNE(n_components=2)
141 | X_tsne = tsne.fit_transform(features)
142 | x_min, x_max = X_tsne.min(0), X_tsne.max(0)
143 | X_norm = (X_tsne - x_min) / (x_max - x_min) # 归一化
144 | plt.figure(figsize=(8, 8))
145 | for i in range(X_norm.shape[0]):
146 | plt.text(X_norm[i, 0], X_norm[i, 1], str(labels[i]), color=plt.cm.Set1(labels[i]),
147 | fontdict={'weight': 'bold', 'size': 9})
148 | plt.xticks([])
149 | plt.yticks([])
150 | plt.show()
151 | exit()
152 |
153 |
154 | if __name__ == '__main__':
155 | # test(r'weights/ugpcl_acdc_136/ugpcl_unet_r50.yaml',
156 | # r'F:\DeskTop\UGPCL\weights\ugpcl_acdc_136\ugpcl_88.11.pth')
157 | show_features(r'weights/ugpcl_acdc_136/ugpcl_unet_r50.yaml',
158 | r'F:\DeskTop\UGPCL\weights\ugpcl_acdc_136\ugpcl_88.11.pth')
159 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import argparse
4 | import warnings
5 | import random
6 |
7 | import torch
8 | import numpy as np
9 |
10 | from datetime import datetime
11 | from codes.builder import build_dataloader, build_logger
12 | from codes.utils.utils import Namespace, parse_yaml
13 | from codes.trainers import *
14 |
15 |
16 | def get_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--config', type=str, default='configs/comparison_acdc_224_136/ugpcl_unet_r50.yaml',
19 | help='train config file path: xxx.yaml')
20 | parser.add_argument('--work_dir', type=str,
21 | default=f'results/comparison_acdc_224_136',
22 | help='the dir to save logs and models')
23 | parser.add_argument('--resume_from', type=str,
24 | # default='results/comparison_acdc_224_136/ugcl_mem_unet_r50_0430155558/iter_1000.pth',
25 | default=None,
26 | help='the checkpoint file to resume from')
27 | parser.add_argument('--start_step', type=int, default=0)
28 | parser.add_argument('--device', type=str, default='cuda:0' if torch.cuda.is_available() else 'cpu')
29 | parser.add_argument('--data_parallel', type=bool, default=False)
30 | parser.add_argument('--seed', type=int, default=1337, help='random seed')
31 | parser.add_argument('--deterministic', type=bool, default=True,
32 | help='whether to set deterministic options for CUDNN backend.')
33 | args = parser.parse_args()
34 |
35 | args_dict = parse_yaml(args.config)
36 |
37 | for key, value in Namespace(args_dict).__dict__.items():
38 | if key in ['name', 'dataset', 'train', 'logger']:
39 | vars(args)[key] = value
40 |
41 | for key, value in Namespace(args_dict).__dict__.items():
42 | if key not in ['name', 'dataset', 'train', 'logger']:
43 | vars(args.train.kwargs)[key] = value
44 |
45 | if args.work_dir is None:
46 | args.work_dir = f'results/{args.dataset.name}'
47 | if args.resume_from is not None:
48 | args.logger.log_dir = os.path.split(os.path.abspath(args.resume_from))[0]
49 | args.logger.file_mode = 'a'
50 | else:
51 | args.logger.log_dir = f'{args.work_dir}/{args.name}_{datetime.now().strftime("%m%d%H%M%S")}'
52 | args.ckpt_save_path = args.logger.log_dir
53 |
54 | for key in args.__dict__.keys():
55 | if key not in args_dict.keys():
56 | args_dict[key] = args.__dict__[key]
57 |
58 | return args, args_dict
59 |
60 |
61 | def set_deterministic(seed):
62 | if seed is not None:
63 | random.seed(seed)
64 | np.random.seed(seed)
65 | torch.manual_seed(seed)
66 | torch.cuda.manual_seed(seed)
67 | torch.backends.cudnn.deterministic = True
68 | torch.backends.cudnn.benchmark = False
69 |
70 |
71 | def build_trainer(name,
72 | logger=None,
73 | device='cuda',
74 | data_parallel=False,
75 | ckpt_save_path=None,
76 | resume_from=None,
77 | **kwargs):
78 | return eval(f'{name}')(logger=logger, device=device, data_parallel=data_parallel, ckpt_save_path=ckpt_save_path,
79 | resume_from=resume_from, **kwargs)
80 |
81 |
82 | def train():
83 | args, args_dict = get_args()
84 | set_deterministic(args.seed)
85 |
86 | def worker_init_fn(worker_id):
87 | random.seed(worker_id + args.seed)
88 |
89 | train_loader, val_loader = build_dataloader(args.dataset, worker_init_fn)
90 | logger = build_logger(args.logger)
91 |
92 | args_yaml_info = yaml.dump(args_dict, sort_keys=False, default_flow_style=None)
93 | yaml_file_name = os.path.split(args.config)[-1]
94 | with open(os.path.join(args.ckpt_save_path, yaml_file_name), 'w') as f:
95 | f.write(args_yaml_info)
96 | f.close()
97 |
98 | logger.info(f'\n{args_yaml_info}\n')
99 |
100 | trainer = build_trainer(name=args.train.name,
101 | logger=logger,
102 | device=args.device,
103 | data_parallel=args.data_parallel,
104 | ckpt_save_path=args.ckpt_save_path,
105 | resume_from=args.resume_from,
106 | **args.train.kwargs.__dict__)
107 | trainer.train(train_loader, val_loader)
108 | logger.close()
109 |
110 |
111 | if __name__ == '__main__':
112 | warnings.filterwarnings("ignore")
113 | train()
114 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
2 |
3 |
4 | python train.py --config='configs/comparison_acdc_224_136/ugpcl_unet_r50.yaml' --device='cuda:1' \
5 | --work_dir='results/comparison_acdc_224_136'
6 |
7 |
--------------------------------------------------------------------------------