├── LICENSE
├── README.md
├── checkpoint
└── readme.md
├── imgs
├── MGCC.png
├── framework.png
├── gen_bus.png
└── gen_tus.png
├── split.py
├── src
├── dataloader
│ └── dataset.py
├── network
│ └── MGCC.py
└── utils
│ ├── losses.py
│ ├── metrics.py
│ ├── ramps.py
│ └── util.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Fenghe Tang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-Level Global Context Cross Consistency for Semi-Supervised Ultrasound Image Segmentation with Diffusion Model
2 |
3 | [Paper](https://arxiv.org/pdf/2305.09447) | [Code](https://github.com/FengheTan9/Multi-Level_Global_Context_Cross_Consistency)
4 |
5 | a Pytorch code base for [Multi-Level Global Context Cross Consistency Model for Semi-Supervised Ultrasound Image Segmentation with Diffusion Model](https://arxiv.org/pdf/2305.09447)
6 |
7 | ## Introduction
8 | Medical image segmentation is a critical step in computer-aided diagnosis, and convolutional neural networks are popular segmentation networks nowadays. However, the inherent local operation characteristics make it difficult to focus on the global contextual information of lesions with different positions, shapes, and sizes. Semi-supervised learning can be used to learn from both labeled and unlabeled samples, alleviating the burden of manual labeling. However, obtaining a large number of unlabeled images in medical scenarios remains challenging. To address these issues, we propose a Multi-level Global Context Cross-consistency (MGCC) framework that uses images generated by a Latent Diffusion Model (LDM) as unlabeled images for semi-supervised learning. The framework involves of two stages. In the first stage, a LDM is used to generate synthetic medical images, which reduces the workload of data annotation and addresses privacy concerns associated with collecting medical data. In the second stage, varying levels of global context noise perturbation are added to the input of the auxiliary decoder, and output consistency is maintained between decoders to improve the representation ability. Experiments conducted on open-source breast ultrasound and private thyroid ultrasound datasets demonstrate the effectiveness of our framework in bridging the probability distribution and the semantic representation of the medical image. Our approach enables the effective transfer of probability distribution knowledge to the segmentation network, resulting in improved segmentation accuracy.
9 |
10 | ### MGCC framework:
11 |
12 | 
13 |
14 | ### MGCC model
15 |
16 | 
17 |
18 | ### **Generation results**
19 |
20 | **BUSI Result:**
21 |
22 |
23 |
24 | **TUS Result:**
25 |
26 |
27 |
28 |
29 |
30 | ## Datasets
31 |
32 | Please put the [BUSI](https://www.kaggle.com/aryashah2k/breast-ultrasound-images-dataset) dataset or your own dataset as the following architecture.
33 | ```
34 | ├── CMUNet
35 | ├── inputs
36 | ├── BUSI
37 | ├── images
38 | | ├── benign (10).png
39 | │ ├── malignant (17).png
40 | │ ├── normal (14).png
41 | │ ├── ...
42 | |
43 | └── masks
44 | ├── 0
45 | | ├── benign (10).png
46 | | ├── malignant (17).png
47 | | ├── normal (14).png
48 | | ├── ...
49 | ├── your dataset
50 | ├── images
51 | | ├── 0a7e06.png
52 | │ ├── 0aab0a.png
53 | │ ├── 0b1761.png
54 | │ ├── ...
55 | |
56 | └── masks
57 | ├── 0
58 | | ├── 0a7e06.png
59 | | ├── 0aab0a.png
60 | | ├── 0b1761.png
61 | | ├── ...
62 | ```
63 | ## Environment
64 |
65 | - GPU: NVIDIA GeForce RTX4090 GPU
66 | - Pytorch: 1.13.0 cuda 11.7
67 | - cudatoolkit: 11.7.1
68 | - scikit-learn: 1.0.2
69 |
70 | ## Training and Validation
71 |
72 | 1. Generate Stage:
73 |
74 | You can follow this [work](https://github.com/mueller-franzes/medfusion).
75 |
76 | 2. Semi-supervised Learning Stage:
77 |
78 | You can first split your dataset:
79 |
80 | ```python
81 | python split.py
82 | ```
83 |
84 | Then, training your dataset:
85 |
86 | ```python
87 | python train.py
88 | ```
89 |
90 | ## Acknowledgements:
91 |
92 | This code-base uses helper functions from [CMU-Net](https://github.com/FengheTan9/CMU-Net), [SSL4MIS](https://github.com/HiLab-git/SSL4MIS) and [medFusion](https://github.com/mueller-franzes/medfusion).
93 |
94 | ## Citation
95 |
96 | If you use our code, please cite our paper:
97 |
98 | ```tex
99 | @article{tang2023multi,
100 | title={Multi-Level Global Context Cross Consistency Model for Semi-Supervised Ultrasound Image Segmentation with Diffusion Model},
101 | author={Tang, Fenghe and Ding, Jianrui and Wang, Lingtao and Xian, Min and Ning, Chunping},
102 | journal={arXiv preprint arXiv:2305.09447},
103 | year={2023}
104 | }
105 | ```
106 |
107 |
--------------------------------------------------------------------------------
/checkpoint/readme.md:
--------------------------------------------------------------------------------
1 | Build the checkpoint file to save your model.
--------------------------------------------------------------------------------
/imgs/MGCC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FengheTan9/Multi-Level-Global-Context-Cross-Consistency/a9c8bd3bb5bb6e0bef9a4bc69f7aa812888a7e0a/imgs/MGCC.png
--------------------------------------------------------------------------------
/imgs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FengheTan9/Multi-Level-Global-Context-Cross-Consistency/a9c8bd3bb5bb6e0bef9a4bc69f7aa812888a7e0a/imgs/framework.png
--------------------------------------------------------------------------------
/imgs/gen_bus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FengheTan9/Multi-Level-Global-Context-Cross-Consistency/a9c8bd3bb5bb6e0bef9a4bc69f7aa812888a7e0a/imgs/gen_bus.png
--------------------------------------------------------------------------------
/imgs/gen_tus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FengheTan9/Multi-Level-Global-Context-Cross-Consistency/a9c8bd3bb5bb6e0bef9a4bc69f7aa812888a7e0a/imgs/gen_tus.png
--------------------------------------------------------------------------------
/split.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from sklearn.model_selection import train_test_split
4 |
5 | name = 'busi'
6 |
7 | root = r'./data/' + name
8 |
9 | img_ids = glob(os.path.join(root, 'images', '*.png'))
10 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
11 |
12 |
13 | count = 1
14 | for i in [41, 64, 1337]:
15 | train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.3, random_state=i)
16 | filename = root + '/{}_train{}.txt'.format(name, count)
17 | with open(filename, 'w') as file:
18 | for i in train_img_ids:
19 | file.write(i + '\n')
20 |
21 | filename = root + '/{}_val{}.txt'.format(name, count)
22 | with open(filename, 'w') as file:
23 | for i in val_img_ids:
24 | file.writelines(i + '\n')
25 |
26 | print(train_img_ids)
27 | print(val_img_ids)
28 | count += 1
29 |
--------------------------------------------------------------------------------
/src/dataloader/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import itertools
5 | from torch.utils.data.sampler import Sampler
6 | import cv2
7 |
8 |
9 | class SemiDataSets(Dataset):
10 | def __init__(
11 | self,
12 | base_dir=None,
13 | split="train",
14 | transform=None,
15 | train_file_dir="train.txt",
16 | val_file_dir="val.txt",
17 | ):
18 | self._base_dir = base_dir
19 | self.sample_list = []
20 | self.split = split
21 | self.transform = transform
22 | self.train_list = []
23 | self.semi_list = []
24 |
25 | if self.split == "train":
26 | with open(os.path.join(self._base_dir, train_file_dir), "r") as f1:
27 | self.sample_list = f1.readlines()
28 | self.sample_list = [item.replace("\n", "") for item in self.sample_list]
29 |
30 | elif self.split == "val":
31 | with open(os.path.join(self._base_dir, val_file_dir), "r") as f:
32 | self.sample_list = f.readlines()
33 | self.sample_list = [item.replace("\n", "") for item in self.sample_list]
34 |
35 | print("total {} samples".format(len(self.sample_list)))
36 |
37 | def __len__(self):
38 | return len(self.sample_list)
39 |
40 | def __getitem__(self, idx):
41 |
42 | case = self.sample_list[idx]
43 |
44 | image = cv2.imread(os.path.join(self._base_dir, 'images', case + '.png'))
45 | label = \
46 | cv2.imread(os.path.join(self._base_dir, 'masks', '0', case + '.png'), cv2.IMREAD_GRAYSCALE)[
47 | ..., None]
48 |
49 | augmented = self.transform(image=image, mask=label)
50 | image = augmented['image']
51 | label = augmented['mask']
52 |
53 | image = image.astype('float32') / 255
54 | image = image.transpose(2, 0, 1)
55 |
56 | label = label.astype('float32') / 255
57 | label = label.transpose(2, 0, 1)
58 |
59 | sample = {"image": image, "label": label, "idx": idx}
60 | return sample
61 |
62 |
63 |
64 | class TwoStreamBatchSampler(Sampler):
65 | """Iterate two sets of indices
66 |
67 | An 'epoch' is one iteration through the primary indices.
68 | During the epoch, the secondary indices are iterated through
69 | as many times as needed.
70 | """
71 |
72 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
73 | self.primary_indices = primary_indices
74 | self.secondary_indices = secondary_indices
75 | self.secondary_batch_size = secondary_batch_size
76 | self.primary_batch_size = batch_size - secondary_batch_size
77 |
78 | assert len(self.primary_indices) >= self.primary_batch_size > 0
79 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0
80 |
81 | def __iter__(self):
82 | primary_iter = iterate_once(self.primary_indices)
83 | secondary_iter = iterate_eternally(self.secondary_indices)
84 | return (
85 | primary_batch + secondary_batch
86 | for (primary_batch, secondary_batch) in zip(
87 | grouper(primary_iter, self.primary_batch_size),
88 | grouper(secondary_iter, self.secondary_batch_size),
89 | )
90 | )
91 |
92 | def __len__(self):
93 | return len(self.primary_indices) // self.primary_batch_size
94 |
95 |
96 | def iterate_once(iterable):
97 | return np.random.permutation(iterable)
98 |
99 |
100 | def iterate_eternally(indices):
101 | def infinite_shuffles():
102 | while True:
103 | yield np.random.permutation(indices)
104 |
105 | return itertools.chain.from_iterable(infinite_shuffles())
106 |
107 |
108 | def grouper(iterable, n):
109 | "Collect data into fixed-length chunks or blocks"
110 | # grouper('ABCDEFG', 3) --> ABC DEF"
111 | args = [iter(iterable)] * n
112 | return zip(*args)
113 |
114 |
--------------------------------------------------------------------------------
/src/network/MGCC.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.distributions.uniform import Uniform
4 | import numpy as np
5 |
6 |
7 | class MSAG(nn.Module):
8 | """
9 | Multi-scale attention gate
10 | Arxiv: https://arxiv.org/abs/2210.13012
11 | """
12 | def __init__(self, channel):
13 | super(MSAG, self).__init__()
14 | self.channel = channel
15 | self.pointwiseConv = nn.Sequential(
16 | nn.Conv2d(self.channel, self.channel, kernel_size=1, padding=0, bias=True),
17 | nn.BatchNorm2d(self.channel),
18 | )
19 | self.ordinaryConv = nn.Sequential(
20 | nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=1, stride=1, bias=True),
21 | nn.BatchNorm2d(self.channel),
22 | )
23 | self.dilationConv = nn.Sequential(
24 | nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=2, stride=1, dilation=2, bias=True),
25 | nn.BatchNorm2d(self.channel),
26 | )
27 | self.voteConv = nn.Sequential(
28 | nn.Conv2d(self.channel * 3, self.channel, kernel_size=(1, 1)),
29 | nn.BatchNorm2d(self.channel),
30 | nn.Sigmoid()
31 | )
32 | self.relu = nn.ReLU(inplace=True)
33 |
34 | def forward(self, x):
35 | x1 = self.pointwiseConv(x)
36 | x2 = self.ordinaryConv(x)
37 | x3 = self.dilationConv(x)
38 | _x = self.relu(torch.cat((x1, x2, x3), dim=1))
39 | _x = self.voteConv(_x)
40 | x = x + x * _x
41 | return x
42 |
43 |
44 | class Residual(nn.Module):
45 | def __init__(self, fn):
46 | super().__init__()
47 | self.fn = fn
48 |
49 | def forward(self, x):
50 | return self.fn(x) + x
51 |
52 |
53 | class ConvMixerBlock(nn.Module):
54 | def __init__(self, dim=1024, depth=7, k=7):
55 | super(ConvMixerBlock, self).__init__()
56 | self.block = nn.Sequential(
57 | *[nn.Sequential(
58 | Residual(nn.Sequential(
59 | # deep wise
60 | nn.Conv2d(dim, dim, kernel_size=(k, k), groups=dim, padding=(k // 2, k // 2)),
61 | nn.GELU(),
62 | nn.BatchNorm2d(dim)
63 | )),
64 | nn.Conv2d(dim, dim, kernel_size=(1, 1)),
65 | nn.GELU(),
66 | nn.BatchNorm2d(dim)
67 | ) for i in range(depth)]
68 | )
69 |
70 | def forward(self, x):
71 | x = self.block(x)
72 | return x
73 |
74 |
75 | class conv_block(nn.Module):
76 | def __init__(self, ch_in, ch_out):
77 | super(conv_block, self).__init__()
78 | self.conv = nn.Sequential(
79 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
80 | nn.BatchNorm2d(ch_out),
81 | nn.ReLU(inplace=True),
82 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
83 | nn.BatchNorm2d(ch_out),
84 | nn.ReLU(inplace=True)
85 | )
86 |
87 | def forward(self, x):
88 | x = self.conv(x)
89 | return x
90 |
91 |
92 | class up_conv(nn.Module):
93 | def __init__(self, ch_in, ch_out):
94 | super(up_conv, self).__init__()
95 | self.up = nn.Sequential(
96 | nn.Upsample(scale_factor=2),
97 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
98 | nn.BatchNorm2d(ch_out),
99 | nn.ReLU(inplace=True)
100 | )
101 |
102 | def forward(self, x):
103 | x = self.up(x)
104 | return x
105 |
106 |
107 | class FeatureNoise(nn.Module):
108 | def __init__(self, uniform_range=0.3):
109 | super(FeatureNoise, self).__init__()
110 | self.uni_dist = Uniform(-uniform_range, uniform_range)
111 |
112 | def feature_based_noise(self, x):
113 | noise_vector = self.uni_dist.sample(
114 | x.shape[1:]).to(x.device).unsqueeze(0)
115 | x_noise = x.mul(noise_vector) + x
116 | return x_noise
117 |
118 | def forward(self, x):
119 | x = self.feature_based_noise(x)
120 | return x
121 |
122 |
123 | def Dropout(x, p=0.3):
124 | x = torch.nn.functional.dropout(x, p)
125 | return x
126 |
127 |
128 | def FeatureDropout(x):
129 | attention = torch.mean(x, dim=1, keepdim=True)
130 | max_val, _ = torch.max(attention.view(
131 | x.size(0), -1), dim=1, keepdim=True)
132 | threshold = max_val * np.random.uniform(0.7, 0.9)
133 | threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
134 | drop_mask = (attention < threshold).float()
135 | x = x.mul(drop_mask)
136 | return x
137 |
138 |
139 | class Decoder(nn.Module):
140 |
141 | def __init__(self, dim_mult=4, with_masg=True):
142 | super(Decoder, self).__init__()
143 | self.with_masg = with_masg
144 | self.Up5 = up_conv(ch_in=256 * dim_mult, ch_out=128 * dim_mult)
145 | self.Up_conv5 = conv_block(ch_in=128 * 2 * dim_mult, ch_out=128 * dim_mult)
146 | self.Up4 = up_conv(ch_in=128 * dim_mult, ch_out=64 * dim_mult)
147 | self.Up_conv4 = conv_block(ch_in=64 * 2 * dim_mult, ch_out=64 * dim_mult)
148 | self.Up3 = up_conv(ch_in=64 * dim_mult, ch_out=32 * dim_mult)
149 | self.Up_conv3 = conv_block(ch_in=32 * 2 * dim_mult, ch_out=32 * dim_mult)
150 | self.Up2 = up_conv(ch_in=32 * dim_mult, ch_out=16 * dim_mult)
151 | self.Up_conv2 = conv_block(ch_in=16 * 2 * dim_mult, ch_out=16 * dim_mult)
152 | self.Conv_1x1 = nn.Conv2d(16 * dim_mult, 1, kernel_size=1, stride=1, padding=0)
153 |
154 | self.msag4 = MSAG(128 * dim_mult)
155 | self.msag3 = MSAG(64 * dim_mult)
156 | self.msag2 = MSAG(32 * dim_mult)
157 | self.msag1 = MSAG(16 * dim_mult)
158 |
159 | def forward(self, feature):
160 | x1, x2, x3, x4, x5 = feature
161 | if self.with_masg:
162 | x4 = self.msag4(x4)
163 | x3 = self.msag3(x3)
164 | x2 = self.msag2(x2)
165 | x1 = self.msag1(x1)
166 |
167 | d5 = self.Up5(x5)
168 | d5 = torch.cat((x4, d5), dim=1)
169 | d5 = self.Up_conv5(d5)
170 |
171 | d4 = self.Up4(d5)
172 | d4 = torch.cat((x3, d4), dim=1)
173 | d4 = self.Up_conv4(d4)
174 |
175 | d3 = self.Up3(d4)
176 | d3 = torch.cat((x2, d3), dim=1)
177 | d3 = self.Up_conv3(d3)
178 |
179 | d2 = self.Up2(d3)
180 | d2 = torch.cat((x1, d2), dim=1)
181 | d2 = self.Up_conv2(d2)
182 | d1 = self.Conv_1x1(d2)
183 | return d1
184 |
185 |
186 | class MGCC(nn.Module):
187 | def __init__(self, img_ch=3, length=(3, 3, 3), k=7, dim_mult=4):
188 | """
189 | Multi-Level Global Context Cross Consistency Model
190 | Args:
191 | img_ch : input channel.
192 | output_ch: output channel.
193 | length: number of convMixer layers
194 | k: kernal size of convMixer
195 |
196 | """
197 | super(MGCC, self).__init__()
198 |
199 | # Encoder
200 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
201 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=16 * dim_mult)
202 | self.Conv2 = conv_block(ch_in=16 * dim_mult, ch_out=32 * dim_mult)
203 | self.Conv3 = conv_block(ch_in=32 * dim_mult, ch_out=64 * dim_mult)
204 | self.Conv4 = conv_block(ch_in=64 * dim_mult, ch_out=128 * dim_mult)
205 | self.Conv5 = conv_block(ch_in=128 * dim_mult, ch_out=256 * dim_mult)
206 | self.ConvMixer1 = ConvMixerBlock(dim=256 * dim_mult, depth=length[0], k=k)
207 | self.ConvMixer2 = ConvMixerBlock(dim=256 * dim_mult, depth=length[1], k=k)
208 | self.ConvMixer3 = ConvMixerBlock(dim=256 * dim_mult, depth=length[2], k=k)
209 | # main Decoder
210 | self.main_decoder = Decoder(dim_mult=dim_mult, with_masg=True)
211 | # aux Decoder
212 | self.aux_decoder1 = Decoder(dim_mult=dim_mult, with_masg=True)
213 | self.aux_decoder2 = Decoder(dim_mult=dim_mult, with_masg=True)
214 | self.aux_decoder3 = Decoder(dim_mult=dim_mult, with_masg=True)
215 |
216 | def forward(self, x):
217 | x1 = self.Conv1(x)
218 | x2 = self.Maxpool(x1)
219 | x2 = self.Conv2(x2)
220 | x3 = self.Maxpool(x2)
221 | x3 = self.Conv3(x3)
222 | x4 = self.Maxpool(x3)
223 | x4 = self.Conv4(x4)
224 | x5 = self.Maxpool(x4)
225 | x5 = self.Conv5(x5)
226 |
227 | if not self.training:
228 | x5 = self.ConvMixer1(x5)
229 | x5 = self.ConvMixer2(x5)
230 | x5 = self.ConvMixer3(x5)
231 | feature = [x1, x2, x3, x4, x5]
232 | main_seg = self.main_decoder(feature)
233 | return main_seg
234 |
235 | # FeatureNoise
236 | feature = [x1, x2, x3, x4, x5]
237 | aux1_feature = [FeatureDropout(i) for i in feature]
238 | aux_seg1 = self.aux_decoder1(aux1_feature)
239 |
240 | x5 = self.ConvMixer1(x5)
241 | feature = [x1, x2, x3, x4, x5]
242 | aux2_feature = [Dropout(i) for i in feature]
243 | aux_seg2 = self.aux_decoder2(aux2_feature)
244 |
245 | x5 = self.ConvMixer2(x5)
246 | feature = [x1, x2, x3, x4, x5]
247 | aux3_feature = [FeatureNoise()(i) for i in feature]
248 | aux_seg3 = self.aux_decoder3(aux3_feature)
249 |
250 | # main decoder
251 | x5 = self.ConvMixer3(x5)
252 | feature = [x1, x2, x3, x4, x5]
253 | main_seg = self.main_decoder(feature)
254 | return main_seg, aux_seg1, aux_seg2, aux_seg3
255 |
--------------------------------------------------------------------------------
/src/utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | __all__ = ['BCEDiceLoss']
7 |
8 |
9 | class BCEDiceLoss(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def forward(self, input, target):
14 | bce = F.binary_cross_entropy_with_logits(input, target)
15 | smooth = 1e-5
16 | input = torch.sigmoid(input)
17 | num = target.size(0)
18 | input = input.view(num, -1)
19 | target = target.view(num, -1)
20 | intersection = (input * target)
21 | dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
22 | dice = 1 - dice.sum() / num
23 | return 0.5 * bce + dice
24 |
25 |
26 | def compute_kl_loss(p, q):
27 | p_loss = F.kl_div(F.log_softmax(p, dim=-1),
28 | F.softmax(q, dim=-1), reduction='none')
29 | q_loss = F.kl_div(F.log_softmax(q, dim=-1),
30 | F.softmax(p, dim=-1), reduction='none')
31 |
32 | # Using function "sum" and "mean" are depending on your task
33 | p_loss = p_loss.mean()
34 | q_loss = q_loss.mean()
35 |
36 | loss = (p_loss + q_loss) / 2
37 | return loss
38 |
--------------------------------------------------------------------------------
/src/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_accuracy(SR, GT, threshold=0.5):
5 | SR = SR > threshold
6 | GT = GT == torch.max(GT)
7 | corr = torch.sum(SR == GT)
8 | tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3)
9 | acc = float(corr)/float(tensor_size)
10 | return acc
11 |
12 |
13 | def get_sensitivity(SR, GT, threshold=0.5):
14 | # Sensitivity == Recall
15 | SE = 0
16 | SR = SR > threshold
17 | GT = GT == torch.max(GT)
18 | # TP : True Positive
19 | # FN : False Negative
20 | TP = ((SR == 1).byte() + (GT == 1).byte()) == 2
21 | FN = ((SR == 0).byte() + (GT == 1).byte()) == 2
22 | SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6)
23 | return SE
24 |
25 | def get_specificity(SR,GT,threshold=0.5):
26 | SP = 0
27 | SR = SR > threshold
28 | GT = GT == torch.max(GT)
29 | # TN : True Negative
30 | # FP : False Positive
31 | TN = ((SR == 0).byte() + (GT == 0).byte()) == 2
32 | FP = ((SR == 1).byte() + (GT == 0).byte()) == 2
33 | SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6)
34 | return SP
35 |
36 | def get_precision(SR,GT,threshold=0.5):
37 | PC = 0
38 | SR = SR > threshold
39 | GT = GT== torch.max(GT)
40 | # TP : True Positive
41 | # FP : False Positive
42 | TP = ((SR == 1).byte() + (GT == 1).byte()) == 2
43 | FP = ((SR == 1).byte() + (GT == 0).byte()) == 2
44 | PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6)
45 | return PC
46 |
47 | def iou_score(output, target):
48 | smooth = 1e-5
49 |
50 | if torch.is_tensor(output):
51 | output = torch.sigmoid(output).data.cpu().numpy()
52 | if torch.is_tensor(target):
53 | target = target.data.cpu().numpy()
54 | output_ = output > 0.5
55 | target_ = target > 0.5
56 |
57 | intersection = (output_ & target_).sum()
58 | union = (output_ | target_).sum()
59 | iou = (intersection + smooth) / (union + smooth)
60 | dice = (2 * iou) / (iou + 1)
61 |
62 | output_ = torch.tensor(output_)
63 | target_ = torch.tensor(target_)
64 | SE = get_sensitivity(output_, target_, threshold=0.5)
65 | PC = get_precision(output_, target_, threshold=0.5)
66 | SP= get_specificity(output_, target_, threshold=0.5)
67 | ACC=get_accuracy(output_, target_, threshold=0.5)
68 | F1 = 2*SE*PC/(SE+PC + 1e-6)
69 | return iou, dice, SE, PC, F1, SP, ACC
70 |
71 |
72 | def dice_coef(output, target):
73 | smooth = 1e-5
74 | output = torch.sigmoid(output).view(-1).data.cpu().numpy()
75 | target = target.view(-1).data.cpu().numpy()
76 | intersection = (output * target).sum()
77 | return (2. * intersection + smooth) / \
78 | (output.sum() + target.sum() + smooth)
79 |
80 |
81 |
--------------------------------------------------------------------------------
/src/utils/ramps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | """Functions for ramping hyperparameters up or down
4 |
5 | Each function takes the current training step or epoch, and the
6 | ramp length in the same format, and returns a multiplier between
7 | 0 and 1.
8 | """
9 |
10 | def sigmoid_rampup(current, rampup_length):
11 | """Exponential rampup from https://arxiv.org/abs/1610.02242"""
12 | if rampup_length == 0:
13 | return 1.0
14 | else:
15 | current = np.clip(current, 0.0, rampup_length)
16 | phase = 1.0 - current / rampup_length
17 | return float(np.exp(-5.0 * phase * phase))
18 |
19 |
20 | def linear_rampup(current, rampup_length):
21 | """Linear rampup"""
22 | assert current >= 0 and rampup_length >= 0
23 | if current >= rampup_length:
24 | return 1.0
25 | else:
26 | return current / rampup_length
27 |
28 |
29 | def cosine_rampdown(current, rampdown_length):
30 | """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
31 | assert 0 <= current <= rampdown_length
32 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
33 |
--------------------------------------------------------------------------------
/src/utils/util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def str2bool(v):
4 | if v.lower() in ['true', 1]:
5 | return True
6 | elif v.lower() in ['false', 0]:
7 | return False
8 | else:
9 | raise argparse.ArgumentTypeError('Boolean value expected.')
10 |
11 |
12 | def count_params(model):
13 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
14 |
15 |
16 | class AverageMeter(object):
17 | """Computes and stores the average and current value"""
18 |
19 | def __init__(self):
20 | self.reset()
21 |
22 | def reset(self):
23 | self.val = 0
24 | self.avg = 0
25 | self.sum = 0
26 | self.count = 0
27 |
28 | def update(self, val, n=1):
29 | self.val = val
30 | self.sum += val * n
31 | self.count += n
32 | self.avg = self.sum / self.count
33 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import numpy as np
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | import torch.optim as optim
7 | from torch.utils.data import DataLoader
8 | from torchvision import transforms
9 | from albumentations.augmentations import transforms
10 | from albumentations.core.composition import Compose
11 | from albumentations import RandomRotate90, Resize
12 | import src.utils.losses as losses
13 | from src.utils.util import AverageMeter
14 | from src.utils.metrics import iou_score
15 | from src.utils import ramps
16 | from src.dataloader.dataset import (SemiDataSets, TwoStreamBatchSampler)
17 | from src.network.MGCC import MGCC
18 | import os
19 |
20 |
21 | def seed_torch(seed):
22 | np.random.seed(seed)
23 | torch.manual_seed(seed)
24 | torch.cuda.manual_seed(seed)
25 | torch.cuda.manual_seed_all(seed)
26 | torch.backends.cudnn.benchmark = False
27 | torch.backends.cudnn.deterministic = True
28 | random.seed(seed)
29 | np.random.seed(seed)
30 | os.environ['PYTHONHASHSEED'] = str(seed)
31 |
32 |
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--semi_percent', type=float, default=0.5)
35 | parser.add_argument('--base_dir', type=str, default="./data/busi", help='dir')
36 | parser.add_argument('--train_file_dir', type=str, default="busi_train1.txt", help='dir')
37 | parser.add_argument('--val_file_dir', type=str, default="busi_val1.txt", help='dir')
38 | parser.add_argument('--max_iterations', type=int,
39 | default=40000, help='maximum epoch number to train')
40 | parser.add_argument('--total_batch_size', type=int, default=8,
41 | help='batch_size per gpu')
42 | parser.add_argument('--base_lr', type=float, default=0.01,
43 | help='segmentation network learning rate')
44 | parser.add_argument('--seed', type=int, default=41, help='random seed')
45 | # label and unlabel
46 | parser.add_argument('--labeled_bs', type=int, default=4,
47 | help='labeled_batch_size per gpu')
48 | # costs
49 | parser.add_argument('--consistency', type=float,
50 | default=7, help='consistency')
51 | parser.add_argument('--consistency_rampup', type=float,
52 | default=200.0, help='consistency_rampup')
53 | # MGCC hyperparameter
54 | parser.add_argument('--kernel_size', type=int,
55 | default=7, help='ConvMixer kernel size')
56 | parser.add_argument('--length', type=tuple,
57 | default=(3, 3, 3), help='length of ConvMixer')
58 | args = parser.parse_args()
59 |
60 | seed_torch(args.seed)
61 |
62 |
63 | def getDataloader(args):
64 | train_transform = Compose([
65 | RandomRotate90(),
66 | transforms.Flip(),
67 | Resize(256, 256),
68 | transforms.Normalize(),
69 | ])
70 | val_transform = Compose([
71 | Resize(256, 256),
72 | transforms.Normalize(),
73 | ])
74 | labeled_slice = args.semi_percent
75 | db_train = SemiDataSets(base_dir=args.base_dir, split="train", transform=train_transform,
76 | train_file_dir=args.train_file_dir, val_file_dir=args.val_file_dir,
77 | )
78 | db_val = SemiDataSets(base_dir=args.base_dir, split="val", transform=val_transform,
79 | train_file_dir=args.train_file_dir, val_file_dir=args.val_file_dir
80 | )
81 |
82 | def worker_init_fn(worker_id):
83 | random.seed(args.seed + worker_id)
84 |
85 | total_slices = len(db_train)
86 | labeled_idxs = list(range(0, int(labeled_slice * total_slices)))
87 | unlabeled_idxs = list(range(int(labeled_slice * total_slices), total_slices))
88 | print("label num:{}, unlabel num:{} percent:{}".format(len(labeled_idxs), len(unlabeled_idxs), labeled_slice))
89 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.total_batch_size, args.labeled_bs)
90 | trainloader = DataLoader(db_train, batch_sampler=batch_sampler,
91 | num_workers=8, pin_memory=False, worker_init_fn=worker_init_fn)
92 | valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)
93 |
94 | return trainloader, valloader
95 |
96 |
97 | def get_current_consistency_weight(epoch):
98 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242
99 | return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)
100 |
101 |
102 | def getModel(args):
103 | print("ConvMixer1:{}, ConvMixer2:{}, ConvMixer3:{}, kernal:{}".format(args.length[0], args.length[1],
104 | args.length[2], args.kernel_size))
105 | return MGCC(length=args.length, k=args.kernel_size).cuda()
106 |
107 |
108 | def train(args):
109 | base_lr = args.base_lr
110 | max_iterations = int(args.max_iterations * args.semi_percent)
111 | trainloader, valloader = getDataloader(args)
112 |
113 | model = getModel(args)
114 |
115 | optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
116 | criterion = losses.__dict__['BCEDiceLoss']().cuda()
117 |
118 | print("{} iterations per epoch".format(len(trainloader)))
119 | best_iou = 0
120 | iter_num = 0
121 | max_epoch = max_iterations // len(trainloader) + 1
122 |
123 | for epoch_num in range(max_epoch):
124 | avg_meters = {'total_loss': AverageMeter(),
125 | 'train_iou': AverageMeter(),
126 | 'consistency_loss': AverageMeter(),
127 | 'supervised_loss': AverageMeter(),
128 | 'val_loss': AverageMeter(),
129 | 'val_iou': AverageMeter(),
130 | 'val_se': AverageMeter(),
131 | 'val_pc': AverageMeter(),
132 | 'val_f1': AverageMeter(),
133 | 'val_acc': AverageMeter()
134 | }
135 | model.train()
136 | for i_batch, sampled_batch in enumerate(trainloader):
137 |
138 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
139 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
140 |
141 | outputs, outputs_aux1, outputs_aux2, outputs_aux3 = model(volume_batch)
142 |
143 | outputs_soft = torch.sigmoid(outputs)
144 | outputs_aux1_soft = torch.sigmoid(outputs_aux1)
145 | outputs_aux2_soft = torch.sigmoid(outputs_aux2)
146 | outputs_aux3_soft = torch.sigmoid(outputs_aux3)
147 |
148 | loss_ce = criterion(outputs[:args.labeled_bs],
149 | label_batch[:args.labeled_bs][:])
150 | loss_ce_aux1 = criterion(outputs_aux1[:args.labeled_bs],
151 | label_batch[:args.labeled_bs][:])
152 | loss_ce_aux2 = criterion(outputs_aux2[:args.labeled_bs],
153 | label_batch[:args.labeled_bs][:])
154 | loss_ce_aux3 = criterion(outputs_aux3[:args.labeled_bs],
155 | label_batch[:args.labeled_bs][:])
156 |
157 | supervised_loss = (loss_ce + loss_ce_aux1 + loss_ce_aux2 + loss_ce_aux3) / 4
158 |
159 | consistency_weight = get_current_consistency_weight(iter_num // 150)
160 | consistency_loss_aux1 = torch.mean(
161 | (outputs_soft[args.labeled_bs:] - outputs_aux1_soft[args.labeled_bs:]) ** 2)
162 | consistency_loss_aux2 = torch.mean(
163 | (outputs_soft[args.labeled_bs:] - outputs_aux2_soft[args.labeled_bs:]) ** 2)
164 | consistency_loss_aux3 = torch.mean(
165 | (outputs_soft[args.labeled_bs:] - outputs_aux3_soft[args.labeled_bs:]) ** 2)
166 |
167 | consistency_loss = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3) / 3
168 | loss = supervised_loss + consistency_weight * consistency_loss
169 | iou, dice, _, _, _, _, _ = iou_score(outputs[:args.labeled_bs], label_batch[:args.labeled_bs])
170 | optimizer.zero_grad()
171 | loss.backward()
172 | optimizer.step()
173 |
174 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9
175 | for param_group in optimizer.param_groups:
176 | param_group['lr'] = lr_
177 |
178 | iter_num = iter_num + 1
179 |
180 | avg_meters['total_loss'].update(loss.item(), volume_batch[:args.labeled_bs].size(0))
181 | avg_meters['supervised_loss'].update(supervised_loss.item(), volume_batch[:args.labeled_bs].size(0))
182 | avg_meters['consistency_loss'].update(consistency_loss.item(), volume_batch[args.labeled_bs:].size(0))
183 | avg_meters['train_iou'].update(iou, volume_batch[:args.labeled_bs].size(0))
184 |
185 | model.eval()
186 | with torch.no_grad():
187 | for i_batch, sampled_batch in enumerate(valloader):
188 | input, target = sampled_batch['image'], sampled_batch['label']
189 | input = input.cuda()
190 | target = target.cuda()
191 | output = model(input)
192 | loss = criterion(output, target)
193 | iou, _, SE, PC, F1, _, ACC = iou_score(output, target)
194 | avg_meters['val_loss'].update(loss.item(), input.size(0))
195 | avg_meters['val_iou'].update(iou, input.size(0))
196 | avg_meters['val_se'].update(SE, input.size(0))
197 | avg_meters['val_pc'].update(PC, input.size(0))
198 | avg_meters['val_f1'].update(F1, input.size(0))
199 | avg_meters['val_acc'].update(ACC, input.size(0))
200 |
201 | print(
202 | 'epoch [%3d/%d] train_loss %.4f supervised_loss %.4f consistency_loss %.4f train_iou: %.4f '
203 | '- val_loss %.4f - val_iou %.4f - val_SE %.4f - val_PC %.4f - val_F1 %.4f - val_ACC %.4f'
204 | % (epoch_num, max_epoch, avg_meters['total_loss'].avg,
205 | avg_meters['supervised_loss'].avg, avg_meters['consistency_loss'].avg, avg_meters['train_iou'].avg,
206 | avg_meters['val_loss'].avg, avg_meters['val_iou'].avg, avg_meters['val_se'].avg,
207 | avg_meters['val_pc'].avg, avg_meters['val_f1'].avg, avg_meters['val_acc'].avg))
208 |
209 | if avg_meters['val_iou'].avg > best_iou:
210 | torch.save(model.state_dict(), 'checkpoint/model.pth')
211 | best_iou = avg_meters['val_iou'].avg
212 | print("=> saved best model")
213 |
214 | return "Training Finished!"
215 |
216 |
217 | if __name__ == "__main__":
218 | train(args)
219 |
--------------------------------------------------------------------------------