├── .github └── workflows │ └── stale.yml ├── Config.py ├── Load_Dataset.py ├── README.md ├── Train_one_epoch.py ├── datasets └── MoNuSeg │ ├── Test_Folder │ ├── img │ │ ├── TCGA-2Z-A9J9-01A-01-TS1.tif │ │ ├── TCGA-44-2665-01B-06-BS6.tif │ │ ├── TCGA-69-7764-01A-01-TS1.tif │ │ ├── TCGA-A6-6782-01A-01-BS1.tif │ │ ├── TCGA-AC-A2FO-01A-01-TS1.tif │ │ ├── TCGA-AO-A0J2-01A-01-BSA.tif │ │ ├── TCGA-CU-A0YN-01A-02-BSB.tif │ │ ├── TCGA-EJ-A46H-01A-03-TSC.tif │ │ ├── TCGA-FG-A4MU-01B-01-TS1.tif │ │ ├── TCGA-GL-6846-01A-01-BS1.tif │ │ ├── TCGA-HC-7209-01A-01-TS1.tif │ │ ├── TCGA-HT-8564-01Z-00-DX1.tif │ │ ├── TCGA-IZ-8196-01A-01-BS1.tif │ │ └── TCGA-ZF-A9R5-01A-01-TS1.tif │ └── labelcol │ │ ├── TCGA-2Z-A9J9-01A-01-TS1.png │ │ ├── TCGA-44-2665-01B-06-BS6.png │ │ ├── TCGA-69-7764-01A-01-TS1.png │ │ ├── TCGA-A6-6782-01A-01-BS1.png │ │ ├── TCGA-AC-A2FO-01A-01-TS1.png │ │ ├── TCGA-AO-A0J2-01A-01-BSA.png │ │ ├── TCGA-CU-A0YN-01A-02-BSB.png │ │ ├── TCGA-EJ-A46H-01A-03-TSC.png │ │ ├── TCGA-FG-A4MU-01B-01-TS1.png │ │ ├── TCGA-GL-6846-01A-01-BS1.png │ │ ├── TCGA-HC-7209-01A-01-TS1.png │ │ ├── TCGA-HT-8564-01Z-00-DX1.png │ │ ├── TCGA-IZ-8196-01A-01-BS1.png │ │ └── TCGA-ZF-A9R5-01A-01-TS1.png │ ├── Train_Folder │ ├── img │ │ ├── TCGA-21-5784-01Z-00-DX1.png │ │ ├── TCGA-21-5786-01Z-00-DX1.png │ │ ├── TCGA-38-6178-01Z-00-DX1.png │ │ ├── TCGA-49-4488-01Z-00-DX1.png │ │ ├── TCGA-50-5931-01Z-00-DX1.png │ │ ├── TCGA-A7-A13E-01Z-00-DX1.png │ │ ├── TCGA-A7-A13F-01Z-00-DX1.png │ │ ├── TCGA-AR-A1AK-01Z-00-DX1.png │ │ ├── TCGA-AR-A1AS-01Z-00-DX1.png │ │ ├── TCGA-B0-5698-01Z-00-DX1.png │ │ ├── TCGA-B0-5710-01Z-00-DX1.png │ │ ├── TCGA-B0-5711-01Z-00-DX1.png │ │ ├── TCGA-CH-5767-01Z-00-DX1.png │ │ ├── TCGA-DK-A2I6-01A-01-TS1.png │ │ ├── TCGA-G2-A2EK-01A-02-TSB.png │ │ ├── TCGA-G9-6336-01Z-00-DX1.png │ │ ├── TCGA-G9-6348-01Z-00-DX1.png │ │ ├── TCGA-G9-6356-01Z-00-DX1.png │ │ ├── TCGA-G9-6362-01Z-00-DX1.png │ │ ├── TCGA-HE-7128-01Z-00-DX1.png │ │ ├── TCGA-HE-7130-01Z-00-DX1.png │ │ ├── TCGA-KB-A93J-01A-01-TS1.png │ │ ├── TCGA-NH-A8F7-01A-01-TS1.png │ │ └── TCGA-RD-A8N9-01A-01-TS1.png │ └── labelcol │ │ ├── TCGA-21-5784-01Z-00-DX1.png │ │ ├── TCGA-21-5786-01Z-00-DX1.png │ │ ├── TCGA-38-6178-01Z-00-DX1.png │ │ ├── TCGA-49-4488-01Z-00-DX1.png │ │ ├── TCGA-50-5931-01Z-00-DX1.png │ │ ├── TCGA-A7-A13E-01Z-00-DX1.png │ │ ├── TCGA-A7-A13F-01Z-00-DX1.png │ │ ├── TCGA-AR-A1AK-01Z-00-DX1.png │ │ ├── TCGA-AR-A1AS-01Z-00-DX1.png │ │ ├── TCGA-B0-5698-01Z-00-DX1.png │ │ ├── TCGA-B0-5710-01Z-00-DX1.png │ │ ├── TCGA-B0-5711-01Z-00-DX1.png │ │ ├── TCGA-CH-5767-01Z-00-DX1.png │ │ ├── TCGA-DK-A2I6-01A-01-TS1.png │ │ ├── TCGA-G2-A2EK-01A-02-TSB.png │ │ ├── TCGA-G9-6336-01Z-00-DX1.png │ │ ├── TCGA-G9-6348-01Z-00-DX1.png │ │ ├── TCGA-G9-6356-01Z-00-DX1.png │ │ ├── TCGA-G9-6362-01Z-00-DX1.png │ │ ├── TCGA-HE-7128-01Z-00-DX1.png │ │ ├── TCGA-HE-7130-01Z-00-DX1.png │ │ ├── TCGA-KB-A93J-01A-01-TS1.png │ │ ├── TCGA-NH-A8F7-01A-01-TS1.png │ │ └── TCGA-RD-A8N9-01A-01-TS1.png │ └── Val_Folder │ ├── img │ ├── TCGA-18-5592-01Z-00-DX1.png │ ├── TCGA-AY-A8YK-01A-01-TS1.png │ ├── TCGA-E2-A14V-01Z-00-DX1.png │ ├── TCGA-E2-A1B5-01Z-00-DX1.png │ ├── TCGA-G9-6363-01Z-00-DX1.png │ └── TCGA-HE-7129-01Z-00-DX1.png │ └── labelcol │ ├── TCGA-18-5592-01Z-00-DX1.png │ ├── TCGA-AY-A8YK-01A-01-TS1.png │ ├── TCGA-E2-A14V-01Z-00-DX1.png │ ├── TCGA-E2-A1B5-01Z-00-DX1.png │ ├── TCGA-G9-6363-01Z-00-DX1.png │ └── TCGA-HE-7129-01Z-00-DX1.png ├── docs └── Framework.jpg ├── nets ├── CTrans.py ├── UCTransNet.py └── UNet.py ├── requirements.txt ├── test_model.py ├── train_model.py └── utils.py /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. 2 | # 3 | # You can adjust the behavior by modifying this file. 4 | # For more information, see: 5 | # https://github.com/actions/stale 6 | name: Mark stale issues and pull requests 7 | 8 | on: 9 | schedule: 10 | - cron: '33 20 * * *' 11 | 12 | jobs: 13 | stale: 14 | 15 | runs-on: ubuntu-latest 16 | permissions: 17 | issues: write 18 | pull-requests: write 19 | 20 | steps: 21 | - uses: actions/stale@v5 22 | with: 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | stale-issue-message: 'Stale issue message' 25 | stale-pr-message: 'Stale pull request message' 26 | stale-issue-label: 'no-issue-activity' 27 | stale-pr-label: 'no-pr-activity' 28 | -------------------------------------------------------------------------------- /Config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/6/19 2:44 下午 3 | # @Author : Haonan Wang 4 | # @File : Config.py 5 | # @Software: PyCharm 6 | import os 7 | import torch 8 | import time 9 | import ml_collections 10 | 11 | ## PARAMETERS OF THE MODEL 12 | save_model = True 13 | tensorboard = True 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 15 | use_cuda = torch.cuda.is_available() 16 | seed = 666 17 | os.environ['PYTHONHASHSEED'] = str(seed) 18 | 19 | cosineLR = True # whether use cosineLR or not 20 | n_channels = 3 21 | n_labels = 1 22 | epochs = 2000 23 | img_size = 224 24 | print_frequency = 1 25 | save_frequency = 5000 26 | vis_frequency = 10 27 | early_stopping_patience = 50 28 | 29 | pretrain = False 30 | task_name = 'MoNuSeg' # GlaS MoNuSeg 31 | # task_name = 'GlaS' 32 | learning_rate = 1e-3 33 | batch_size = 4 34 | 35 | 36 | # model_name = 'UCTransNet' 37 | model_name = 'UCTransNet_pretrain' 38 | 39 | train_dataset = './datasets/'+ task_name+ '/Train_Folder/' 40 | val_dataset = './datasets/'+ task_name+ '/Val_Folder/' 41 | test_dataset = './datasets/'+ task_name+ '/Test_Folder/' 42 | session_name = 'Test_session' + '_' + time.strftime('%m.%d_%Hh%M') 43 | save_path = task_name +'/'+ model_name +'/' + session_name + '/' 44 | model_path = save_path + 'models/' 45 | tensorboard_folder = save_path + 'tensorboard_logs/' 46 | logger_path = save_path + session_name + ".log" 47 | visualize_path = save_path + 'visualize_val/' 48 | 49 | 50 | ########################################################################## 51 | # CTrans configs 52 | ########################################################################## 53 | def get_CTranS_config(): 54 | config = ml_collections.ConfigDict() 55 | config.transformer = ml_collections.ConfigDict() 56 | config.KV_size = 960 # KV_size = Q1 + Q2 + Q3 + Q4 57 | config.transformer.num_heads = 4 58 | config.transformer.num_layers = 4 59 | config.expand_ratio = 4 # MLP channel dimension expand ratio 60 | config.transformer.embeddings_dropout_rate = 0.1 61 | config.transformer.attention_dropout_rate = 0.1 62 | config.transformer.dropout_rate = 0 63 | config.patch_sizes = [16,8,4,2] 64 | config.base_channel = 64 # base channel of U-Net 65 | config.n_classes = 1 66 | return config 67 | 68 | 69 | 70 | 71 | # used in testing phase, copy the session name in training phase 72 | test_session = "Test_session_07.03_20h39" -------------------------------------------------------------------------------- /Load_Dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/6/19 11:30 上午 3 | # @Author : Haonan Wang 4 | # @File : Load_Dataset.py 5 | # @Software: PyCharm 6 | import numpy as np 7 | import torch 8 | import random 9 | from scipy.ndimage.interpolation import zoom 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms as T 12 | from torchvision.transforms import functional as F 13 | from typing import Callable 14 | import os 15 | import cv2 16 | from scipy import ndimage 17 | 18 | def random_rot_flip(image, label): 19 | k = np.random.randint(0, 4) 20 | image = np.rot90(image, k) 21 | label = np.rot90(label, k) 22 | axis = np.random.randint(0, 2) 23 | image = np.flip(image, axis=axis).copy() 24 | label = np.flip(label, axis=axis).copy() 25 | return image, label 26 | 27 | def random_rotate(image, label): 28 | angle = np.random.randint(-20, 20) 29 | image = ndimage.rotate(image, angle, order=0, reshape=False) 30 | label = ndimage.rotate(label, angle, order=0, reshape=False) 31 | return image, label 32 | 33 | class RandomGenerator(object): 34 | def __init__(self, output_size): 35 | self.output_size = output_size 36 | 37 | def __call__(self, sample): 38 | image, label = sample['image'], sample['label'] 39 | image, label = F.to_pil_image(image), F.to_pil_image(label) 40 | x, y = image.size 41 | if random.random() > 0.5: 42 | image, label = random_rot_flip(image, label) 43 | elif random.random() < 0.5: 44 | image, label = random_rotate(image, label) 45 | 46 | if x != self.output_size[0] or y != self.output_size[1]: 47 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 48 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 49 | image = F.to_tensor(image) 50 | label = to_long_tensor(label) 51 | sample = {'image': image, 'label': label} 52 | return sample 53 | 54 | class ValGenerator(object): 55 | def __init__(self, output_size): 56 | self.output_size = output_size 57 | 58 | def __call__(self, sample): 59 | image, label = sample['image'], sample['label'] 60 | image, label = F.to_pil_image(image), F.to_pil_image(label) 61 | x, y = image.size 62 | if x != self.output_size[0] or y != self.output_size[1]: 63 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 64 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 65 | image = F.to_tensor(image) 66 | label = to_long_tensor(label) 67 | sample = {'image': image, 'label': label} 68 | return sample 69 | 70 | def to_long_tensor(pic): 71 | # handle numpy array 72 | img = torch.from_numpy(np.array(pic, np.uint8)) 73 | # backward compatibility 74 | return img.long() 75 | 76 | def correct_dims(*images): 77 | corr_images = [] 78 | # print(images) 79 | for img in images: 80 | if len(img.shape) == 2: 81 | corr_images.append(np.expand_dims(img, axis=2)) 82 | else: 83 | corr_images.append(img) 84 | 85 | if len(corr_images) == 1: 86 | return corr_images[0] 87 | else: 88 | return corr_images 89 | 90 | class ImageToImage2D(Dataset): 91 | """ 92 | Reads the images and applies the augmentation transform on them. 93 | Usage: 94 | 1. If used without the unet.model.Model wrapper, an instance of this object should be passed to 95 | torch.utils.data.DataLoader. Iterating through this returns the tuple of image, mask and image 96 | filename. 97 | 2. With unet.model.Model wrapper, an instance of this object should be passed as train or validation 98 | datasets. 99 | 100 | Args: 101 | dataset_path: path to the dataset. Structure of the dataset should be: 102 | dataset_path 103 | |-- images 104 | |-- img001.png 105 | |-- img002.png 106 | |-- ... 107 | |-- masks 108 | |-- img001.png 109 | |-- img002.png 110 | |-- ... 111 | 112 | joint_transform: augmentation transform, an instance of JointTransform2D. If bool(joint_transform) 113 | evaluates to False, torchvision.transforms.ToTensor will be used on both image and mask. 114 | one_hot_mask: bool, if True, returns the mask in one-hot encoded form. 115 | """ 116 | 117 | def __init__(self, dataset_path: str, joint_transform: Callable = None, one_hot_mask: int = False, image_size: int =224) -> None: 118 | self.dataset_path = dataset_path 119 | self.image_size = image_size 120 | self.input_path = os.path.join(dataset_path, 'img') 121 | self.output_path = os.path.join(dataset_path, 'labelcol') 122 | self.images_list = os.listdir(self.input_path) 123 | self.one_hot_mask = one_hot_mask 124 | 125 | if joint_transform: 126 | self.joint_transform = joint_transform 127 | else: 128 | to_tensor = T.ToTensor() 129 | self.joint_transform = lambda x, y: (to_tensor(x), to_tensor(y)) 130 | 131 | def __len__(self): 132 | return len(os.listdir(self.input_path)) 133 | 134 | def __getitem__(self, idx): 135 | 136 | image_filename = self.images_list[idx] 137 | #print(image_filename[: -3]) 138 | # read image 139 | # print(os.path.join(self.input_path, image_filename)) 140 | # print(os.path.join(self.output_path, image_filename[: -3] + "png")) 141 | # print(os.path.join(self.input_path, image_filename)) 142 | image = cv2.imread(os.path.join(self.input_path, image_filename)) 143 | # print("img",image_filename) 144 | # print("1",image.shape) 145 | image = cv2.resize(image,(self.image_size,self.image_size)) 146 | # print(np.max(image), np.min(image)) 147 | # print("2",image.shape) 148 | # read mask image 149 | mask = cv2.imread(os.path.join(self.output_path, image_filename[: -3] + "png"),0) 150 | # print("mask",image_filename[: -3] + "png") 151 | # print(np.max(mask), np.min(mask)) 152 | mask = cv2.resize(mask,(self.image_size,self.image_size)) 153 | # print(np.max(mask), np.min(mask)) 154 | mask[mask<=0] = 0 155 | # (mask == 35).astype(int) 156 | mask[mask>0] = 1 157 | # print("11111",np.max(mask), np.min(mask)) 158 | 159 | # correct dimensions if needed 160 | image, mask = correct_dims(image, mask) 161 | # image, mask = F.to_pil_image(image), F.to_pil_image(mask) 162 | # print("11",image.shape) 163 | # print("22",mask.shape) 164 | sample = {'image': image, 'label': mask} 165 | 166 | if self.joint_transform: 167 | sample = self.joint_transform(sample) 168 | # sample = {'image': image, 'label': mask} 169 | # print("2222",np.max(mask), np.min(mask)) 170 | 171 | if self.one_hot_mask: 172 | assert self.one_hot_mask > 0, 'one_hot_mask must be nonnegative' 173 | mask = torch.zeros((self.one_hot_mask, mask.shape[1], mask.shape[2])).scatter_(0, mask.long(), 1) 174 | # mask = np.swapaxes(mask,2,0) 175 | # print(image.shape) 176 | # print("mask",mask) 177 | # mask = np.transpose(mask,(2,0,1)) 178 | # image = np.transpose(image,(2,0,1)) 179 | # print(image.shape) 180 | # print(mask.shape) 181 | # print(sample['image'].shape) 182 | 183 | return sample, image_filename 184 | 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AAAI2022] UCTransNet 2 | 3 | 4 | This repo is the official implementation of 5 | ['UCTransNet: Rethinking the Skip Connections in U-Net from 6 | a Channel-wise Perspective with Transformer'](https://ojs.aaai.org/index.php/AAAI/article/view/20144) which is accepted at AAAI2022. 7 | 8 | ![framework](https://github.com/McGregorWwww/UCTransNet/blob/main/docs/Framework.jpg) 9 | 10 | We propose a Channel Transformer module (CTrans) and use it to 11 | replace the skip connections in original U-Net, thus we name it 'U-CTrans-Net'. 12 | 13 | 14 | **[Online Presentation Video](https://www.bilibili.com/video/BV1ZF411p7PM?spm_id_from=333.999.0.0) is available for brief introduction.** 15 | 16 | 🔥🔥🔥 For an improved version of UCTransNet, please refer to [UDTransNet](https://github.com/McGregorWwww/UDTransNet) ([Narrowing the semantic gaps in U-Net with learnable skip connections: The case of medical image segmentation 17 | ](https://arxiv.org/abs/2312.15182)), which achieves higher performance and lower computational cost. 🔥🔥🔥 18 | 19 | 20 | ## Requirements 21 | 22 | Install from the ```requirements.txt``` using: 23 | ```angular2html 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ## Usage 28 | 29 | *Note: If you have some problems with the code, the [issues](https://github.com/McGregorWwww/UCTransNet/issues?q=is%3Aissue+is%3Aclosed) may help.* 30 | 31 | ### 1. Data Preparation 32 | #### 1.1. GlaS and MoNuSeg Datasets 33 | The original data can be downloaded in following links: 34 | * MoNuSeg Dataset - [Link (Original)](https://monuseg.grand-challenge.org/Data/) 35 | * GLAS Dataset - [Link (Original)](https://warwick.ac.uk/fac/cross_fac/tia/data/glascontest) 36 | 37 | Then prepare the datasets in the following format for easy use of the code: 38 | ```angular2html 39 | ├── datasets 40 |    ├── GlaS 41 |    │   ├── Test_Folder 42 |    │   │   ├── img 43 |    │   │   └── labelcol 44 |    │   ├── Train_Folder 45 |    │   │   ├── img 46 |    │   │   └── labelcol 47 |    │   └── Val_Folder 48 |    │   ├── img 49 |    │   └── labelcol 50 |    └── MoNuSeg 51 |       ├── Test_Folder 52 |       │   ├── img 53 |       │   └── labelcol 54 |       ├── Train_Folder 55 |       │   ├── img 56 |       │   └── labelcol 57 |       └── Val_Folder 58 |       ├── img 59 |       └── labelcol 60 | ``` 61 | #### 1.2. Synapse Dataset 62 | The Synapse dataset we used is provided by TransUNet's authors. 63 | Please go to [https://github.com/Beckschen/TransUNet/blob/main/datasets/README.md](https://github.com/Beckschen/TransUNet/blob/main/datasets/README.md) 64 | for details. 65 | 66 | #### (Optional) 🔥🔥 Using customized datasets. 67 | - If you want to implement UCTransNet on a customized dataset, the easiest way is to organize the file structure similar to GlaS as described above. 68 | 69 | - Ensure that the images are in the `.jpg` format, and the mask IDs should match the image IDs but with the `.png` extension. 70 | 71 | - Any inconsistencies in the file structure or naming conventions may result in I/O errors. 72 | 73 | 74 | ### 2. Training 75 | As mentioned in the paper, we introduce two strategies 76 | to optimize UCTransNet. 77 | 78 | The first step is to change the settings in ```Config.py```, 79 | all the configurations including learning rate, batch size and etc. are 80 | in it. 81 | 82 | #### 2.1 Jointly Training 83 | We optimize the convolution parameters 84 | in U-Net and the CTrans parameters together with a single loss. 85 | Run: 86 | ```angular2html 87 | python train_model.py 88 | ``` 89 | 90 | #### 2.2 Pre-training 91 | 92 | Our method just replaces the skip connections in U-Net, 93 | so the parameters in U-Net can be used as part of pretrained weights. 94 | 95 | By first training a classical U-Net using ```/nets/UNet.py``` 96 | then using the pretrained weights to train the UCTransNet, 97 | CTrans module can get better initial features. 98 | 99 | This strategy can improve the convergence speed and may 100 | improve the final segmentation performance in some cases. 101 | 102 | 103 | ### 3. Testing 104 | #### 3.1. Get Pre-trained Models 105 | Here, we provide pre-trained weights on GlaS and MoNuSeg, if you do not want to train the models by yourself, you can download them in the following links: 106 | * GlaS:https://drive.google.com/file/d/1ciAwb2-0G1pZrt_lgSwd-7vH1STmxdYe/view?usp=sharing 107 | * MoNuSeg: https://drive.google.com/file/d/1CJvHoh3VrPsBn_njZDo6SvJF_yAVe5MK/view?usp=sharing 108 | #### 3.2. Test the Model and Visualize the Segmentation Results 109 | First, change the session name in ```Config.py``` as the training phase. 110 | Then run: 111 | ```angular2html 112 | python test_model.py 113 | ``` 114 | You can get the Dice and IoU scores and the visualization results. 115 | 116 | 🔥🔥 **The testing results of all classes in Synapse dataset can be downloaded through [this link](https://drive.google.com/file/d/1E-ZJLkNc0AJSUKI1CCWdcROMS9wERI9s/view?usp=sharing).** 🔥🔥 117 | 118 | 119 | ### 4. Reproducibility 120 | In our code, we carefully set the random seed and set cudnn as 'deterministic' mode to eliminate the randomness. 121 | However, there still exsist some factors which may cause different training results, e.g., the cuda version, GPU types, the number of GPUs and etc. The GPU used in our experiments is NVIDIA A40 (48G) and the cuda version is 11.2. 122 | 123 | Especially for multi-GPU cases, the upsampling operation has big problems with randomness. 124 | See https://pytorch.org/docs/stable/notes/randomness.html for more details. 125 | 126 | When training, we suggest to train the model twice to verify wheather the randomness is eliminated. Because we use the early stopping strategy, **the final performance may change significantly due to the randomness**. 127 | 128 | ## Reference 129 | 130 | 131 | * UNet++: https://github.com/qubvel/segmentation_models.pytorch 132 | * Attention U-Net: https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets 133 | * MultiResUNet: https://github.com/makifozkanoglu/MultiResUNet-PyTorch 134 | * TransUNet: https://github.com/Beckschen/TransUNet 135 | * Swin-Unet: https://github.com/HuCaoFighting/Swin-Unet 136 | * MedT: https://github.com/jeya-maria-jose/Medical-Transformer 137 | 138 | 139 | 140 | ## Citations 141 | 142 | 143 | If this code is helpful for your study, please cite: 144 | ``` 145 | @article{UCTransNet, 146 | title={UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-Wise Perspective with Transformer}, 147 | volume={36}, 148 | url={https://ojs.aaai.org/index.php/AAAI/article/view/20144}, 149 | DOI={10.1609/aaai.v36i3.20144}, 150 | number={3}, 151 | journal={Proceedings of the AAAI Conference on Artificial Intelligence}, 152 | author={Wang, Haonan and Cao, Peng and Wang, Jiaqi and Zaiane, Osmar R.}, 153 | year={2022}, 154 | month={Jun.}, 155 | pages={2441-2449}} 156 | ``` 157 | 158 | 159 | ## Contact 160 | Haonan Wang ([haonan1wang@gmail.com](haonan1wang@gmail.com)) 161 | -------------------------------------------------------------------------------- /Train_one_epoch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/6/19 2:14 下午 3 | # @Author : Haonan Wang 4 | # @File : Train_one_epoch.py 5 | # @Software: PyCharm 6 | import torch.optim 7 | import os 8 | import time 9 | from utils import * 10 | import Config as config 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | 14 | 15 | def print_summary(epoch, i, nb_batch, loss, loss_name, batch_time, 16 | average_loss, average_time, iou, average_iou, 17 | dice, average_dice, acc, average_acc, mode, lr, logger): 18 | ''' 19 | mode = Train or Test 20 | ''' 21 | summary = ' [' + str(mode) + '] Epoch: [{0}][{1}/{2}] '.format( 22 | epoch, i, nb_batch) 23 | string = '' 24 | string += 'Loss:{:.3f} '.format(loss) 25 | string += '(Avg {:.4f}) '.format(average_loss) 26 | # string += 'IoU:{:.3f} '.format(iou) 27 | # string += '(Avg {:.4f}) '.format(average_iou) 28 | string += 'Dice:{:.4f} '.format(dice) 29 | string += '(Avg {:.4f}) '.format(average_dice) 30 | # string += 'Acc:{:.3f} '.format(acc) 31 | # string += '(Avg {:.4f}) '.format(average_acc) 32 | if mode == 'Train': 33 | string += 'LR {:.2e} '.format(lr) 34 | # string += 'Time {:.1f} '.format(batch_time) 35 | string += '(AvgTime {:.1f}) '.format(average_time) 36 | summary += string 37 | logger.info(summary) 38 | # print summary 39 | 40 | 41 | ################################################################################## 42 | #================================================================================= 43 | # Train One Epoch 44 | #================================================================================= 45 | ################################################################################## 46 | def train_one_epoch(loader, model, criterion, optimizer, writer, epoch, lr_scheduler, model_type, logger): 47 | logging_mode = 'Train' if model.training else 'Val' 48 | 49 | end = time.time() 50 | time_sum, loss_sum = 0, 0 51 | dice_sum, iou_sum, acc_sum = 0.0, 0.0, 0.0 52 | 53 | dices = [] 54 | for i, (sampled_batch, names) in enumerate(loader, 1): 55 | 56 | try: 57 | loss_name = criterion._get_name() 58 | except AttributeError: 59 | loss_name = criterion.__name__ 60 | 61 | # Take variable and put them to GPU 62 | images, masks = sampled_batch['image'], sampled_batch['label'] 63 | images, masks = images.cuda(), masks.cuda() 64 | 65 | 66 | # ==================================================== 67 | # Compute loss 68 | # ==================================================== 69 | 70 | preds = model(images) 71 | out_loss = criterion(preds, masks.float()) # Loss 72 | 73 | 74 | if model.training: 75 | optimizer.zero_grad() 76 | out_loss.backward() 77 | optimizer.step() 78 | 79 | # print(masks.size()) 80 | # print(preds.size()) 81 | 82 | 83 | # train_iou = 0 84 | train_iou = iou_on_batch(masks,preds) 85 | train_dice = criterion._show_dice(preds, masks.float()) 86 | 87 | batch_time = time.time() - end 88 | # train_acc = acc_on_batch(masks,preds) 89 | if epoch % config.vis_frequency == 0 and logging_mode is 'Val': 90 | vis_path = config.visualize_path+str(epoch)+'/' 91 | if not os.path.isdir(vis_path): 92 | os.makedirs(vis_path) 93 | save_on_batch(images,masks,preds,names,vis_path) 94 | dices.append(train_dice) 95 | 96 | time_sum += len(images) * batch_time 97 | loss_sum += len(images) * out_loss 98 | iou_sum += len(images) * train_iou 99 | # acc_sum += len(images) * train_acc 100 | dice_sum += len(images) * train_dice 101 | 102 | if i == len(loader): 103 | average_loss = loss_sum / (config.batch_size*(i-1) + len(images)) 104 | average_time = time_sum / (config.batch_size*(i-1) + len(images)) 105 | train_iou_average = iou_sum / (config.batch_size*(i-1) + len(images)) 106 | # train_acc_average = acc_sum / (config.batch_size*(i-1) + len(images)) 107 | train_dice_avg = dice_sum / (config.batch_size*(i-1) + len(images)) 108 | else: 109 | average_loss = loss_sum / (i * config.batch_size) 110 | average_time = time_sum / (i * config.batch_size) 111 | train_iou_average = iou_sum / (i * config.batch_size) 112 | # train_acc_average = acc_sum / (i * config.batch_size) 113 | train_dice_avg = dice_sum / (i * config.batch_size) 114 | 115 | end = time.time() 116 | torch.cuda.empty_cache() 117 | 118 | if i % config.print_frequency == 0: 119 | print_summary(epoch + 1, i, len(loader), out_loss, loss_name, batch_time, 120 | average_loss, average_time, train_iou, train_iou_average, 121 | train_dice, train_dice_avg, 0, 0, logging_mode, 122 | lr=min(g["lr"] for g in optimizer.param_groups),logger=logger) 123 | 124 | if config.tensorboard: 125 | step = epoch * len(loader) + i 126 | writer.add_scalar(logging_mode + '_' + loss_name, out_loss.item(), step) 127 | 128 | # plot metrics in tensorboard 129 | writer.add_scalar(logging_mode + '_iou', train_iou, step) 130 | # writer.add_scalar(logging_mode + '_acc', train_acc, step) 131 | writer.add_scalar(logging_mode + '_dice', train_dice, step) 132 | 133 | torch.cuda.empty_cache() 134 | 135 | if lr_scheduler is not None: 136 | lr_scheduler.step() 137 | # if epoch + 1 > 10: # Plateau 138 | # if lr_scheduler is not None: 139 | # lr_scheduler.step(train_dice_avg) 140 | return average_loss, train_dice_avg 141 | 142 | -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-2Z-A9J9-01A-01-TS1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-2Z-A9J9-01A-01-TS1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-44-2665-01B-06-BS6.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-44-2665-01B-06-BS6.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-69-7764-01A-01-TS1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-69-7764-01A-01-TS1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-A6-6782-01A-01-BS1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-A6-6782-01A-01-BS1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-AC-A2FO-01A-01-TS1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-AC-A2FO-01A-01-TS1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-AO-A0J2-01A-01-BSA.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-AO-A0J2-01A-01-BSA.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-CU-A0YN-01A-02-BSB.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-CU-A0YN-01A-02-BSB.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-EJ-A46H-01A-03-TSC.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-EJ-A46H-01A-03-TSC.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-FG-A4MU-01B-01-TS1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-FG-A4MU-01B-01-TS1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-GL-6846-01A-01-BS1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-GL-6846-01A-01-BS1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-HC-7209-01A-01-TS1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-HC-7209-01A-01-TS1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-HT-8564-01Z-00-DX1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-HT-8564-01Z-00-DX1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-IZ-8196-01A-01-BS1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-IZ-8196-01A-01-BS1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/img/TCGA-ZF-A9R5-01A-01-TS1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/img/TCGA-ZF-A9R5-01A-01-TS1.tif -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-2Z-A9J9-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-2Z-A9J9-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-44-2665-01B-06-BS6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-44-2665-01B-06-BS6.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-69-7764-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-69-7764-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-A6-6782-01A-01-BS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-A6-6782-01A-01-BS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-AC-A2FO-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-AC-A2FO-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-AO-A0J2-01A-01-BSA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-AO-A0J2-01A-01-BSA.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-CU-A0YN-01A-02-BSB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-CU-A0YN-01A-02-BSB.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-EJ-A46H-01A-03-TSC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-EJ-A46H-01A-03-TSC.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-FG-A4MU-01B-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-FG-A4MU-01B-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-GL-6846-01A-01-BS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-GL-6846-01A-01-BS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-HC-7209-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-HC-7209-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-HT-8564-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-HT-8564-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-IZ-8196-01A-01-BS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-IZ-8196-01A-01-BS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Test_Folder/labelcol/TCGA-ZF-A9R5-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Test_Folder/labelcol/TCGA-ZF-A9R5-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-21-5784-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-21-5784-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-21-5786-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-21-5786-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-38-6178-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-38-6178-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-49-4488-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-49-4488-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-50-5931-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-50-5931-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-A7-A13E-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-A7-A13E-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-A7-A13F-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-A7-A13F-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-AR-A1AK-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-AR-A1AK-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-AR-A1AS-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-AR-A1AS-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-B0-5698-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-B0-5698-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-B0-5710-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-B0-5710-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-B0-5711-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-B0-5711-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-CH-5767-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-CH-5767-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-DK-A2I6-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-DK-A2I6-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-G2-A2EK-01A-02-TSB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-G2-A2EK-01A-02-TSB.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-G9-6336-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-G9-6336-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-G9-6348-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-G9-6348-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-G9-6356-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-G9-6356-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-G9-6362-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-G9-6362-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-HE-7128-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-HE-7128-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-HE-7130-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-HE-7130-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-KB-A93J-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-KB-A93J-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-NH-A8F7-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-NH-A8F7-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/img/TCGA-RD-A8N9-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/img/TCGA-RD-A8N9-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-21-5784-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-21-5784-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-21-5786-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-21-5786-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-38-6178-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-38-6178-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-49-4488-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-49-4488-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-50-5931-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-50-5931-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-A7-A13E-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-A7-A13E-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-A7-A13F-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-A7-A13F-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-AR-A1AK-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-AR-A1AK-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-AR-A1AS-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-AR-A1AS-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-B0-5698-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-B0-5698-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-B0-5710-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-B0-5710-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-B0-5711-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-B0-5711-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-CH-5767-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-CH-5767-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-DK-A2I6-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-DK-A2I6-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G2-A2EK-01A-02-TSB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G2-A2EK-01A-02-TSB.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G9-6336-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G9-6336-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G9-6348-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G9-6348-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G9-6356-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G9-6356-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G9-6362-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-G9-6362-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-HE-7128-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-HE-7128-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-HE-7130-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-HE-7130-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-KB-A93J-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-KB-A93J-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-NH-A8F7-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-NH-A8F7-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Train_Folder/labelcol/TCGA-RD-A8N9-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Train_Folder/labelcol/TCGA-RD-A8N9-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/img/TCGA-18-5592-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/img/TCGA-18-5592-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/img/TCGA-AY-A8YK-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/img/TCGA-AY-A8YK-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/img/TCGA-E2-A14V-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/img/TCGA-E2-A14V-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/img/TCGA-E2-A1B5-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/img/TCGA-E2-A1B5-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/img/TCGA-G9-6363-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/img/TCGA-G9-6363-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/img/TCGA-HE-7129-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/img/TCGA-HE-7129-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/labelcol/TCGA-18-5592-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/labelcol/TCGA-18-5592-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/labelcol/TCGA-AY-A8YK-01A-01-TS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/labelcol/TCGA-AY-A8YK-01A-01-TS1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/labelcol/TCGA-E2-A14V-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/labelcol/TCGA-E2-A14V-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/labelcol/TCGA-E2-A1B5-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/labelcol/TCGA-E2-A1B5-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/labelcol/TCGA-G9-6363-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/labelcol/TCGA-G9-6363-01Z-00-DX1.png -------------------------------------------------------------------------------- /datasets/MoNuSeg/Val_Folder/labelcol/TCGA-HE-7129-01Z-00-DX1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/datasets/MoNuSeg/Val_Folder/labelcol/TCGA-HE-7129-01Z-00-DX1.png -------------------------------------------------------------------------------- /docs/Framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGregorWwww/UCTransNet/493bd8136b2bb7265d7876104f56688363c8d0d1/docs/Framework.jpg -------------------------------------------------------------------------------- /nets/CTrans.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Haonan Wang 3 | # @File : CTrans.py 4 | # @Software: PyCharm 5 | # coding=utf-8 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | import copy 10 | import logging 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | from torch.nn import Dropout, Softmax, Conv2d, LayerNorm 16 | from torch.nn.modules.utils import _pair 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class Channel_Embeddings(nn.Module): 22 | """Construct the embeddings from patch, position embeddings. 23 | """ 24 | def __init__(self,config, patchsize, img_size, in_channels): 25 | super().__init__() 26 | img_size = _pair(img_size) 27 | patch_size = _pair(patchsize) 28 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 29 | 30 | self.patch_embeddings = Conv2d(in_channels=in_channels, 31 | out_channels=in_channels, 32 | kernel_size=patch_size, 33 | stride=patch_size) 34 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, in_channels)) 35 | self.dropout = Dropout(config.transformer["embeddings_dropout_rate"]) 36 | 37 | def forward(self, x): 38 | if x is None: 39 | return None 40 | x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) 41 | x = x.flatten(2) 42 | x = x.transpose(-1, -2) # (B, n_patches, hidden) 43 | embeddings = x + self.position_embeddings 44 | embeddings = self.dropout(embeddings) 45 | return embeddings 46 | 47 | class Reconstruct(nn.Module): 48 | def __init__(self, in_channels, out_channels, kernel_size, scale_factor): 49 | super(Reconstruct, self).__init__() 50 | if kernel_size == 3: 51 | padding = 1 52 | else: 53 | padding = 0 54 | self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=kernel_size, padding=padding) 55 | self.norm = nn.BatchNorm2d(out_channels) 56 | self.activation = nn.ReLU(inplace=True) 57 | self.scale_factor = scale_factor 58 | 59 | def forward(self, x): 60 | if x is None: 61 | return None 62 | 63 | B, n_patch, hidden = x.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) 64 | h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) 65 | x = x.permute(0, 2, 1) 66 | x = x.contiguous().view(B, hidden, h, w) 67 | x = nn.Upsample(scale_factor=self.scale_factor)(x) 68 | 69 | out = self.conv(x) 70 | out = self.norm(out) 71 | out = self.activation(out) 72 | return out 73 | 74 | class Attention_org(nn.Module): 75 | def __init__(self, config, vis,channel_num): 76 | super(Attention_org, self).__init__() 77 | self.vis = vis 78 | self.KV_size = config.KV_size 79 | self.channel_num = channel_num 80 | self.num_attention_heads = config.transformer["num_heads"] 81 | 82 | self.query1 = nn.ModuleList() 83 | self.query2 = nn.ModuleList() 84 | self.query3 = nn.ModuleList() 85 | self.query4 = nn.ModuleList() 86 | self.key = nn.ModuleList() 87 | self.value = nn.ModuleList() 88 | 89 | for _ in range(config.transformer["num_heads"]): 90 | query1 = nn.Linear(channel_num[0], channel_num[0], bias=False) 91 | query2 = nn.Linear(channel_num[1], channel_num[1], bias=False) 92 | query3 = nn.Linear(channel_num[2], channel_num[2], bias=False) 93 | query4 = nn.Linear(channel_num[3], channel_num[3], bias=False) 94 | key = nn.Linear( self.KV_size, self.KV_size, bias=False) 95 | value = nn.Linear(self.KV_size, self.KV_size, bias=False) 96 | self.query1.append(copy.deepcopy(query1)) 97 | self.query2.append(copy.deepcopy(query2)) 98 | self.query3.append(copy.deepcopy(query3)) 99 | self.query4.append(copy.deepcopy(query4)) 100 | self.key.append(copy.deepcopy(key)) 101 | self.value.append(copy.deepcopy(value)) 102 | self.psi = nn.InstanceNorm2d(self.num_attention_heads) 103 | self.softmax = Softmax(dim=3) 104 | self.out1 = nn.Linear(channel_num[0], channel_num[0], bias=False) 105 | self.out2 = nn.Linear(channel_num[1], channel_num[1], bias=False) 106 | self.out3 = nn.Linear(channel_num[2], channel_num[2], bias=False) 107 | self.out4 = nn.Linear(channel_num[3], channel_num[3], bias=False) 108 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 109 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 110 | 111 | 112 | 113 | def forward(self, emb1,emb2,emb3,emb4, emb_all): 114 | multi_head_Q1_list = [] 115 | multi_head_Q2_list = [] 116 | multi_head_Q3_list = [] 117 | multi_head_Q4_list = [] 118 | multi_head_K_list = [] 119 | multi_head_V_list = [] 120 | if emb1 is not None: 121 | for query1 in self.query1: 122 | Q1 = query1(emb1) 123 | multi_head_Q1_list.append(Q1) 124 | if emb2 is not None: 125 | for query2 in self.query2: 126 | Q2 = query2(emb2) 127 | multi_head_Q2_list.append(Q2) 128 | if emb3 is not None: 129 | for query3 in self.query3: 130 | Q3 = query3(emb3) 131 | multi_head_Q3_list.append(Q3) 132 | if emb4 is not None: 133 | for query4 in self.query4: 134 | Q4 = query4(emb4) 135 | multi_head_Q4_list.append(Q4) 136 | for key in self.key: 137 | K = key(emb_all) 138 | multi_head_K_list.append(K) 139 | for value in self.value: 140 | V = value(emb_all) 141 | multi_head_V_list.append(V) 142 | # print(len(multi_head_Q4_list)) 143 | 144 | multi_head_Q1 = torch.stack(multi_head_Q1_list, dim=1) if emb1 is not None else None 145 | multi_head_Q2 = torch.stack(multi_head_Q2_list, dim=1) if emb2 is not None else None 146 | multi_head_Q3 = torch.stack(multi_head_Q3_list, dim=1) if emb3 is not None else None 147 | multi_head_Q4 = torch.stack(multi_head_Q4_list, dim=1) if emb4 is not None else None 148 | multi_head_K = torch.stack(multi_head_K_list, dim=1) 149 | multi_head_V = torch.stack(multi_head_V_list, dim=1) 150 | 151 | multi_head_Q1 = multi_head_Q1.transpose(-1, -2) if emb1 is not None else None 152 | multi_head_Q2 = multi_head_Q2.transpose(-1, -2) if emb2 is not None else None 153 | multi_head_Q3 = multi_head_Q3.transpose(-1, -2) if emb3 is not None else None 154 | multi_head_Q4 = multi_head_Q4.transpose(-1, -2) if emb4 is not None else None 155 | 156 | attention_scores1 = torch.matmul(multi_head_Q1, multi_head_K) if emb1 is not None else None 157 | attention_scores2 = torch.matmul(multi_head_Q2, multi_head_K) if emb2 is not None else None 158 | attention_scores3 = torch.matmul(multi_head_Q3, multi_head_K) if emb3 is not None else None 159 | attention_scores4 = torch.matmul(multi_head_Q4, multi_head_K) if emb4 is not None else None 160 | 161 | attention_scores1 = attention_scores1 / math.sqrt(self.KV_size) if emb1 is not None else None 162 | attention_scores2 = attention_scores2 / math.sqrt(self.KV_size) if emb2 is not None else None 163 | attention_scores3 = attention_scores3 / math.sqrt(self.KV_size) if emb3 is not None else None 164 | attention_scores4 = attention_scores4 / math.sqrt(self.KV_size) if emb4 is not None else None 165 | 166 | attention_probs1 = self.softmax(self.psi(attention_scores1)) if emb1 is not None else None 167 | attention_probs2 = self.softmax(self.psi(attention_scores2)) if emb2 is not None else None 168 | attention_probs3 = self.softmax(self.psi(attention_scores3)) if emb3 is not None else None 169 | attention_probs4 = self.softmax(self.psi(attention_scores4)) if emb4 is not None else None 170 | # print(attention_probs4.size()) 171 | 172 | if self.vis: 173 | weights = [] 174 | weights.append(attention_probs1.mean(1)) 175 | weights.append(attention_probs2.mean(1)) 176 | weights.append(attention_probs3.mean(1)) 177 | weights.append(attention_probs4.mean(1)) 178 | else: weights=None 179 | 180 | attention_probs1 = self.attn_dropout(attention_probs1) if emb1 is not None else None 181 | attention_probs2 = self.attn_dropout(attention_probs2) if emb2 is not None else None 182 | attention_probs3 = self.attn_dropout(attention_probs3) if emb3 is not None else None 183 | attention_probs4 = self.attn_dropout(attention_probs4) if emb4 is not None else None 184 | 185 | multi_head_V = multi_head_V.transpose(-1, -2) 186 | context_layer1 = torch.matmul(attention_probs1, multi_head_V) if emb1 is not None else None 187 | context_layer2 = torch.matmul(attention_probs2, multi_head_V) if emb2 is not None else None 188 | context_layer3 = torch.matmul(attention_probs3, multi_head_V) if emb3 is not None else None 189 | context_layer4 = torch.matmul(attention_probs4, multi_head_V) if emb4 is not None else None 190 | 191 | context_layer1 = context_layer1.permute(0, 3, 2, 1).contiguous() if emb1 is not None else None 192 | context_layer2 = context_layer2.permute(0, 3, 2, 1).contiguous() if emb2 is not None else None 193 | context_layer3 = context_layer3.permute(0, 3, 2, 1).contiguous() if emb3 is not None else None 194 | context_layer4 = context_layer4.permute(0, 3, 2, 1).contiguous() if emb4 is not None else None 195 | context_layer1 = context_layer1.mean(dim=3) if emb1 is not None else None 196 | context_layer2 = context_layer2.mean(dim=3) if emb2 is not None else None 197 | context_layer3 = context_layer3.mean(dim=3) if emb3 is not None else None 198 | context_layer4 = context_layer4.mean(dim=3) if emb4 is not None else None 199 | 200 | O1 = self.out1(context_layer1) if emb1 is not None else None 201 | O2 = self.out2(context_layer2) if emb2 is not None else None 202 | O3 = self.out3(context_layer3) if emb3 is not None else None 203 | O4 = self.out4(context_layer4) if emb4 is not None else None 204 | O1 = self.proj_dropout(O1) if emb1 is not None else None 205 | O2 = self.proj_dropout(O2) if emb2 is not None else None 206 | O3 = self.proj_dropout(O3) if emb3 is not None else None 207 | O4 = self.proj_dropout(O4) if emb4 is not None else None 208 | return O1,O2,O3,O4, weights 209 | 210 | 211 | 212 | 213 | class Mlp(nn.Module): 214 | def __init__(self,config, in_channel, mlp_channel): 215 | super(Mlp, self).__init__() 216 | self.fc1 = nn.Linear(in_channel, mlp_channel) 217 | self.fc2 = nn.Linear(mlp_channel, in_channel) 218 | self.act_fn = nn.GELU() 219 | self.dropout = Dropout(config.transformer["dropout_rate"]) 220 | self._init_weights() 221 | 222 | def _init_weights(self): 223 | nn.init.xavier_uniform_(self.fc1.weight) 224 | nn.init.xavier_uniform_(self.fc2.weight) 225 | nn.init.normal_(self.fc1.bias, std=1e-6) 226 | nn.init.normal_(self.fc2.bias, std=1e-6) 227 | 228 | def forward(self, x): 229 | x = self.fc1(x) 230 | x = self.act_fn(x) 231 | x = self.dropout(x) 232 | x = self.fc2(x) 233 | x = self.dropout(x) 234 | return x 235 | 236 | class Block_ViT(nn.Module): 237 | def __init__(self, config, vis, channel_num): 238 | super(Block_ViT, self).__init__() 239 | expand_ratio = config.expand_ratio 240 | self.attn_norm1 = LayerNorm(channel_num[0],eps=1e-6) 241 | self.attn_norm2 = LayerNorm(channel_num[1],eps=1e-6) 242 | self.attn_norm3 = LayerNorm(channel_num[2],eps=1e-6) 243 | self.attn_norm4 = LayerNorm(channel_num[3],eps=1e-6) 244 | self.attn_norm = LayerNorm(config.KV_size,eps=1e-6) 245 | self.channel_attn = Attention_org(config, vis, channel_num) 246 | 247 | self.ffn_norm1 = LayerNorm(channel_num[0],eps=1e-6) 248 | self.ffn_norm2 = LayerNorm(channel_num[1],eps=1e-6) 249 | self.ffn_norm3 = LayerNorm(channel_num[2],eps=1e-6) 250 | self.ffn_norm4 = LayerNorm(channel_num[3],eps=1e-6) 251 | self.ffn1 = Mlp(config,channel_num[0],channel_num[0]*expand_ratio) 252 | self.ffn2 = Mlp(config,channel_num[1],channel_num[1]*expand_ratio) 253 | self.ffn3 = Mlp(config,channel_num[2],channel_num[2]*expand_ratio) 254 | self.ffn4 = Mlp(config,channel_num[3],channel_num[3]*expand_ratio) 255 | 256 | 257 | def forward(self, emb1,emb2,emb3,emb4): 258 | embcat = [] 259 | org1 = emb1 260 | org2 = emb2 261 | org3 = emb3 262 | org4 = emb4 263 | for i in range(4): 264 | var_name = "emb"+str(i+1) 265 | tmp_var = locals()[var_name] 266 | if tmp_var is not None: 267 | embcat.append(tmp_var) 268 | 269 | emb_all = torch.cat(embcat,dim=2) 270 | cx1 = self.attn_norm1(emb1) if emb1 is not None else None 271 | cx2 = self.attn_norm2(emb2) if emb2 is not None else None 272 | cx3 = self.attn_norm3(emb3) if emb3 is not None else None 273 | cx4 = self.attn_norm4(emb4) if emb4 is not None else None 274 | emb_all = self.attn_norm(emb_all) 275 | cx1,cx2,cx3,cx4, weights = self.channel_attn(cx1,cx2,cx3,cx4,emb_all) 276 | cx1 = org1 + cx1 if emb1 is not None else None 277 | cx2 = org2 + cx2 if emb2 is not None else None 278 | cx3 = org3 + cx3 if emb3 is not None else None 279 | cx4 = org4 + cx4 if emb4 is not None else None 280 | 281 | org1 = cx1 282 | org2 = cx2 283 | org3 = cx3 284 | org4 = cx4 285 | x1 = self.ffn_norm1(cx1) if emb1 is not None else None 286 | x2 = self.ffn_norm2(cx2) if emb2 is not None else None 287 | x3 = self.ffn_norm3(cx3) if emb3 is not None else None 288 | x4 = self.ffn_norm4(cx4) if emb4 is not None else None 289 | x1 = self.ffn1(x1) if emb1 is not None else None 290 | x2 = self.ffn2(x2) if emb2 is not None else None 291 | x3 = self.ffn3(x3) if emb3 is not None else None 292 | x4 = self.ffn4(x4) if emb4 is not None else None 293 | x1 = x1 + org1 if emb1 is not None else None 294 | x2 = x2 + org2 if emb2 is not None else None 295 | x3 = x3 + org3 if emb3 is not None else None 296 | x4 = x4 + org4 if emb4 is not None else None 297 | 298 | return x1, x2, x3, x4, weights 299 | 300 | 301 | class Encoder(nn.Module): 302 | def __init__(self, config, vis, channel_num): 303 | super(Encoder, self).__init__() 304 | self.vis = vis 305 | self.layer = nn.ModuleList() 306 | self.encoder_norm1 = LayerNorm(channel_num[0],eps=1e-6) 307 | self.encoder_norm2 = LayerNorm(channel_num[1],eps=1e-6) 308 | self.encoder_norm3 = LayerNorm(channel_num[2],eps=1e-6) 309 | self.encoder_norm4 = LayerNorm(channel_num[3],eps=1e-6) 310 | for _ in range(config.transformer["num_layers"]): 311 | layer = Block_ViT(config, vis, channel_num) 312 | self.layer.append(copy.deepcopy(layer)) 313 | 314 | def forward(self, emb1,emb2,emb3,emb4): 315 | attn_weights = [] 316 | for layer_block in self.layer: 317 | emb1,emb2,emb3,emb4, weights = layer_block(emb1,emb2,emb3,emb4) 318 | if self.vis: 319 | attn_weights.append(weights) 320 | emb1 = self.encoder_norm1(emb1) if emb1 is not None else None 321 | emb2 = self.encoder_norm2(emb2) if emb2 is not None else None 322 | emb3 = self.encoder_norm3(emb3) if emb3 is not None else None 323 | emb4 = self.encoder_norm4(emb4) if emb4 is not None else None 324 | return emb1,emb2,emb3,emb4, attn_weights 325 | 326 | 327 | class ChannelTransformer(nn.Module): 328 | def __init__(self, config, vis, img_size, channel_num=[64, 128, 256, 512], patchSize=[32, 16, 8, 4]): 329 | super().__init__() 330 | 331 | self.patchSize_1 = patchSize[0] 332 | self.patchSize_2 = patchSize[1] 333 | self.patchSize_3 = patchSize[2] 334 | self.patchSize_4 = patchSize[3] 335 | self.embeddings_1 = Channel_Embeddings(config,self.patchSize_1, img_size=img_size, in_channels=channel_num[0]) 336 | self.embeddings_2 = Channel_Embeddings(config,self.patchSize_2, img_size=img_size//2, in_channels=channel_num[1]) 337 | self.embeddings_3 = Channel_Embeddings(config,self.patchSize_3, img_size=img_size//4, in_channels=channel_num[2]) 338 | self.embeddings_4 = Channel_Embeddings(config,self.patchSize_4, img_size=img_size//8, in_channels=channel_num[3]) 339 | self.encoder = Encoder(config, vis, channel_num) 340 | 341 | self.reconstruct_1 = Reconstruct(channel_num[0], channel_num[0], kernel_size=1,scale_factor=(self.patchSize_1,self.patchSize_1)) 342 | self.reconstruct_2 = Reconstruct(channel_num[1], channel_num[1], kernel_size=1,scale_factor=(self.patchSize_2,self.patchSize_2)) 343 | self.reconstruct_3 = Reconstruct(channel_num[2], channel_num[2], kernel_size=1,scale_factor=(self.patchSize_3,self.patchSize_3)) 344 | self.reconstruct_4 = Reconstruct(channel_num[3], channel_num[3], kernel_size=1,scale_factor=(self.patchSize_4,self.patchSize_4)) 345 | 346 | def forward(self,en1,en2,en3,en4): 347 | 348 | emb1 = self.embeddings_1(en1) 349 | emb2 = self.embeddings_2(en2) 350 | emb3 = self.embeddings_3(en3) 351 | emb4 = self.embeddings_4(en4) 352 | 353 | encoded1, encoded2, encoded3, encoded4, attn_weights = self.encoder(emb1,emb2,emb3,emb4) # (B, n_patch, hidden) 354 | x1 = self.reconstruct_1(encoded1) if en1 is not None else None 355 | x2 = self.reconstruct_2(encoded2) if en2 is not None else None 356 | x3 = self.reconstruct_3(encoded3) if en3 is not None else None 357 | x4 = self.reconstruct_4(encoded4) if en4 is not None else None 358 | 359 | x1 = x1 + en1 if en1 is not None else None 360 | x2 = x2 + en2 if en2 is not None else None 361 | x3 = x3 + en3 if en3 is not None else None 362 | x4 = x4 + en4 if en4 is not None else None 363 | 364 | return x1, x2, x3, x4, attn_weights 365 | 366 | -------------------------------------------------------------------------------- /nets/UCTransNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/7/8 8:59 上午 3 | # @File : UCTransNet.py 4 | # @Software: PyCharm 5 | import torch.nn as nn 6 | import torch 7 | import torch.nn.functional as F 8 | from .CTrans import ChannelTransformer 9 | 10 | def get_activation(activation_type): 11 | activation_type = activation_type.lower() 12 | if hasattr(nn, activation_type): 13 | return getattr(nn, activation_type)() 14 | else: 15 | return nn.ReLU() 16 | 17 | def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'): 18 | layers = [] 19 | layers.append(ConvBatchNorm(in_channels, out_channels, activation)) 20 | 21 | for _ in range(nb_Conv - 1): 22 | layers.append(ConvBatchNorm(out_channels, out_channels, activation)) 23 | return nn.Sequential(*layers) 24 | 25 | class ConvBatchNorm(nn.Module): 26 | """(convolution => [BN] => ReLU)""" 27 | 28 | def __init__(self, in_channels, out_channels, activation='ReLU'): 29 | super(ConvBatchNorm, self).__init__() 30 | self.conv = nn.Conv2d(in_channels, out_channels, 31 | kernel_size=3, padding=1) 32 | self.norm = nn.BatchNorm2d(out_channels) 33 | self.activation = get_activation(activation) 34 | 35 | def forward(self, x): 36 | out = self.conv(x) 37 | out = self.norm(out) 38 | return self.activation(out) 39 | 40 | class DownBlock(nn.Module): 41 | """Downscaling with maxpool convolution""" 42 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): 43 | super(DownBlock, self).__init__() 44 | self.maxpool = nn.MaxPool2d(2) 45 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) 46 | 47 | def forward(self, x): 48 | out = self.maxpool(x) 49 | return self.nConvs(out) 50 | 51 | class Flatten(nn.Module): 52 | def forward(self, x): 53 | return x.view(x.size(0), -1) 54 | 55 | class CCA(nn.Module): 56 | """ 57 | CCA Block 58 | """ 59 | def __init__(self, F_g, F_x): 60 | super().__init__() 61 | self.mlp_x = nn.Sequential( 62 | Flatten(), 63 | nn.Linear(F_x, F_x)) 64 | self.mlp_g = nn.Sequential( 65 | Flatten(), 66 | nn.Linear(F_g, F_x)) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | def forward(self, g, x): 70 | # channel-wise attention 71 | avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 72 | channel_att_x = self.mlp_x(avg_pool_x) 73 | avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3))) 74 | channel_att_g = self.mlp_g(avg_pool_g) 75 | channel_att_sum = (channel_att_x + channel_att_g)/2.0 76 | scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) 77 | x_after_channel = x * scale 78 | out = self.relu(x_after_channel) 79 | return out 80 | 81 | class UpBlock_attention(nn.Module): 82 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): 83 | super().__init__() 84 | self.up = nn.Upsample(scale_factor=2) 85 | self.coatt = CCA(F_g=in_channels//2, F_x=in_channels//2) 86 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) 87 | 88 | def forward(self, x, skip_x): 89 | up = self.up(x) 90 | skip_x_att = self.coatt(g=up, x=skip_x) 91 | x = torch.cat([skip_x_att, up], dim=1) # dim 1 is the channel dimension 92 | return self.nConvs(x) 93 | 94 | class UCTransNet(nn.Module): 95 | def __init__(self, config,n_channels=3, n_classes=1,img_size=224,vis=False): 96 | super().__init__() 97 | self.vis = vis 98 | self.n_channels = n_channels 99 | self.n_classes = n_classes 100 | in_channels = config.base_channel 101 | self.inc = ConvBatchNorm(n_channels, in_channels) 102 | self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2) 103 | self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2) 104 | self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2) 105 | self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2) 106 | self.mtc = ChannelTransformer(config, vis, img_size, 107 | channel_num=[in_channels, in_channels*2, in_channels*4, in_channels*8], 108 | patchSize=config.patch_sizes) 109 | self.up4 = UpBlock_attention(in_channels*16, in_channels*4, nb_Conv=2) 110 | self.up3 = UpBlock_attention(in_channels*8, in_channels*2, nb_Conv=2) 111 | self.up2 = UpBlock_attention(in_channels*4, in_channels, nb_Conv=2) 112 | self.up1 = UpBlock_attention(in_channels*2, in_channels, nb_Conv=2) 113 | self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1), stride=(1,1)) 114 | self.last_activation = nn.Sigmoid() # if using BCELoss 115 | 116 | def forward(self, x): 117 | x = x.float() 118 | x1 = self.inc(x) 119 | x2 = self.down1(x1) 120 | x3 = self.down2(x2) 121 | x4 = self.down3(x3) 122 | x5 = self.down4(x4) 123 | x1,x2,x3,x4,att_weights = self.mtc(x1,x2,x3,x4) 124 | x = self.up4(x5, x4) 125 | x = self.up3(x, x3) 126 | x = self.up2(x, x2) 127 | x = self.up1(x, x1) 128 | if self.n_classes ==1: 129 | logits = self.last_activation(self.outc(x)) 130 | else: 131 | logits = self.outc(x) # if nusing BCEWithLogitsLoss or class>1 132 | if self.vis: # visualize the attention maps 133 | return logits, att_weights 134 | else: 135 | return logits 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /nets/UNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | def get_activation(activation_type): 5 | activation_type = activation_type.lower() 6 | if hasattr(nn, activation_type): 7 | return getattr(nn, activation_type)() 8 | else: 9 | return nn.ReLU() 10 | 11 | def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'): 12 | layers = [] 13 | layers.append(ConvBatchNorm(in_channels, out_channels, activation)) 14 | 15 | for _ in range(nb_Conv - 1): 16 | layers.append(ConvBatchNorm(out_channels, out_channels, activation)) 17 | return nn.Sequential(*layers) 18 | 19 | class ConvBatchNorm(nn.Module): 20 | """(convolution => [BN] => ReLU)""" 21 | 22 | def __init__(self, in_channels, out_channels, activation='ReLU'): 23 | super(ConvBatchNorm, self).__init__() 24 | self.conv = nn.Conv2d(in_channels, out_channels, 25 | kernel_size=3, padding=1) 26 | self.norm = nn.BatchNorm2d(out_channels) 27 | self.activation = get_activation(activation) 28 | 29 | def forward(self, x): 30 | out = self.conv(x) 31 | out = self.norm(out) 32 | return self.activation(out) 33 | 34 | class DownBlock(nn.Module): 35 | """Downscaling with maxpool convolution""" 36 | 37 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): 38 | super(DownBlock, self).__init__() 39 | self.maxpool = nn.MaxPool2d(2) 40 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) 41 | 42 | def forward(self, x): 43 | out = self.maxpool(x) 44 | return self.nConvs(out) 45 | 46 | class UpBlock(nn.Module): 47 | """Upscaling then conv""" 48 | 49 | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): 50 | super(UpBlock, self).__init__() 51 | 52 | # self.up = nn.Upsample(scale_factor=2) 53 | self.up = nn.ConvTranspose2d(in_channels//2,in_channels//2,(2,2),2) 54 | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) 55 | 56 | def forward(self, x, skip_x): 57 | out = self.up(x) 58 | x = torch.cat([out, skip_x], dim=1) # dim 1 is the channel dimension 59 | return self.nConvs(x) 60 | 61 | class UNet(nn.Module): 62 | def __init__(self, n_channels=3, n_classes=9): 63 | ''' 64 | n_channels : number of channels of the input. 65 | By default 3, because we have RGB images 66 | n_labels : number of channels of the ouput. 67 | By default 3 (2 labels + 1 for the background) 68 | ''' 69 | super().__init__() 70 | self.n_channels = n_channels 71 | self.n_classes = n_classes 72 | # Question here 73 | in_channels = 64 74 | self.inc = ConvBatchNorm(n_channels, in_channels) 75 | self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2) 76 | self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2) 77 | self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2) 78 | self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2) 79 | self.up4 = UpBlock(in_channels*16, in_channels*4, nb_Conv=2) 80 | self.up3 = UpBlock(in_channels*8, in_channels*2, nb_Conv=2) 81 | self.up2 = UpBlock(in_channels*4, in_channels, nb_Conv=2) 82 | self.up1 = UpBlock(in_channels*2, in_channels, nb_Conv=2) 83 | self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1)) 84 | if n_classes == 1: 85 | self.last_activation = nn.Sigmoid() 86 | else: 87 | self.last_activation = None 88 | 89 | def forward(self, x): 90 | # Question here 91 | x = x.float() 92 | x1 = self.inc(x) 93 | x2 = self.down1(x1) 94 | x3 = self.down2(x2) 95 | x4 = self.down3(x3) 96 | x5 = self.down4(x4) 97 | x = self.up4(x5, x4) 98 | x = self.up3(x, x3) 99 | x = self.up2(x, x2) 100 | x = self.up1(x, x1) 101 | if self.last_activation is not None: 102 | logits = self.last_activation(self.outc(x)) 103 | # print("111") 104 | else: 105 | logits = self.outc(x) 106 | # print("222") 107 | # logits = self.outc(x) # if using BCEWithLogitsLoss 108 | # print(logits.size()) 109 | return logits 110 | 111 | 112 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | albumentations==0.5.2 3 | cached-property==1.5.2 4 | certifi==2021.5.30 5 | contextlib2==21.6.0 6 | cssselect==1.1.0 7 | cycler==0.10.0 8 | decorator==4.4.2 9 | einops==0.3.0 10 | entmax==1.0 11 | h5py==3.3.0 12 | imageio==2.9.0 13 | imgaug==0.4.0 14 | joblib==1.0.1 15 | kiwisolver==1.3.1 16 | matplotlib==3.3.4 17 | mkl-fft==1.3.0 18 | mkl-random==1.2.2 19 | mkl-service==2.4.0 20 | ml-collections==0.1.0 21 | munch==2.5.0 22 | networkx==2.5.1 23 | numpy 24 | opencv-python==4.5.1.48 25 | opencv-python-headless==4.5.2.54 26 | pandas==1.1.5 27 | Pillow==8.3.1 28 | pretrainedmodels==0.7.4 29 | protobuf==3.17.3 30 | pyparsing==2.4.7 31 | pyquery 32 | python-dateutil==2.8.1 33 | pytz==2021.1 34 | PyWavelets==1.1.1 35 | PyYAML==5.4.1 36 | scikit-image==0.17.2 37 | scikit-learn==0.24.2 38 | scipy==1.5.4 39 | seaborn==0.11.2 40 | Shapely==1.7.1 41 | SimpleITK==2.0.2 42 | tensorboardX==2.4 43 | threadpoolctl==2.1.0 44 | tifffile==2021.7.2 45 | timm==0.4.12 46 | torch==1.8.1 47 | torchvision==0.9.1 48 | tqdm==4.61.2 49 | xlrd==1.2.0 50 | yacs==0.1.8 51 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import torch.optim 2 | from Load_Dataset import ValGenerator, ImageToImage2D 3 | from torch.utils.data import DataLoader 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | import Config as config 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | import os 10 | from nets.UCTransNet import UCTransNet 11 | from utils import * 12 | import cv2 13 | 14 | 15 | def show_image_with_dice(predict_save, labs, save_path): 16 | 17 | tmp_lbl = (labs).astype(np.float32) 18 | tmp_3dunet = (predict_save).astype(np.float32) 19 | dice_pred = 2 * np.sum(tmp_lbl * tmp_3dunet) / (np.sum(tmp_lbl) + np.sum(tmp_3dunet) + 1e-5) 20 | # dice_show = "%.3f" % (dice_pred) 21 | iou_pred = jaccard_score(tmp_lbl.reshape(-1),tmp_3dunet.reshape(-1)) 22 | # fig, ax = plt.subplots() 23 | # plt.gca().add_patch(patches.Rectangle(xy=(4, 4),width=120,height=20,color="white",linewidth=1)) 24 | if config.task_name is "MoNuSeg": 25 | predict_save = cv2.pyrUp(predict_save,(448,448)) 26 | predict_save = cv2.resize(predict_save,(2000,2000)) 27 | # kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], np.float32) #定义一个核 28 | # predict_save = cv2.filter2D(predict_save, -1, kernel=kernel) 29 | cv2.imwrite(save_path,predict_save * 255) 30 | else: 31 | cv2.imwrite(save_path,predict_save * 255) 32 | # plt.imshow(predict_save * 255,cmap='gray') 33 | # plt.text(x=10, y=24, s="Dice:" + str(dice_show), fontsize=5) 34 | # plt.axis("off") 35 | # remove the white borders 36 | # height, width = predict_save.shape 37 | # fig.set_size_inches(width / 100.0 / 3.0, height / 100.0 / 3.0) 38 | # plt.gca().xaxis.set_major_locator(plt.NullLocator()) 39 | # plt.gca().yaxis.set_major_locator(plt.NullLocator()) 40 | # plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 41 | # plt.margins(0, 0) 42 | # plt.savefig(save_path, dpi=2000) 43 | # plt.close() 44 | return dice_pred, iou_pred 45 | 46 | def vis_and_save_heatmap(model, input_img, img_RGB, labs, vis_save_path, dice_pred, dice_ens): 47 | model.eval() 48 | 49 | output = model(input_img.cuda()) 50 | pred_class = torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output)) 51 | predict_save = pred_class[0].cpu().data.numpy() 52 | predict_save = np.reshape(predict_save, (config.img_size, config.img_size)) 53 | dice_pred_tmp, iou_tmp = show_image_with_dice(predict_save, labs, save_path=vis_save_path+'_predict'+model_type+'.jpg') 54 | return dice_pred_tmp, iou_tmp 55 | 56 | 57 | 58 | if __name__ == '__main__': 59 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 60 | test_session = config.test_session 61 | if config.task_name is "GlaS": 62 | test_num = 80 63 | model_type = config.model_name 64 | model_path = "./GlaS/"+model_type+"/"+test_session+"/models/best_model-"+model_type+".pth.tar" 65 | 66 | elif config.task_name is "MoNuSeg": 67 | test_num = 14 68 | model_type = config.model_name 69 | model_path = "./MoNuSeg/"+model_type+"/"+test_session+"/models/best_model-"+model_type+".pth.tar" 70 | 71 | 72 | save_path = config.task_name +'/'+ model_type +'/' + test_session + '/' 73 | vis_path = "./" + config.task_name + '_visualize_test/' 74 | if not os.path.exists(vis_path): 75 | os.makedirs(vis_path) 76 | 77 | checkpoint = torch.load(model_path, map_location='cuda') 78 | 79 | 80 | if model_type == 'UCTransNet': 81 | config_vit = config.get_CTranS_config() 82 | model = UCTransNet(config_vit,n_channels=config.n_channels,n_classes=config.n_labels) 83 | 84 | elif model_type == 'UCTransNet_pretrain': 85 | config_vit = config.get_CTranS_config() 86 | model = UCTransNet(config_vit,n_channels=config.n_channels,n_classes=config.n_labels) 87 | 88 | 89 | else: raise TypeError('Please enter a valid name for the model type') 90 | 91 | model = model.cuda() 92 | if torch.cuda.device_count() > 1: 93 | print ("Let's use {0} GPUs!".format(torch.cuda.device_count())) 94 | model = nn.DataParallel(model, device_ids=[0,1,2,3]) 95 | model.load_state_dict(checkpoint['state_dict']) 96 | print('Model loaded !') 97 | tf_test = ValGenerator(output_size=[config.img_size, config.img_size]) 98 | test_dataset = ImageToImage2D(config.test_dataset, tf_test,image_size=config.img_size) 99 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) 100 | 101 | dice_pred = 0.0 102 | iou_pred = 0.0 103 | dice_ens = 0.0 104 | 105 | with tqdm(total=test_num, desc='Test visualize', unit='img', ncols=70, leave=True) as pbar: 106 | for i, (sampled_batch, names) in enumerate(test_loader, 1): 107 | test_data, test_label = sampled_batch['image'], sampled_batch['label'] 108 | arr=test_data.numpy() 109 | arr = arr.astype(np.float32()) 110 | lab=test_label.data.numpy() 111 | img_lab = np.reshape(lab, (lab.shape[1], lab.shape[2])) * 255 112 | fig, ax = plt.subplots() 113 | plt.imshow(img_lab, cmap='gray') 114 | plt.axis("off") 115 | height, width = config.img_size, config.img_size 116 | fig.set_size_inches(width / 100.0 / 3.0, height / 100.0 / 3.0) 117 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 118 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 119 | plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0) 120 | plt.margins(0, 0) 121 | plt.savefig(vis_path+str(i)+"_lab.jpg", dpi=300) 122 | plt.close() 123 | input_img = torch.from_numpy(arr) 124 | dice_pred_t,iou_pred_t = vis_and_save_heatmap(model, input_img, None, lab, 125 | vis_path+str(i), 126 | dice_pred=dice_pred, dice_ens=dice_ens) 127 | dice_pred+=dice_pred_t 128 | iou_pred+=iou_pred_t 129 | torch.cuda.empty_cache() 130 | pbar.update() 131 | print ("dice_pred",dice_pred/test_num) 132 | print ("iou_pred",iou_pred/test_num) 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021/7/8 8:59 上午 3 | # @Author : Haonan Wang 4 | # @File : train.py 5 | # @Software: PyCharm 6 | import torch.optim 7 | from tensorboardX import SummaryWriter 8 | import os 9 | import numpy as np 10 | import random 11 | from torch.backends import cudnn 12 | from Load_Dataset import RandomGenerator,ValGenerator,ImageToImage2D 13 | from nets.UCTransNet import UCTransNet 14 | from torch.utils.data import DataLoader 15 | import logging 16 | from Train_one_epoch import train_one_epoch 17 | import Config as config 18 | from torchvision import transforms 19 | from utils import CosineAnnealingWarmRestarts, WeightedDiceBCE 20 | 21 | def logger_config(log_path): 22 | loggerr = logging.getLogger() 23 | loggerr.setLevel(level=logging.INFO) 24 | handler = logging.FileHandler(log_path, encoding='UTF-8') 25 | handler.setLevel(logging.INFO) 26 | formatter = logging.Formatter('%(message)s') 27 | handler.setFormatter(formatter) 28 | console = logging.StreamHandler() 29 | console.setLevel(logging.INFO) 30 | loggerr.addHandler(handler) 31 | loggerr.addHandler(console) 32 | return loggerr 33 | 34 | def save_checkpoint(state, save_path): 35 | ''' 36 | Save the current model. 37 | If the model is the best model since beginning of the training 38 | it will be copy 39 | ''' 40 | logger.info('\t Saving to {}'.format(save_path)) 41 | if not os.path.isdir(save_path): 42 | os.makedirs(save_path) 43 | 44 | epoch = state['epoch'] # epoch no 45 | best_model = state['best_model'] # bool 46 | model = state['model'] # model type 47 | 48 | if best_model: 49 | filename = save_path + '/' + \ 50 | 'best_model-{}.pth.tar'.format(model) 51 | else: 52 | filename = save_path + '/' + \ 53 | 'model-{}-{:02d}.pth.tar'.format(model, epoch) 54 | torch.save(state, filename) 55 | 56 | def worker_init_fn(worker_id): 57 | random.seed(config.seed + worker_id) 58 | 59 | ################################################################################## 60 | #================================================================================= 61 | # Main Loop: load model, 62 | #================================================================================= 63 | ################################################################################## 64 | def main_loop(batch_size=config.batch_size, model_type='', tensorboard=True): 65 | # Load train and val data 66 | train_tf= transforms.Compose([RandomGenerator(output_size=[config.img_size, config.img_size])]) 67 | val_tf = ValGenerator(output_size=[config.img_size, config.img_size]) 68 | train_dataset = ImageToImage2D(config.train_dataset, train_tf,image_size=config.img_size) 69 | val_dataset = ImageToImage2D(config.val_dataset, val_tf,image_size=config.img_size) 70 | train_loader = DataLoader(train_dataset, 71 | batch_size=config.batch_size, 72 | shuffle=True, 73 | worker_init_fn=worker_init_fn, 74 | num_workers=8, 75 | pin_memory=True) 76 | val_loader = DataLoader(val_dataset, 77 | batch_size=config.batch_size, 78 | shuffle=True, 79 | worker_init_fn=worker_init_fn, 80 | num_workers=8, 81 | pin_memory=True) 82 | 83 | lr = config.learning_rate 84 | logger.info(model_type) 85 | 86 | if model_type == 'UCTransNet': 87 | config_vit = config.get_CTranS_config() 88 | logger.info('transformer head num: {}'.format(config_vit.transformer.num_heads)) 89 | logger.info('transformer layers num: {}'.format(config_vit.transformer.num_layers)) 90 | logger.info('transformer expand ratio: {}'.format(config_vit.expand_ratio)) 91 | model = UCTransNet(config_vit,n_channels=config.n_channels,n_classes=config.n_labels) 92 | 93 | elif model_type == 'UCTransNet_pretrain': 94 | config_vit = config.get_CTranS_config() 95 | logger.info('transformer head num: {}'.format(config_vit.transformer.num_heads)) 96 | logger.info('transformer layers num: {}'.format(config_vit.transformer.num_layers)) 97 | logger.info('transformer expand ratio: {}'.format(config_vit.expand_ratio)) 98 | model = UCTransNet(config_vit,n_channels=config.n_channels,n_classes=config.n_labels) 99 | pretrained_UNet_model_path = "./nets/best_model-UNet.pth.tar" 100 | pretrained_UNet = torch.load(pretrained_UNet_model_path, map_location='cuda') 101 | pretrained_UNet = pretrained_UNet['state_dict'] 102 | model2_dict = model.state_dict() 103 | state_dict = {k:v for k,v in pretrained_UNet.items() if k in model2_dict.keys()} 104 | print(state_dict.keys()) 105 | model2_dict.update(state_dict) 106 | model.load_state_dict(model2_dict) 107 | logger.info('Load successful!') 108 | 109 | else: raise TypeError('Please enter a valid name for the model type') 110 | 111 | 112 | model = model.cuda() 113 | # if torch.cuda.device_count() > 1: 114 | # print ("Let's use {0} GPUs!".format(torch.cuda.device_count())) 115 | # model = nn.DataParallel(model, device_ids=[0]) 116 | criterion = WeightedDiceBCE(dice_weight=0.5,BCE_weight=0.5) 117 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr) # Choose optimize 118 | if config.cosineLR is True: 119 | lr_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-4) 120 | else: 121 | lr_scheduler = None 122 | if tensorboard: 123 | log_dir = config.tensorboard_folder 124 | logger.info('log dir: '.format(log_dir)) 125 | if not os.path.isdir(log_dir): 126 | os.makedirs(log_dir) 127 | writer = SummaryWriter(log_dir) 128 | else: 129 | writer = None 130 | 131 | max_dice = 0.0 132 | best_epoch = 1 133 | for epoch in range(config.epochs): # loop over the dataset multiple times 134 | logger.info('\n========= Epoch [{}/{}] ========='.format(epoch + 1, config.epochs + 1)) 135 | logger.info(config.session_name) 136 | # train for one epoch 137 | model.train(True) 138 | logger.info('Training with batch size : {}'.format(batch_size)) 139 | train_one_epoch(train_loader, model, criterion, optimizer, writer, epoch, None, model_type, logger) 140 | # evaluate on validation set 141 | logger.info('Validation') 142 | with torch.no_grad(): 143 | model.eval() 144 | val_loss, val_dice = train_one_epoch(val_loader, model, criterion, 145 | optimizer, writer, epoch, lr_scheduler,model_type,logger) 146 | 147 | # ============================================================= 148 | # Save best model 149 | # ============================================================= 150 | if val_dice > max_dice: 151 | if epoch+1 > 5: 152 | logger.info('\t Saving best model, mean dice increased from: {:.4f} to {:.4f}'.format(max_dice,val_dice)) 153 | max_dice = val_dice 154 | best_epoch = epoch + 1 155 | save_checkpoint({'epoch': epoch, 156 | 'best_model': True, 157 | 'model': model_type, 158 | 'state_dict': model.state_dict(), 159 | 'val_loss': val_loss, 160 | 'optimizer': optimizer.state_dict()}, config.model_path) 161 | else: 162 | logger.info('\t Mean dice:{:.4f} does not increase, ' 163 | 'the best is still: {:.4f} in epoch {}'.format(val_dice,max_dice, best_epoch)) 164 | early_stopping_count = epoch - best_epoch + 1 165 | logger.info('\t early_stopping_count: {}/{}'.format(early_stopping_count,config.early_stopping_patience)) 166 | 167 | if early_stopping_count > config.early_stopping_patience: 168 | logger.info('\t early_stopping!') 169 | break 170 | 171 | return model 172 | 173 | 174 | if __name__ == '__main__': 175 | deterministic = True 176 | if not deterministic: 177 | cudnn.benchmark = True 178 | cudnn.deterministic = False 179 | else: 180 | cudnn.benchmark = False 181 | cudnn.deterministic = True 182 | random.seed(config.seed) 183 | np.random.seed(config.seed) 184 | torch.manual_seed(config.seed) 185 | torch.cuda.manual_seed(config.seed) 186 | torch.cuda.manual_seed_all(config.seed) 187 | if not os.path.isdir(config.save_path): 188 | os.makedirs(config.save_path) 189 | 190 | logger = logger_config(log_path=config.logger_path) 191 | model = main_loop(model_type=config.model_name, tensorboard=True) 192 | 193 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score,jaccard_score 3 | import cv2 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import math 7 | from functools import wraps 8 | import warnings 9 | import weakref 10 | from torch.optim.optimizer import Optimizer 11 | 12 | class WeightedBCE(nn.Module): 13 | 14 | def __init__(self, weights=[0.4, 0.6]): 15 | super(WeightedBCE, self).__init__() 16 | self.weights = weights 17 | 18 | def forward(self, logit_pixel, truth_pixel): 19 | # print("====",logit_pixel.size()) 20 | logit = logit_pixel.view(-1) 21 | truth = truth_pixel.view(-1) 22 | assert(logit.shape==truth.shape) 23 | loss = F.binary_cross_entropy(logit, truth, reduction='none') 24 | pos = (truth>0.5).float() 25 | neg = (truth<0.5).float() 26 | pos_weight = pos.sum().item() + 1e-12 27 | neg_weight = neg.sum().item() + 1e-12 28 | loss = (self.weights[0]*pos*loss/pos_weight + self.weights[1]*neg*loss/neg_weight).sum() 29 | 30 | return loss 31 | 32 | class WeightedDiceLoss(nn.Module): 33 | def __init__(self, weights=[0.5, 0.5]): # W_pos=0.8, W_neg=0.2 34 | super(WeightedDiceLoss, self).__init__() 35 | self.weights = weights 36 | 37 | def forward(self, logit, truth, smooth=1e-5): 38 | batch_size = len(logit) 39 | logit = logit.view(batch_size,-1) 40 | truth = truth.view(batch_size,-1) 41 | assert(logit.shape==truth.shape) 42 | p = logit.view(batch_size,-1) 43 | t = truth.view(batch_size,-1) 44 | w = truth.detach() 45 | w = w*(self.weights[1]-self.weights[0])+self.weights[0] 46 | # p = w*(p*2-1) #convert to [0,1] --> [-1, 1] 47 | # t = w*(t*2-1) 48 | p = w*(p) 49 | t = w*(t) 50 | intersection = (p * t).sum(-1) 51 | union = (p * p).sum(-1) + (t * t).sum(-1) 52 | dice = 1 - (2*intersection + smooth) / (union +smooth) 53 | # print "------",dice.data 54 | 55 | loss = dice.mean() 56 | return loss 57 | 58 | class WeightedDiceBCE(nn.Module): 59 | def __init__(self,dice_weight=1,BCE_weight=1): 60 | super(WeightedDiceBCE, self).__init__() 61 | self.BCE_loss = WeightedBCE(weights=[0.5, 0.5]) 62 | self.dice_loss = WeightedDiceLoss(weights=[0.5, 0.5]) 63 | self.BCE_weight = BCE_weight 64 | self.dice_weight = dice_weight 65 | 66 | def _show_dice(self, inputs, targets): 67 | inputs[inputs>=0.5] = 1 68 | inputs[inputs<0.5] = 0 69 | # print("2",np.sum(tmp)) 70 | targets[targets>0] = 1 71 | targets[targets<=0] = 0 72 | hard_dice_coeff = 1.0 - self.dice_loss(inputs, targets) 73 | return hard_dice_coeff 74 | 75 | def forward(self, inputs, targets): 76 | # inputs = inputs.contiguous().view(-1) 77 | # targets = targets.contiguous().view(-1) 78 | # print "dice_loss", self.dice_loss(inputs, targets) 79 | # print "focal_loss", self.focal_loss(inputs, targets) 80 | dice = self.dice_loss(inputs, targets) 81 | BCE = self.BCE_loss(inputs, targets) 82 | # print "dice",dice 83 | # print "focal",focal 84 | dice_BCE_loss = self.dice_weight * dice + self.BCE_weight * BCE 85 | 86 | return dice_BCE_loss 87 | 88 | def auc_on_batch(masks, pred): 89 | '''Computes the mean Area Under ROC Curve over a batch during training''' 90 | aucs = [] 91 | for i in range(pred.shape[1]): 92 | prediction = pred[i][0].cpu().detach().numpy() 93 | # print("www",np.max(prediction), np.min(prediction)) 94 | mask = masks[i].cpu().detach().numpy() 95 | # print("rrr",np.max(mask), np.min(mask)) 96 | aucs.append(roc_auc_score(mask.reshape(-1), prediction.reshape(-1))) 97 | return np.mean(aucs) 98 | 99 | def iou_on_batch(masks, pred): 100 | '''Computes the mean Area Under ROC Curve over a batch during training''' 101 | ious = [] 102 | 103 | for i in range(pred.shape[0]): 104 | pred_tmp = pred[i][0].cpu().detach().numpy() 105 | # print("www",np.max(prediction), np.min(prediction)) 106 | mask_tmp = masks[i].cpu().detach().numpy() 107 | pred_tmp[pred_tmp>=0.5] = 1 108 | pred_tmp[pred_tmp<0.5] = 0 109 | # print("2",np.sum(tmp)) 110 | mask_tmp[mask_tmp>0] = 1 111 | mask_tmp[mask_tmp<=0] = 0 112 | # print("rrr",np.max(mask), np.min(mask)) 113 | ious.append(jaccard_score(mask_tmp.reshape(-1), pred_tmp.reshape(-1))) 114 | return np.mean(ious) 115 | 116 | def dice_coef(y_true, y_pred): 117 | smooth = 1e-5 118 | y_true_f = y_true.flatten() 119 | y_pred_f = y_pred.flatten() 120 | intersection = np.sum(y_true_f * y_pred_f) 121 | return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth) 122 | 123 | def dice_on_batch(masks, pred): 124 | '''Computes the mean Area Under ROC Curve over a batch during training''' 125 | dices = [] 126 | 127 | for i in range(pred.shape[0]): 128 | pred_tmp = pred[i][0].cpu().detach().numpy() 129 | # print("www",np.max(prediction), np.min(prediction)) 130 | mask_tmp = masks[i].cpu().detach().numpy() 131 | pred_tmp[pred_tmp>=0.5] = 1 132 | pred_tmp[pred_tmp<0.5] = 0 133 | # print("2",np.sum(tmp)) 134 | mask_tmp[mask_tmp>0] = 1 135 | mask_tmp[mask_tmp<=0] = 0 136 | # print("rrr",np.max(mask), np.min(mask)) 137 | dices.append(dice_coef(mask_tmp, pred_tmp)) 138 | return np.mean(dices) 139 | 140 | def save_on_batch(images1, masks, pred, names, vis_path): 141 | '''Computes the mean Area Under ROC Curve over a batch during training''' 142 | for i in range(pred.shape[0]): 143 | pred_tmp = pred[i][0].cpu().detach().numpy() 144 | mask_tmp = masks[i].cpu().detach().numpy() 145 | pred_tmp[pred_tmp>=0.5] = 255 146 | pred_tmp[pred_tmp<0.5] = 0 147 | mask_tmp[mask_tmp>0] = 255 148 | mask_tmp[mask_tmp<=0] = 0 149 | 150 | cv2.imwrite(vis_path+ names[i][:-4]+"_pred.jpg", pred_tmp) 151 | cv2.imwrite(vis_path+names[i][:-4]+"_gt.jpg", mask_tmp) 152 | 153 | 154 | 155 | class _LRScheduler(object): 156 | 157 | def __init__(self, optimizer, last_epoch=-1): 158 | 159 | # Attach optimizer 160 | if not isinstance(optimizer, Optimizer): 161 | raise TypeError('{} is not an Optimizer'.format( 162 | type(optimizer).__name__)) 163 | self.optimizer = optimizer 164 | 165 | # Initialize epoch and base learning rates 166 | if last_epoch == -1: 167 | for group in optimizer.param_groups: 168 | group.setdefault('initial_lr', group['lr']) 169 | else: 170 | for i, group in enumerate(optimizer.param_groups): 171 | if 'initial_lr' not in group: 172 | raise KeyError("param 'initial_lr' is not specified " 173 | "in param_groups[{}] when resuming an optimizer".format(i)) 174 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 175 | self.last_epoch = last_epoch 176 | 177 | # Following https://github.com/pytorch/pytorch/issues/20124 178 | # We would like to ensure that `lr_scheduler.step()` is called after 179 | # `optimizer.step()` 180 | def with_counter(method): 181 | if getattr(method, '_with_counter', False): 182 | # `optimizer.step()` has already been replaced, return. 183 | return method 184 | 185 | # Keep a weak reference to the optimizer instance to prevent 186 | # cyclic references. 187 | instance_ref = weakref.ref(method.__self__) 188 | # Get the unbound method for the same purpose. 189 | func = method.__func__ 190 | cls = instance_ref().__class__ 191 | del method 192 | 193 | @wraps(func) 194 | def wrapper(*args, **kwargs): 195 | instance = instance_ref() 196 | instance._step_count += 1 197 | wrapped = func.__get__(instance, cls) 198 | return wrapped(*args, **kwargs) 199 | 200 | # Note that the returned function here is no longer a bound method, 201 | # so attributes like `__func__` and `__self__` no longer exist. 202 | wrapper._with_counter = True 203 | return wrapper 204 | 205 | self.optimizer.step = with_counter(self.optimizer.step) 206 | self.optimizer._step_count = 0 207 | self._step_count = 0 208 | 209 | self.step() 210 | 211 | def state_dict(self): 212 | """Returns the state of the scheduler as a :class:`dict`. 213 | 214 | It contains an entry for every variable in self.__dict__ which 215 | is not the optimizer. 216 | """ 217 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 218 | 219 | def load_state_dict(self, state_dict): 220 | """Loads the schedulers state. 221 | 222 | Arguments: 223 | state_dict (dict): scheduler state. Should be an object returned 224 | from a call to :meth:`state_dict`. 225 | """ 226 | self.__dict__.update(state_dict) 227 | 228 | def get_last_lr(self): 229 | """ Return last computed learning rate by current scheduler. 230 | """ 231 | return self._last_lr 232 | 233 | def get_lr(self): 234 | # Compute learning rate using chainable form of the scheduler 235 | raise NotImplementedError 236 | 237 | def step(self, epoch=None): 238 | # Raise a warning if old pattern is detected 239 | # https://github.com/pytorch/pytorch/issues/20124 240 | if self._step_count == 1: 241 | if not hasattr(self.optimizer.step, "_with_counter"): 242 | warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " 243 | "initialization. Please, make sure to call `optimizer.step()` before " 244 | "`lr_scheduler.step()`. See more details at " 245 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) 246 | 247 | # Just check if there were two first lr_scheduler.step() calls before optimizer.step() 248 | elif self.optimizer._step_count < 1: 249 | warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " 250 | "In PyTorch 1.1.0 and later, you should call them in the opposite order: " 251 | "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " 252 | "will result in PyTorch skipping the first value of the learning rate schedule. " 253 | "See more details at " 254 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) 255 | self._step_count += 1 256 | 257 | class _enable_get_lr_call: 258 | 259 | def __init__(self, o): 260 | self.o = o 261 | 262 | def __enter__(self): 263 | self.o._get_lr_called_within_step = True 264 | return self 265 | 266 | def __exit__(self, type, value, traceback): 267 | self.o._get_lr_called_within_step = False 268 | return self 269 | 270 | with _enable_get_lr_call(self): 271 | if epoch is None: 272 | self.last_epoch += 1 273 | values = self.get_lr() 274 | else: 275 | self.last_epoch = epoch 276 | if hasattr(self, "_get_closed_form_lr"): 277 | values = self._get_closed_form_lr() 278 | else: 279 | values = self.get_lr() 280 | 281 | for param_group, lr in zip(self.optimizer.param_groups, values): 282 | param_group['lr'] = lr 283 | 284 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 285 | 286 | class CosineAnnealingWarmRestarts(_LRScheduler): 287 | r"""Set the learning rate of each parameter group using a cosine annealing 288 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` 289 | is the number of epochs since the last restart and :math:`T_{i}` is the number 290 | of epochs between two warm restarts in SGDR: 291 | 292 | .. math:: 293 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + 294 | \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) 295 | 296 | When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. 297 | When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. 298 | 299 | It has been proposed in 300 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. 301 | 302 | Args: 303 | optimizer (Optimizer): Wrapped optimizer. 304 | T_0 (int): Number of iterations for the first restart. 305 | T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. 306 | eta_min (float, optional): Minimum learning rate. Default: 0. 307 | last_epoch (int, optional): The index of last epoch. Default: -1. 308 | 309 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 310 | https://arxiv.org/abs/1608.03983 311 | """ 312 | 313 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): 314 | if T_0 <= 0 or not isinstance(T_0, int): 315 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 316 | if T_mult < 1 or not isinstance(T_mult, int): 317 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) 318 | self.T_0 = T_0 319 | self.T_i = T_0 320 | self.T_mult = T_mult 321 | self.eta_min = eta_min 322 | 323 | super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch) 324 | 325 | self.T_cur = self.last_epoch 326 | 327 | def get_lr(self): 328 | if not self._get_lr_called_within_step: 329 | warnings.warn("To get the last learning rate computed by the scheduler, " 330 | "please use `get_last_lr()`.", DeprecationWarning) 331 | 332 | return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 333 | for base_lr in self.base_lrs] 334 | 335 | def step(self, epoch=None): 336 | """Step could be called after every batch update 337 | 338 | Example: 339 | >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) 340 | >>> iters = len(dataloader) 341 | >>> for epoch in range(20): 342 | >>> for i, sample in enumerate(dataloader): 343 | >>> inputs, labels = sample['inputs'], sample['labels'] 344 | >>> scheduler.step(epoch + i / iters) 345 | >>> optimizer.zero_grad() 346 | >>> outputs = net(inputs) 347 | >>> loss = criterion(outputs, labels) 348 | >>> loss.backward() 349 | >>> optimizer.step() 350 | 351 | This function can be called in an interleaved way. 352 | 353 | Example: 354 | >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) 355 | >>> for epoch in range(20): 356 | >>> scheduler.step() 357 | >>> scheduler.step(26) 358 | >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) 359 | """ 360 | 361 | if epoch is None and self.last_epoch < 0: 362 | epoch = 0 363 | 364 | if epoch is None: 365 | epoch = self.last_epoch + 1 366 | self.T_cur = self.T_cur + 1 367 | if self.T_cur >= self.T_i: 368 | self.T_cur = self.T_cur - self.T_i 369 | self.T_i = self.T_i * self.T_mult 370 | else: 371 | if epoch < 0: 372 | raise ValueError("Expected non-negative epoch, but got {}".format(epoch)) 373 | if epoch >= self.T_0: 374 | if self.T_mult == 1: 375 | self.T_cur = epoch % self.T_0 376 | else: 377 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 378 | self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) 379 | self.T_i = self.T_0 * self.T_mult ** (n) 380 | else: 381 | self.T_i = self.T_0 382 | self.T_cur = epoch 383 | self.last_epoch = math.floor(epoch) 384 | 385 | class _enable_get_lr_call: 386 | 387 | def __init__(self, o): 388 | self.o = o 389 | 390 | def __enter__(self): 391 | self.o._get_lr_called_within_step = True 392 | return self 393 | 394 | def __exit__(self, type, value, traceback): 395 | self.o._get_lr_called_within_step = False 396 | return self 397 | 398 | with _enable_get_lr_call(self): 399 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 400 | param_group['lr'] = lr 401 | 402 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 403 | 404 | --------------------------------------------------------------------------------