├── LICENSE ├── README.md ├── checkpoint └── readme.md ├── imgs ├── MGCC.png ├── framework.png ├── gen_bus.png └── gen_tus.png ├── split.py ├── src ├── dataloader │ └── dataset.py ├── network │ └── MGCC.py └── utils │ ├── losses.py │ ├── metrics.py │ ├── ramps.py │ └── util.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Fenghe Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Level Global Context Cross Consistency for Semi-Supervised Ultrasound Image Segmentation with Diffusion Model 2 | 3 | [Paper](https://arxiv.org/pdf/2305.09447) | [Code](https://github.com/FengheTan9/Multi-Level_Global_Context_Cross_Consistency) 4 | 5 | a Pytorch code base for [Multi-Level Global Context Cross Consistency Model for Semi-Supervised Ultrasound Image Segmentation with Diffusion Model](https://arxiv.org/pdf/2305.09447) 6 | 7 | ## Introduction 8 | Medical image segmentation is a critical step in computer-aided diagnosis, and convolutional neural networks are popular segmentation networks nowadays. However, the inherent local operation characteristics make it difficult to focus on the global contextual information of lesions with different positions, shapes, and sizes. Semi-supervised learning can be used to learn from both labeled and unlabeled samples, alleviating the burden of manual labeling. However, obtaining a large number of unlabeled images in medical scenarios remains challenging. To address these issues, we propose a Multi-level Global Context Cross-consistency (MGCC) framework that uses images generated by a Latent Diffusion Model (LDM) as unlabeled images for semi-supervised learning. The framework involves of two stages. In the first stage, a LDM is used to generate synthetic medical images, which reduces the workload of data annotation and addresses privacy concerns associated with collecting medical data. In the second stage, varying levels of global context noise perturbation are added to the input of the auxiliary decoder, and output consistency is maintained between decoders to improve the representation ability. Experiments conducted on open-source breast ultrasound and private thyroid ultrasound datasets demonstrate the effectiveness of our framework in bridging the probability distribution and the semantic representation of the medical image. Our approach enables the effective transfer of probability distribution knowledge to the segmentation network, resulting in improved segmentation accuracy. 9 | 10 | ### MGCC framework: 11 | 12 | ![framework](imgs/framework.png) 13 | 14 | ### MGCC model 15 | 16 | ![mgcc](imgs/MGCC.png) 17 | 18 | ### **Generation results** 19 | 20 | **BUSI Result:** 21 | 22 | 23 | 24 | **TUS Result:** 25 | 26 | 27 | 28 | 29 | 30 | ## Datasets 31 | 32 | Please put the [BUSI](https://www.kaggle.com/aryashah2k/breast-ultrasound-images-dataset) dataset or your own dataset as the following architecture. 33 | ``` 34 | ├── CMUNet 35 | ├── inputs 36 | ├── BUSI 37 | ├── images 38 | | ├── benign (10).png 39 | │ ├── malignant (17).png 40 | │ ├── normal (14).png 41 | │ ├── ... 42 | | 43 | └── masks 44 | ├── 0 45 | | ├── benign (10).png 46 | | ├── malignant (17).png 47 | | ├── normal (14).png 48 | | ├── ... 49 | ├── your dataset 50 | ├── images 51 | | ├── 0a7e06.png 52 | │ ├── 0aab0a.png 53 | │ ├── 0b1761.png 54 | │ ├── ... 55 | | 56 | └── masks 57 | ├── 0 58 | | ├── 0a7e06.png 59 | | ├── 0aab0a.png 60 | | ├── 0b1761.png 61 | | ├── ... 62 | ``` 63 | ## Environment 64 | 65 | - GPU: NVIDIA GeForce RTX4090 GPU 66 | - Pytorch: 1.13.0 cuda 11.7 67 | - cudatoolkit: 11.7.1 68 | - scikit-learn: 1.0.2 69 | 70 | ## Training and Validation 71 | 72 | 1. Generate Stage: 73 | 74 | You can follow this [work](https://github.com/mueller-franzes/medfusion). 75 | 76 | 2. Semi-supervised Learning Stage: 77 | 78 | You can first split your dataset: 79 | 80 | ```python 81 | python split.py 82 | ``` 83 | 84 | Then, training your dataset: 85 | 86 | ```python 87 | python train.py 88 | ``` 89 | 90 | ## Acknowledgements: 91 | 92 | This code-base uses helper functions from [CMU-Net](https://github.com/FengheTan9/CMU-Net), [SSL4MIS](https://github.com/HiLab-git/SSL4MIS) and [medFusion](https://github.com/mueller-franzes/medfusion). 93 | 94 | ## Citation 95 | 96 | If you use our code, please cite our paper: 97 | 98 | ```tex 99 | @article{tang2023multi, 100 | title={Multi-Level Global Context Cross Consistency Model for Semi-Supervised Ultrasound Image Segmentation with Diffusion Model}, 101 | author={Tang, Fenghe and Ding, Jianrui and Wang, Lingtao and Xian, Min and Ning, Chunping}, 102 | journal={arXiv preprint arXiv:2305.09447}, 103 | year={2023} 104 | } 105 | ``` 106 | 107 | -------------------------------------------------------------------------------- /checkpoint/readme.md: -------------------------------------------------------------------------------- 1 | Build the checkpoint file to save your model. -------------------------------------------------------------------------------- /imgs/MGCC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FengheTan9/Multi-Level-Global-Context-Cross-Consistency/a9c8bd3bb5bb6e0bef9a4bc69f7aa812888a7e0a/imgs/MGCC.png -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FengheTan9/Multi-Level-Global-Context-Cross-Consistency/a9c8bd3bb5bb6e0bef9a4bc69f7aa812888a7e0a/imgs/framework.png -------------------------------------------------------------------------------- /imgs/gen_bus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FengheTan9/Multi-Level-Global-Context-Cross-Consistency/a9c8bd3bb5bb6e0bef9a4bc69f7aa812888a7e0a/imgs/gen_bus.png -------------------------------------------------------------------------------- /imgs/gen_tus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FengheTan9/Multi-Level-Global-Context-Cross-Consistency/a9c8bd3bb5bb6e0bef9a4bc69f7aa812888a7e0a/imgs/gen_tus.png -------------------------------------------------------------------------------- /split.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from sklearn.model_selection import train_test_split 4 | 5 | name = 'busi' 6 | 7 | root = r'./data/' + name 8 | 9 | img_ids = glob(os.path.join(root, 'images', '*.png')) 10 | img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids] 11 | 12 | 13 | count = 1 14 | for i in [41, 64, 1337]: 15 | train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.3, random_state=i) 16 | filename = root + '/{}_train{}.txt'.format(name, count) 17 | with open(filename, 'w') as file: 18 | for i in train_img_ids: 19 | file.write(i + '\n') 20 | 21 | filename = root + '/{}_val{}.txt'.format(name, count) 22 | with open(filename, 'w') as file: 23 | for i in val_img_ids: 24 | file.writelines(i + '\n') 25 | 26 | print(train_img_ids) 27 | print(val_img_ids) 28 | count += 1 29 | -------------------------------------------------------------------------------- /src/dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import itertools 5 | from torch.utils.data.sampler import Sampler 6 | import cv2 7 | 8 | 9 | class SemiDataSets(Dataset): 10 | def __init__( 11 | self, 12 | base_dir=None, 13 | split="train", 14 | transform=None, 15 | train_file_dir="train.txt", 16 | val_file_dir="val.txt", 17 | ): 18 | self._base_dir = base_dir 19 | self.sample_list = [] 20 | self.split = split 21 | self.transform = transform 22 | self.train_list = [] 23 | self.semi_list = [] 24 | 25 | if self.split == "train": 26 | with open(os.path.join(self._base_dir, train_file_dir), "r") as f1: 27 | self.sample_list = f1.readlines() 28 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 29 | 30 | elif self.split == "val": 31 | with open(os.path.join(self._base_dir, val_file_dir), "r") as f: 32 | self.sample_list = f.readlines() 33 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 34 | 35 | print("total {} samples".format(len(self.sample_list))) 36 | 37 | def __len__(self): 38 | return len(self.sample_list) 39 | 40 | def __getitem__(self, idx): 41 | 42 | case = self.sample_list[idx] 43 | 44 | image = cv2.imread(os.path.join(self._base_dir, 'images', case + '.png')) 45 | label = \ 46 | cv2.imread(os.path.join(self._base_dir, 'masks', '0', case + '.png'), cv2.IMREAD_GRAYSCALE)[ 47 | ..., None] 48 | 49 | augmented = self.transform(image=image, mask=label) 50 | image = augmented['image'] 51 | label = augmented['mask'] 52 | 53 | image = image.astype('float32') / 255 54 | image = image.transpose(2, 0, 1) 55 | 56 | label = label.astype('float32') / 255 57 | label = label.transpose(2, 0, 1) 58 | 59 | sample = {"image": image, "label": label, "idx": idx} 60 | return sample 61 | 62 | 63 | 64 | class TwoStreamBatchSampler(Sampler): 65 | """Iterate two sets of indices 66 | 67 | An 'epoch' is one iteration through the primary indices. 68 | During the epoch, the secondary indices are iterated through 69 | as many times as needed. 70 | """ 71 | 72 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 73 | self.primary_indices = primary_indices 74 | self.secondary_indices = secondary_indices 75 | self.secondary_batch_size = secondary_batch_size 76 | self.primary_batch_size = batch_size - secondary_batch_size 77 | 78 | assert len(self.primary_indices) >= self.primary_batch_size > 0 79 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 80 | 81 | def __iter__(self): 82 | primary_iter = iterate_once(self.primary_indices) 83 | secondary_iter = iterate_eternally(self.secondary_indices) 84 | return ( 85 | primary_batch + secondary_batch 86 | for (primary_batch, secondary_batch) in zip( 87 | grouper(primary_iter, self.primary_batch_size), 88 | grouper(secondary_iter, self.secondary_batch_size), 89 | ) 90 | ) 91 | 92 | def __len__(self): 93 | return len(self.primary_indices) // self.primary_batch_size 94 | 95 | 96 | def iterate_once(iterable): 97 | return np.random.permutation(iterable) 98 | 99 | 100 | def iterate_eternally(indices): 101 | def infinite_shuffles(): 102 | while True: 103 | yield np.random.permutation(indices) 104 | 105 | return itertools.chain.from_iterable(infinite_shuffles()) 106 | 107 | 108 | def grouper(iterable, n): 109 | "Collect data into fixed-length chunks or blocks" 110 | # grouper('ABCDEFG', 3) --> ABC DEF" 111 | args = [iter(iterable)] * n 112 | return zip(*args) 113 | 114 | -------------------------------------------------------------------------------- /src/network/MGCC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions.uniform import Uniform 4 | import numpy as np 5 | 6 | 7 | class MSAG(nn.Module): 8 | """ 9 | Multi-scale attention gate 10 | Arxiv: https://arxiv.org/abs/2210.13012 11 | """ 12 | def __init__(self, channel): 13 | super(MSAG, self).__init__() 14 | self.channel = channel 15 | self.pointwiseConv = nn.Sequential( 16 | nn.Conv2d(self.channel, self.channel, kernel_size=1, padding=0, bias=True), 17 | nn.BatchNorm2d(self.channel), 18 | ) 19 | self.ordinaryConv = nn.Sequential( 20 | nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=1, stride=1, bias=True), 21 | nn.BatchNorm2d(self.channel), 22 | ) 23 | self.dilationConv = nn.Sequential( 24 | nn.Conv2d(self.channel, self.channel, kernel_size=3, padding=2, stride=1, dilation=2, bias=True), 25 | nn.BatchNorm2d(self.channel), 26 | ) 27 | self.voteConv = nn.Sequential( 28 | nn.Conv2d(self.channel * 3, self.channel, kernel_size=(1, 1)), 29 | nn.BatchNorm2d(self.channel), 30 | nn.Sigmoid() 31 | ) 32 | self.relu = nn.ReLU(inplace=True) 33 | 34 | def forward(self, x): 35 | x1 = self.pointwiseConv(x) 36 | x2 = self.ordinaryConv(x) 37 | x3 = self.dilationConv(x) 38 | _x = self.relu(torch.cat((x1, x2, x3), dim=1)) 39 | _x = self.voteConv(_x) 40 | x = x + x * _x 41 | return x 42 | 43 | 44 | class Residual(nn.Module): 45 | def __init__(self, fn): 46 | super().__init__() 47 | self.fn = fn 48 | 49 | def forward(self, x): 50 | return self.fn(x) + x 51 | 52 | 53 | class ConvMixerBlock(nn.Module): 54 | def __init__(self, dim=1024, depth=7, k=7): 55 | super(ConvMixerBlock, self).__init__() 56 | self.block = nn.Sequential( 57 | *[nn.Sequential( 58 | Residual(nn.Sequential( 59 | # deep wise 60 | nn.Conv2d(dim, dim, kernel_size=(k, k), groups=dim, padding=(k // 2, k // 2)), 61 | nn.GELU(), 62 | nn.BatchNorm2d(dim) 63 | )), 64 | nn.Conv2d(dim, dim, kernel_size=(1, 1)), 65 | nn.GELU(), 66 | nn.BatchNorm2d(dim) 67 | ) for i in range(depth)] 68 | ) 69 | 70 | def forward(self, x): 71 | x = self.block(x) 72 | return x 73 | 74 | 75 | class conv_block(nn.Module): 76 | def __init__(self, ch_in, ch_out): 77 | super(conv_block, self).__init__() 78 | self.conv = nn.Sequential( 79 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 80 | nn.BatchNorm2d(ch_out), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 83 | nn.BatchNorm2d(ch_out), 84 | nn.ReLU(inplace=True) 85 | ) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | return x 90 | 91 | 92 | class up_conv(nn.Module): 93 | def __init__(self, ch_in, ch_out): 94 | super(up_conv, self).__init__() 95 | self.up = nn.Sequential( 96 | nn.Upsample(scale_factor=2), 97 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 98 | nn.BatchNorm2d(ch_out), 99 | nn.ReLU(inplace=True) 100 | ) 101 | 102 | def forward(self, x): 103 | x = self.up(x) 104 | return x 105 | 106 | 107 | class FeatureNoise(nn.Module): 108 | def __init__(self, uniform_range=0.3): 109 | super(FeatureNoise, self).__init__() 110 | self.uni_dist = Uniform(-uniform_range, uniform_range) 111 | 112 | def feature_based_noise(self, x): 113 | noise_vector = self.uni_dist.sample( 114 | x.shape[1:]).to(x.device).unsqueeze(0) 115 | x_noise = x.mul(noise_vector) + x 116 | return x_noise 117 | 118 | def forward(self, x): 119 | x = self.feature_based_noise(x) 120 | return x 121 | 122 | 123 | def Dropout(x, p=0.3): 124 | x = torch.nn.functional.dropout(x, p) 125 | return x 126 | 127 | 128 | def FeatureDropout(x): 129 | attention = torch.mean(x, dim=1, keepdim=True) 130 | max_val, _ = torch.max(attention.view( 131 | x.size(0), -1), dim=1, keepdim=True) 132 | threshold = max_val * np.random.uniform(0.7, 0.9) 133 | threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) 134 | drop_mask = (attention < threshold).float() 135 | x = x.mul(drop_mask) 136 | return x 137 | 138 | 139 | class Decoder(nn.Module): 140 | 141 | def __init__(self, dim_mult=4, with_masg=True): 142 | super(Decoder, self).__init__() 143 | self.with_masg = with_masg 144 | self.Up5 = up_conv(ch_in=256 * dim_mult, ch_out=128 * dim_mult) 145 | self.Up_conv5 = conv_block(ch_in=128 * 2 * dim_mult, ch_out=128 * dim_mult) 146 | self.Up4 = up_conv(ch_in=128 * dim_mult, ch_out=64 * dim_mult) 147 | self.Up_conv4 = conv_block(ch_in=64 * 2 * dim_mult, ch_out=64 * dim_mult) 148 | self.Up3 = up_conv(ch_in=64 * dim_mult, ch_out=32 * dim_mult) 149 | self.Up_conv3 = conv_block(ch_in=32 * 2 * dim_mult, ch_out=32 * dim_mult) 150 | self.Up2 = up_conv(ch_in=32 * dim_mult, ch_out=16 * dim_mult) 151 | self.Up_conv2 = conv_block(ch_in=16 * 2 * dim_mult, ch_out=16 * dim_mult) 152 | self.Conv_1x1 = nn.Conv2d(16 * dim_mult, 1, kernel_size=1, stride=1, padding=0) 153 | 154 | self.msag4 = MSAG(128 * dim_mult) 155 | self.msag3 = MSAG(64 * dim_mult) 156 | self.msag2 = MSAG(32 * dim_mult) 157 | self.msag1 = MSAG(16 * dim_mult) 158 | 159 | def forward(self, feature): 160 | x1, x2, x3, x4, x5 = feature 161 | if self.with_masg: 162 | x4 = self.msag4(x4) 163 | x3 = self.msag3(x3) 164 | x2 = self.msag2(x2) 165 | x1 = self.msag1(x1) 166 | 167 | d5 = self.Up5(x5) 168 | d5 = torch.cat((x4, d5), dim=1) 169 | d5 = self.Up_conv5(d5) 170 | 171 | d4 = self.Up4(d5) 172 | d4 = torch.cat((x3, d4), dim=1) 173 | d4 = self.Up_conv4(d4) 174 | 175 | d3 = self.Up3(d4) 176 | d3 = torch.cat((x2, d3), dim=1) 177 | d3 = self.Up_conv3(d3) 178 | 179 | d2 = self.Up2(d3) 180 | d2 = torch.cat((x1, d2), dim=1) 181 | d2 = self.Up_conv2(d2) 182 | d1 = self.Conv_1x1(d2) 183 | return d1 184 | 185 | 186 | class MGCC(nn.Module): 187 | def __init__(self, img_ch=3, length=(3, 3, 3), k=7, dim_mult=4): 188 | """ 189 | Multi-Level Global Context Cross Consistency Model 190 | Args: 191 | img_ch : input channel. 192 | output_ch: output channel. 193 | length: number of convMixer layers 194 | k: kernal size of convMixer 195 | 196 | """ 197 | super(MGCC, self).__init__() 198 | 199 | # Encoder 200 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 201 | self.Conv1 = conv_block(ch_in=img_ch, ch_out=16 * dim_mult) 202 | self.Conv2 = conv_block(ch_in=16 * dim_mult, ch_out=32 * dim_mult) 203 | self.Conv3 = conv_block(ch_in=32 * dim_mult, ch_out=64 * dim_mult) 204 | self.Conv4 = conv_block(ch_in=64 * dim_mult, ch_out=128 * dim_mult) 205 | self.Conv5 = conv_block(ch_in=128 * dim_mult, ch_out=256 * dim_mult) 206 | self.ConvMixer1 = ConvMixerBlock(dim=256 * dim_mult, depth=length[0], k=k) 207 | self.ConvMixer2 = ConvMixerBlock(dim=256 * dim_mult, depth=length[1], k=k) 208 | self.ConvMixer3 = ConvMixerBlock(dim=256 * dim_mult, depth=length[2], k=k) 209 | # main Decoder 210 | self.main_decoder = Decoder(dim_mult=dim_mult, with_masg=True) 211 | # aux Decoder 212 | self.aux_decoder1 = Decoder(dim_mult=dim_mult, with_masg=True) 213 | self.aux_decoder2 = Decoder(dim_mult=dim_mult, with_masg=True) 214 | self.aux_decoder3 = Decoder(dim_mult=dim_mult, with_masg=True) 215 | 216 | def forward(self, x): 217 | x1 = self.Conv1(x) 218 | x2 = self.Maxpool(x1) 219 | x2 = self.Conv2(x2) 220 | x3 = self.Maxpool(x2) 221 | x3 = self.Conv3(x3) 222 | x4 = self.Maxpool(x3) 223 | x4 = self.Conv4(x4) 224 | x5 = self.Maxpool(x4) 225 | x5 = self.Conv5(x5) 226 | 227 | if not self.training: 228 | x5 = self.ConvMixer1(x5) 229 | x5 = self.ConvMixer2(x5) 230 | x5 = self.ConvMixer3(x5) 231 | feature = [x1, x2, x3, x4, x5] 232 | main_seg = self.main_decoder(feature) 233 | return main_seg 234 | 235 | # FeatureNoise 236 | feature = [x1, x2, x3, x4, x5] 237 | aux1_feature = [FeatureDropout(i) for i in feature] 238 | aux_seg1 = self.aux_decoder1(aux1_feature) 239 | 240 | x5 = self.ConvMixer1(x5) 241 | feature = [x1, x2, x3, x4, x5] 242 | aux2_feature = [Dropout(i) for i in feature] 243 | aux_seg2 = self.aux_decoder2(aux2_feature) 244 | 245 | x5 = self.ConvMixer2(x5) 246 | feature = [x1, x2, x3, x4, x5] 247 | aux3_feature = [FeatureNoise()(i) for i in feature] 248 | aux_seg3 = self.aux_decoder3(aux3_feature) 249 | 250 | # main decoder 251 | x5 = self.ConvMixer3(x5) 252 | feature = [x1, x2, x3, x4, x5] 253 | main_seg = self.main_decoder(feature) 254 | return main_seg, aux_seg1, aux_seg2, aux_seg3 255 | -------------------------------------------------------------------------------- /src/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | __all__ = ['BCEDiceLoss'] 7 | 8 | 9 | class BCEDiceLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, input, target): 14 | bce = F.binary_cross_entropy_with_logits(input, target) 15 | smooth = 1e-5 16 | input = torch.sigmoid(input) 17 | num = target.size(0) 18 | input = input.view(num, -1) 19 | target = target.view(num, -1) 20 | intersection = (input * target) 21 | dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) 22 | dice = 1 - dice.sum() / num 23 | return 0.5 * bce + dice 24 | 25 | 26 | def compute_kl_loss(p, q): 27 | p_loss = F.kl_div(F.log_softmax(p, dim=-1), 28 | F.softmax(q, dim=-1), reduction='none') 29 | q_loss = F.kl_div(F.log_softmax(q, dim=-1), 30 | F.softmax(p, dim=-1), reduction='none') 31 | 32 | # Using function "sum" and "mean" are depending on your task 33 | p_loss = p_loss.mean() 34 | q_loss = q_loss.mean() 35 | 36 | loss = (p_loss + q_loss) / 2 37 | return loss 38 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_accuracy(SR, GT, threshold=0.5): 5 | SR = SR > threshold 6 | GT = GT == torch.max(GT) 7 | corr = torch.sum(SR == GT) 8 | tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3) 9 | acc = float(corr)/float(tensor_size) 10 | return acc 11 | 12 | 13 | def get_sensitivity(SR, GT, threshold=0.5): 14 | # Sensitivity == Recall 15 | SE = 0 16 | SR = SR > threshold 17 | GT = GT == torch.max(GT) 18 | # TP : True Positive 19 | # FN : False Negative 20 | TP = ((SR == 1).byte() + (GT == 1).byte()) == 2 21 | FN = ((SR == 0).byte() + (GT == 1).byte()) == 2 22 | SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6) 23 | return SE 24 | 25 | def get_specificity(SR,GT,threshold=0.5): 26 | SP = 0 27 | SR = SR > threshold 28 | GT = GT == torch.max(GT) 29 | # TN : True Negative 30 | # FP : False Positive 31 | TN = ((SR == 0).byte() + (GT == 0).byte()) == 2 32 | FP = ((SR == 1).byte() + (GT == 0).byte()) == 2 33 | SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6) 34 | return SP 35 | 36 | def get_precision(SR,GT,threshold=0.5): 37 | PC = 0 38 | SR = SR > threshold 39 | GT = GT== torch.max(GT) 40 | # TP : True Positive 41 | # FP : False Positive 42 | TP = ((SR == 1).byte() + (GT == 1).byte()) == 2 43 | FP = ((SR == 1).byte() + (GT == 0).byte()) == 2 44 | PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6) 45 | return PC 46 | 47 | def iou_score(output, target): 48 | smooth = 1e-5 49 | 50 | if torch.is_tensor(output): 51 | output = torch.sigmoid(output).data.cpu().numpy() 52 | if torch.is_tensor(target): 53 | target = target.data.cpu().numpy() 54 | output_ = output > 0.5 55 | target_ = target > 0.5 56 | 57 | intersection = (output_ & target_).sum() 58 | union = (output_ | target_).sum() 59 | iou = (intersection + smooth) / (union + smooth) 60 | dice = (2 * iou) / (iou + 1) 61 | 62 | output_ = torch.tensor(output_) 63 | target_ = torch.tensor(target_) 64 | SE = get_sensitivity(output_, target_, threshold=0.5) 65 | PC = get_precision(output_, target_, threshold=0.5) 66 | SP= get_specificity(output_, target_, threshold=0.5) 67 | ACC=get_accuracy(output_, target_, threshold=0.5) 68 | F1 = 2*SE*PC/(SE+PC + 1e-6) 69 | return iou, dice, SE, PC, F1, SP, ACC 70 | 71 | 72 | def dice_coef(output, target): 73 | smooth = 1e-5 74 | output = torch.sigmoid(output).view(-1).data.cpu().numpy() 75 | target = target.view(-1).data.cpu().numpy() 76 | intersection = (output * target).sum() 77 | return (2. * intersection + smooth) / \ 78 | (output.sum() + target.sum() + smooth) 79 | 80 | 81 | -------------------------------------------------------------------------------- /src/utils/ramps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """Functions for ramping hyperparameters up or down 4 | 5 | Each function takes the current training step or epoch, and the 6 | ramp length in the same format, and returns a multiplier between 7 | 0 and 1. 8 | """ 9 | 10 | def sigmoid_rampup(current, rampup_length): 11 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 12 | if rampup_length == 0: 13 | return 1.0 14 | else: 15 | current = np.clip(current, 0.0, rampup_length) 16 | phase = 1.0 - current / rampup_length 17 | return float(np.exp(-5.0 * phase * phase)) 18 | 19 | 20 | def linear_rampup(current, rampup_length): 21 | """Linear rampup""" 22 | assert current >= 0 and rampup_length >= 0 23 | if current >= rampup_length: 24 | return 1.0 25 | else: 26 | return current / rampup_length 27 | 28 | 29 | def cosine_rampdown(current, rampdown_length): 30 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 31 | assert 0 <= current <= rampdown_length 32 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 33 | -------------------------------------------------------------------------------- /src/utils/util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | if v.lower() in ['true', 1]: 5 | return True 6 | elif v.lower() in ['false', 0]: 7 | return False 8 | else: 9 | raise argparse.ArgumentTypeError('Boolean value expected.') 10 | 11 | 12 | def count_params(model): 13 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=1): 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / self.count 33 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | from albumentations.augmentations import transforms 10 | from albumentations.core.composition import Compose 11 | from albumentations import RandomRotate90, Resize 12 | import src.utils.losses as losses 13 | from src.utils.util import AverageMeter 14 | from src.utils.metrics import iou_score 15 | from src.utils import ramps 16 | from src.dataloader.dataset import (SemiDataSets, TwoStreamBatchSampler) 17 | from src.network.MGCC import MGCC 18 | import os 19 | 20 | 21 | def seed_torch(seed): 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | torch.backends.cudnn.benchmark = False 27 | torch.backends.cudnn.deterministic = True 28 | random.seed(seed) 29 | np.random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--semi_percent', type=float, default=0.5) 35 | parser.add_argument('--base_dir', type=str, default="./data/busi", help='dir') 36 | parser.add_argument('--train_file_dir', type=str, default="busi_train1.txt", help='dir') 37 | parser.add_argument('--val_file_dir', type=str, default="busi_val1.txt", help='dir') 38 | parser.add_argument('--max_iterations', type=int, 39 | default=40000, help='maximum epoch number to train') 40 | parser.add_argument('--total_batch_size', type=int, default=8, 41 | help='batch_size per gpu') 42 | parser.add_argument('--base_lr', type=float, default=0.01, 43 | help='segmentation network learning rate') 44 | parser.add_argument('--seed', type=int, default=41, help='random seed') 45 | # label and unlabel 46 | parser.add_argument('--labeled_bs', type=int, default=4, 47 | help='labeled_batch_size per gpu') 48 | # costs 49 | parser.add_argument('--consistency', type=float, 50 | default=7, help='consistency') 51 | parser.add_argument('--consistency_rampup', type=float, 52 | default=200.0, help='consistency_rampup') 53 | # MGCC hyperparameter 54 | parser.add_argument('--kernel_size', type=int, 55 | default=7, help='ConvMixer kernel size') 56 | parser.add_argument('--length', type=tuple, 57 | default=(3, 3, 3), help='length of ConvMixer') 58 | args = parser.parse_args() 59 | 60 | seed_torch(args.seed) 61 | 62 | 63 | def getDataloader(args): 64 | train_transform = Compose([ 65 | RandomRotate90(), 66 | transforms.Flip(), 67 | Resize(256, 256), 68 | transforms.Normalize(), 69 | ]) 70 | val_transform = Compose([ 71 | Resize(256, 256), 72 | transforms.Normalize(), 73 | ]) 74 | labeled_slice = args.semi_percent 75 | db_train = SemiDataSets(base_dir=args.base_dir, split="train", transform=train_transform, 76 | train_file_dir=args.train_file_dir, val_file_dir=args.val_file_dir, 77 | ) 78 | db_val = SemiDataSets(base_dir=args.base_dir, split="val", transform=val_transform, 79 | train_file_dir=args.train_file_dir, val_file_dir=args.val_file_dir 80 | ) 81 | 82 | def worker_init_fn(worker_id): 83 | random.seed(args.seed + worker_id) 84 | 85 | total_slices = len(db_train) 86 | labeled_idxs = list(range(0, int(labeled_slice * total_slices))) 87 | unlabeled_idxs = list(range(int(labeled_slice * total_slices), total_slices)) 88 | print("label num:{}, unlabel num:{} percent:{}".format(len(labeled_idxs), len(unlabeled_idxs), labeled_slice)) 89 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.total_batch_size, args.labeled_bs) 90 | trainloader = DataLoader(db_train, batch_sampler=batch_sampler, 91 | num_workers=8, pin_memory=False, worker_init_fn=worker_init_fn) 92 | valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1) 93 | 94 | return trainloader, valloader 95 | 96 | 97 | def get_current_consistency_weight(epoch): 98 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 99 | return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup) 100 | 101 | 102 | def getModel(args): 103 | print("ConvMixer1:{}, ConvMixer2:{}, ConvMixer3:{}, kernal:{}".format(args.length[0], args.length[1], 104 | args.length[2], args.kernel_size)) 105 | return MGCC(length=args.length, k=args.kernel_size).cuda() 106 | 107 | 108 | def train(args): 109 | base_lr = args.base_lr 110 | max_iterations = int(args.max_iterations * args.semi_percent) 111 | trainloader, valloader = getDataloader(args) 112 | 113 | model = getModel(args) 114 | 115 | optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 116 | criterion = losses.__dict__['BCEDiceLoss']().cuda() 117 | 118 | print("{} iterations per epoch".format(len(trainloader))) 119 | best_iou = 0 120 | iter_num = 0 121 | max_epoch = max_iterations // len(trainloader) + 1 122 | 123 | for epoch_num in range(max_epoch): 124 | avg_meters = {'total_loss': AverageMeter(), 125 | 'train_iou': AverageMeter(), 126 | 'consistency_loss': AverageMeter(), 127 | 'supervised_loss': AverageMeter(), 128 | 'val_loss': AverageMeter(), 129 | 'val_iou': AverageMeter(), 130 | 'val_se': AverageMeter(), 131 | 'val_pc': AverageMeter(), 132 | 'val_f1': AverageMeter(), 133 | 'val_acc': AverageMeter() 134 | } 135 | model.train() 136 | for i_batch, sampled_batch in enumerate(trainloader): 137 | 138 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 139 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 140 | 141 | outputs, outputs_aux1, outputs_aux2, outputs_aux3 = model(volume_batch) 142 | 143 | outputs_soft = torch.sigmoid(outputs) 144 | outputs_aux1_soft = torch.sigmoid(outputs_aux1) 145 | outputs_aux2_soft = torch.sigmoid(outputs_aux2) 146 | outputs_aux3_soft = torch.sigmoid(outputs_aux3) 147 | 148 | loss_ce = criterion(outputs[:args.labeled_bs], 149 | label_batch[:args.labeled_bs][:]) 150 | loss_ce_aux1 = criterion(outputs_aux1[:args.labeled_bs], 151 | label_batch[:args.labeled_bs][:]) 152 | loss_ce_aux2 = criterion(outputs_aux2[:args.labeled_bs], 153 | label_batch[:args.labeled_bs][:]) 154 | loss_ce_aux3 = criterion(outputs_aux3[:args.labeled_bs], 155 | label_batch[:args.labeled_bs][:]) 156 | 157 | supervised_loss = (loss_ce + loss_ce_aux1 + loss_ce_aux2 + loss_ce_aux3) / 4 158 | 159 | consistency_weight = get_current_consistency_weight(iter_num // 150) 160 | consistency_loss_aux1 = torch.mean( 161 | (outputs_soft[args.labeled_bs:] - outputs_aux1_soft[args.labeled_bs:]) ** 2) 162 | consistency_loss_aux2 = torch.mean( 163 | (outputs_soft[args.labeled_bs:] - outputs_aux2_soft[args.labeled_bs:]) ** 2) 164 | consistency_loss_aux3 = torch.mean( 165 | (outputs_soft[args.labeled_bs:] - outputs_aux3_soft[args.labeled_bs:]) ** 2) 166 | 167 | consistency_loss = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3) / 3 168 | loss = supervised_loss + consistency_weight * consistency_loss 169 | iou, dice, _, _, _, _, _ = iou_score(outputs[:args.labeled_bs], label_batch[:args.labeled_bs]) 170 | optimizer.zero_grad() 171 | loss.backward() 172 | optimizer.step() 173 | 174 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 175 | for param_group in optimizer.param_groups: 176 | param_group['lr'] = lr_ 177 | 178 | iter_num = iter_num + 1 179 | 180 | avg_meters['total_loss'].update(loss.item(), volume_batch[:args.labeled_bs].size(0)) 181 | avg_meters['supervised_loss'].update(supervised_loss.item(), volume_batch[:args.labeled_bs].size(0)) 182 | avg_meters['consistency_loss'].update(consistency_loss.item(), volume_batch[args.labeled_bs:].size(0)) 183 | avg_meters['train_iou'].update(iou, volume_batch[:args.labeled_bs].size(0)) 184 | 185 | model.eval() 186 | with torch.no_grad(): 187 | for i_batch, sampled_batch in enumerate(valloader): 188 | input, target = sampled_batch['image'], sampled_batch['label'] 189 | input = input.cuda() 190 | target = target.cuda() 191 | output = model(input) 192 | loss = criterion(output, target) 193 | iou, _, SE, PC, F1, _, ACC = iou_score(output, target) 194 | avg_meters['val_loss'].update(loss.item(), input.size(0)) 195 | avg_meters['val_iou'].update(iou, input.size(0)) 196 | avg_meters['val_se'].update(SE, input.size(0)) 197 | avg_meters['val_pc'].update(PC, input.size(0)) 198 | avg_meters['val_f1'].update(F1, input.size(0)) 199 | avg_meters['val_acc'].update(ACC, input.size(0)) 200 | 201 | print( 202 | 'epoch [%3d/%d] train_loss %.4f supervised_loss %.4f consistency_loss %.4f train_iou: %.4f ' 203 | '- val_loss %.4f - val_iou %.4f - val_SE %.4f - val_PC %.4f - val_F1 %.4f - val_ACC %.4f' 204 | % (epoch_num, max_epoch, avg_meters['total_loss'].avg, 205 | avg_meters['supervised_loss'].avg, avg_meters['consistency_loss'].avg, avg_meters['train_iou'].avg, 206 | avg_meters['val_loss'].avg, avg_meters['val_iou'].avg, avg_meters['val_se'].avg, 207 | avg_meters['val_pc'].avg, avg_meters['val_f1'].avg, avg_meters['val_acc'].avg)) 208 | 209 | if avg_meters['val_iou'].avg > best_iou: 210 | torch.save(model.state_dict(), 'checkpoint/model.pth') 211 | best_iou = avg_meters['val_iou'].avg 212 | print("=> saved best model") 213 | 214 | return "Training Finished!" 215 | 216 | 217 | if __name__ == "__main__": 218 | train(args) 219 | --------------------------------------------------------------------------------