├── README.md ├── coreset selection.py ├── dataset ├── train_fewshot_data.txt ├── train_fewshot_data │ └── TRAIN001 │ │ ├── 20.png │ │ ├── 20_gt.png │ │ ├── 21.png │ │ ├── 21_gt.png │ │ ├── 22.png │ │ ├── 22_gt.png │ │ ├── 23.png │ │ ├── 23_gt.png │ │ ├── 24.png │ │ ├── 24_gt.png │ │ ├── 25.png │ │ ├── 25_gt.png │ │ ├── 26.png │ │ ├── 26_gt.png │ │ ├── 27.png │ │ ├── 27_gt.png │ │ ├── 28.png │ │ ├── 28_gt.png │ │ ├── 29.png │ │ ├── 29_gt.png │ │ ├── 30.png │ │ └── 30_gt.png ├── val_fewshot_data.txt └── val_fewshot_data │ └── VAL001 │ ├── 20.png │ ├── 20_gt.png │ ├── 21.png │ ├── 21_gt.png │ ├── 22.png │ ├── 22_gt.png │ ├── 23.png │ ├── 23_gt.png │ ├── 24.png │ ├── 24_gt.png │ ├── 25.png │ ├── 25_gt.png │ ├── 26.png │ ├── 26_gt.png │ ├── 27.png │ ├── 27_gt.png │ ├── 28.png │ ├── 28_gt.png │ ├── 29.png │ ├── 29_gt.png │ ├── 30.png │ ├── 30_gt.png │ ├── 31.png │ ├── 31_gt.png │ ├── 32.png │ ├── 32_gt.png │ ├── 33.png │ └── 33_gt.png ├── images ├── Figure 1.png └── Figure 2.png ├── main.py ├── networks ├── __init__.py ├── unet.py └── unetr.py ├── optimizers ├── __init__.py └── lr_scheduler.py ├── requirements.txt ├── runs ├── __init__.py └── log │ └── log.txt └── utils ├── __init__.py ├── data_utils.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # Annotation-efficient-learning-for-OCT-segmentation 2 | 3 | This repository contains the code for the paper ["Annotation-efficient learning for OCT segmentation"](https://opg.optica.org/boe/fulltext.cfm?uri=boe-14-7-3294&id=531648). We propose an annotation-efficient learning method for OCT segmentation that could significantly reduce annotation costs and improve learning efficiency. Here we provide generative pre-trained transformer-based encoder and CNN-based segmentation decoder, both pretrained on open-access OCTdatasets. The proposed pre-trained model can be directly transfered to your ROI segmeantation based on OCT image. We hope this may help improve the intelligence and application penetration of OCT. 4 | 5 | ![Overview](images/Figure%201.png) 6 | ![Model architecture](images/Figure%202.png) 7 | 8 | ## Dependencies 9 | python==3.8
10 | torch==1.11.1
11 | numpy==1.19.5
12 | monai==0.7.0
13 | timm==0.3.2
14 | tensorboardX==2.1
15 | torchvision==0.12.0
16 | opencv-python==4.5.5
17 | 18 | ## Usage 19 | 1. Clone the repository: 20 | ``` 21 | git clone https://github.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation.git 22 | ``` 23 | 24 | 2. Install the required dependencies: 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | 3. Download the pre-trained [phase1 model file](https://drive.google.com/file/d/1JHdL1HRZM86n4761uoO3_6N4p4zqBDmx/view?usp=sharing) for weights of encoder and [phase2 model file](https://drive.google.com/file/d/1gOihHsH4-GAtS6R6wzxkOQsaMAF6fj_h/view?usp=sharing) for weights of decoder, and then put them in `./runs/` folder. 29 | 30 | 4. Edit suitable path and parameters in main.py 31 | 32 | 5. Go to the corresponding folder and run: 33 | ``` 34 | cd Annotation-efficient-learning-for-OCT-segmentation 35 | python main.py 36 | ``` 37 | 38 | ## Training on your Dataset 39 | The prepared architecture of dataset is referenced to `./dataset/` folder containing `train_fewshot_data` and `val_fewshot_data`. The name index of images is listed in `train_fewshot_data.txt` and `val_fewshot_data.txt`. 40 | 41 | ## Citation 42 | ``` 43 | @article{ 44 | title={Annotation-efficient learning for OCT segmentation}, 45 | author={Zhang, Haoran and Yang, Jianlong and Zheng, Ce and Zhao, Shiqing and Zhang, Aili}, 46 | journal={Biomedical Optics Express}, 47 | volume={14}, 48 | number={7}, 49 | pages={3294--3307}, 50 | year={2023}, 51 | publisher={Optica Publishing Group} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /coreset selection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def euclidean_dist(x, y): 5 | m, n = x.size(0), y.size(0) 6 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 7 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 8 | dist = xx + yy 9 | dist.addmm_(1, -2, x, y.t()) 10 | dist = dist.clamp(min=1e-12).sqrt() 11 | return dist 12 | 13 | def coreset_selection(matrix, budget: int, metric, device, random_seed=None, index=None, already_selected=None, 14 | print_freq: int = 20): 15 | if type(matrix) == torch.Tensor: 16 | assert matrix.dim() == 2 17 | elif type(matrix) == np.ndarray: 18 | assert matrix.ndim == 2 19 | matrix = torch.from_numpy(matrix).requires_grad_(False).to(device) 20 | 21 | sample_num = matrix.shape[0] 22 | assert sample_num >= 1 23 | 24 | if budget < 0: 25 | raise ValueError("Illegal budget size.") 26 | elif budget > sample_num: 27 | budget = sample_num 28 | 29 | if index is not None: 30 | assert matrix.shape[0] == len(index) 31 | else: 32 | index = np.arange(sample_num) 33 | 34 | assert callable(metric) 35 | 36 | already_selected = np.array(already_selected) 37 | 38 | with torch.no_grad(): 39 | np.random.seed(random_seed) 40 | if already_selected.__len__() == 0: 41 | select_result = np.zeros(sample_num, dtype=bool) 42 | # Randomly select one initial point. 43 | already_selected = [np.random.randint(0, sample_num)] 44 | budget -= 1 45 | select_result[already_selected] = True 46 | else: 47 | select_result = np.in1d(index, already_selected) 48 | 49 | num_of_already_selected = np.sum(select_result)# =1 50 | 51 | # Initialize a (num_of_already_selected+budget-1)*sample_num matrix storing distances of pool points from 52 | # each clustering center. 53 | dis_matrix = -1 * torch.ones([num_of_already_selected + budget - 1, sample_num], requires_grad=False).to(device) 54 | 55 | dis_matrix[:num_of_already_selected, ~select_result] = metric(matrix[select_result], matrix[~select_result]) 56 | 57 | mins = torch.min(dis_matrix[:num_of_already_selected, :], dim=0).values#每个点和对应的already_selected最小值 58 | 59 | for i in range(budget): 60 | if i % print_freq == 0: 61 | print("| Selecting [%3d/%3d]" % (i + 1, budget)) 62 | p = torch.argmax(mins).item() 63 | select_result[p] = True 64 | 65 | if i == budget - 1: 66 | break 67 | mins[p] = -1 68 | dis_matrix[num_of_already_selected + i, ~select_result] = metric(matrix[[p]], matrix[~select_result]) 69 | mins = torch.min(mins, dis_matrix[num_of_already_selected + i]) 70 | return index[select_result] 71 | 72 | def main(txtName, numpyName, caseList, budget, device, seed): 73 | """ 74 | txtName: The .txt contains the ordinal name list of images. For example:"TRAIN001/001.png" 75 | numpyName: The .np file containsthe ordinal 1-D feature maps of image (reshaped from 2-D feature map), 76 | which are acquired from the output of encoder. The size of it is (number of images, feature dimensions). 77 | caseList: The ordinal list of case name. For example: ['TRAIN001', 'TRAIN002', 'TRAIN003', 'TRAIN04']. 78 | budget: The number of images for core-set selection. 79 | device: The cuda number or cpu. 80 | seed: The random seed decides the initial point. 81 | """ 82 | 83 | #loading txt list 84 | with open(txtName, mode='r') as F: 85 | imageList = F.readlines() 86 | #load image features 87 | already_selected = [] 88 | matrix = np.load(numpyName) 89 | matrix = np.array(matrix, dtype=np.float32) 90 | print("number of images total:",matrix.shape[0],", feature dimension:",matrix.shape[1]) 91 | 92 | for case in caseList: 93 | indexlist = [] 94 | featurelist = [] 95 | for i, imageName in enumerate(imageList): 96 | if case in imageName: 97 | indexlist.append(i) 98 | featurelist.append(matrix[i]) 99 | x = torch.from_numpy(np.array(featurelist)) 100 | y = torch.from_numpy(np.array(featurelist)) 101 | dis_matrix = euclidean_dist(x, y) 102 | min = int(torch.argmin(torch.sum(dis_matrix, dim=0), dim=0).numpy()) 103 | already_selected.append(indexlist[min]) 104 | 105 | subset = coreset_selection(matrix, budget=int(np.round(budget)), 106 | metric=euclidean_dist, device=device, 107 | random_seed=seed, 108 | already_selected=np.array(already_selected)) 109 | 110 | print("{} images has been selected as for core-set!".format(len(subset))) 111 | print("The index of selected images are as follows:") 112 | print(subset) 113 | 114 | if __name__ == '__main__': 115 | 116 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 117 | txtName = './train_fewshot_data.txt' 118 | numpyName = 'imageFeature.npy' 119 | caseList = ['TRAIN001', 'TRAIN002', 'TRAIN003'] 120 | budget = 50 121 | seed=700 122 | 123 | main(txtName, numpyName, caseList, budget, device, seed) 124 | -------------------------------------------------------------------------------- /dataset/train_fewshot_data.txt: -------------------------------------------------------------------------------- 1 | TRAIN001/20.png 2 | TRAIN001/21.png 3 | TRAIN001/22.png 4 | TRAIN001/23.png 5 | TRAIN001/24.png 6 | TRAIN001/25.png 7 | TRAIN001/26.png 8 | TRAIN001/27.png 9 | TRAIN001/28.png 10 | TRAIN001/29.png 11 | TRAIN001/30.png 12 | -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/20.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/20_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/20_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/21.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/21_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/21_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/22.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/22_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/22_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/23.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/23_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/23_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/24.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/24_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/24_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/25.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/25_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/25_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/26.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/26_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/26_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/27.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/27_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/27_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/28.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/28_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/28_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/29.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/29_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/29_gt.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/30.png -------------------------------------------------------------------------------- /dataset/train_fewshot_data/TRAIN001/30_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/30_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data.txt: -------------------------------------------------------------------------------- 1 | VAL001/20.png 2 | VAL001/21.png 3 | VAL001/22.png 4 | VAL001/23.png 5 | VAL001/24.png 6 | VAL001/25.png 7 | VAL001/26.png 8 | VAL001/27.png 9 | VAL001/28.png 10 | VAL001/29.png 11 | VAL001/30.png 12 | VAL001/31.png 13 | VAL001/32.png 14 | VAL001/33.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/20.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/20_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/20_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/21.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/21_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/21_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/22.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/22_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/22_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/23.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/23_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/23_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/24.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/24_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/24_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/25.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/25_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/25_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/26.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/26_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/26_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/27.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/27_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/27_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/28.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/28_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/28_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/29.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/29_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/29_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/30.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/30_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/30_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/31.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/31_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/31_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/32.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/32_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/32_gt.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/33.png -------------------------------------------------------------------------------- /dataset/val_fewshot_data/VAL001/33_gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/33_gt.png -------------------------------------------------------------------------------- /images/Figure 1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/images/Figure 1.png -------------------------------------------------------------------------------- /images/Figure 2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/images/Figure 2.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn.parallel 6 | import torch.utils.data.distributed 7 | import networks.unetr 8 | from utils.trainer import run_training, val_epoch 9 | from utils.data_utils import get_loader 10 | from networks.unet import UNet 11 | from monai.losses import DiceCELoss 12 | from monai.metrics import DiceMetric 13 | from monai.transforms import AsDiscrete 14 | from monai.utils.enums import MetricReduction 15 | import random 16 | 17 | def main(train,val): 18 | parser = argparse.ArgumentParser(description="UNETR segmentation pipeline") 19 | parser.add_argument("--test_mode", default=False, type=bool, help="test mode or not") 20 | parser.add_argument("--label_index",default=[255],type=list, help="index of label for foreground segmentation") 21 | parser.add_argument("--phase2_weight", default="./runs/phase2 model.pth", type=str,help="load phase2 model weight") 22 | parser.add_argument("--phase1_weight", default="./runs/phase1 model.pth", type=str, help="load phase1 model weight") 23 | parser.add_argument("--logdir", default="./runs/log", type=str, help="directory to save the tensorboard logs") 24 | parser.add_argument('--list_dir', type=str, default='./dataset', help='list dir') 25 | parser.add_argument("--data_dir", default="./dataset", type=str, help="dataset directory") 26 | parser.add_argument("--train_text", default=train, type=str, help="training image list text name") 27 | parser.add_argument("--val_text", default=val, type=str, help="validating image list text name") 28 | parser.add_argument('--model_name', default='vit_large_patch16', type=str, metavar='MODEL',help='Name of model to train') 29 | parser.add_argument("--save_checkpoint", default=True,type=bool, help="save checkpoint during training") 30 | parser.add_argument("--max_epochs", default=100, type=int, help="max number of training epochs") 31 | parser.add_argument("--batch_size", default=4, type=int, help="number of batch size") 32 | parser.add_argument("--optim_lr", default=0.5*1e-4, type=float, help="optimization learning rate") 33 | parser.add_argument("--optim_name", default="adam", type=str, help="optimization algorithm")#adamw 34 | parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight") 35 | parser.add_argument("--momentum", default=0.99, type=float, help="momentum") 36 | parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training") 37 | parser.add_argument("--val_every", default=1, type=int, help="validation frequency") 38 | parser.add_argument("--distributed", action="store_true", help="start distributed training") 39 | parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training") 40 | parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training") 41 | parser.add_argument("--num_workers", default=2, type=int, help="number of workers") 42 | parser.add_argument("--norm_name", default="instance", type=str, help="normalization layer type in decoder") 43 | parser.add_argument("--feature_size", default=64, type=int, help="feature size dimention")# 44 | parser.add_argument("--in_channels", default=3, type=int, help="number of input channels") 45 | parser.add_argument("--out_channels", default=1, type=int, help="number of output channels") 46 | parser.add_argument("--res_block", default=True,type=bool, help="use residual blocks") 47 | parser.add_argument("--conv_block", default=True,type=bool, help="use conv blocks") 48 | parser.add_argument('--input_size', default=224, type=int,help='images input size')# 49 | parser.add_argument("--dropout_rate", default=0.4, type=float, help="dropout rate") 50 | parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler") 51 | parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") 52 | parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan") 53 | parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero") 54 | 55 | args = parser.parse_args() 56 | args.amp = not args.noamp 57 | 58 | main_worker(gpu=0, args=args) 59 | seed = 666 60 | torch.manual_seed(seed) 61 | torch.cuda.manual_seed_all(seed) 62 | np.random.seed(seed) 63 | random.seed(seed) 64 | torch.backends.cudnn.deterministic = True 65 | 66 | 67 | def main_worker(gpu, args): 68 | np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True) 69 | args.gpu = gpu 70 | torch.cuda.set_device(args.gpu) 71 | torch.backends.cudnn.benchmark = True 72 | 73 | loader = get_loader(args) 74 | print(args.rank, " gpu", args.gpu) 75 | if args.rank == 0: 76 | print("Batch size is:", args.batch_size, "epochs", args.max_epochs) 77 | 78 | 79 | if (args.model_name is None) or args.model_name in ['vit_base_patch16','vit_large_patch16','vit_huge_patch14']: 80 | model = networks.unetr.__dict__[args.model_name]( 81 | in_channels=args.in_channels, 82 | out_channels=args.out_channels, 83 | img_size=args.input_size, 84 | feature_size=args.feature_size, 85 | norm_name=args.norm_name, 86 | conv_block=True, 87 | res_block=True, 88 | dropout_rate=args.dropout_rate, 89 | num_classes=1000) 90 | 91 | # loading weights 92 | if not args.test_mode: 93 | if args.phase1_weight is not None and args.phase2_weight is None:# training phase 2 94 | checkpoint = torch.load(args.phase1_weight,map_location='cpu') 95 | checkpoint_model = checkpoint['model'] 96 | print(checkpoint['model']) 97 | model.load_state_dict(checkpoint_model, strict=False) 98 | print("Use pretrained weights") 99 | #print(model.state_dict()) 100 | elif args.phase1_weight is not None and args.phase2_weight is not None: # training phase 4 101 | #load weight of phase2 model first (for decoder) 102 | assert args.phase2_weight != None, 'No segmentation model weights loaded' 103 | model_dict = torch.load(args.phase2_weight, map_location="cpu") 104 | print(model_dict['state_dict']) 105 | model.load_state_dict(model_dict['state_dict'], strict=True) 106 | # then load weight of phase1 model (for encoder) 107 | checkpoint = torch.load(args.phase1_weight,map_location='cpu') 108 | checkpoint_model = checkpoint['model'] 109 | print(checkpoint_model) 110 | model.load_state_dict(checkpoint_model, strict=False) 111 | 112 | 113 | else:# test mode, load best model 114 | assert args.phase2_weight != None, 'No segmentation model weights loaded' 115 | model_dict = torch.load(args.phase2_weight, map_location="cpu") 116 | model.load_state_dict(model_dict['state_dict']) 117 | 118 | elif args.model_name =='unet': 119 | model = UNet(n_channels = args.in_channels, n_classes = args.out_channels, bilinear=True) 120 | else: 121 | raise ValueError("Unsupported model " + str(args.model_name)) 122 | 123 | #setup loss 124 | dice_loss = DiceCELoss( 125 | to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr) 126 | 127 | post_label = AsDiscrete(to_onehot=True, n_classes=args.out_channels) 128 | post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=args.out_channels) 129 | dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN, get_not_nans=True) 130 | 131 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 132 | print("Total parameters count", pytorch_total_params) 133 | 134 | best_acc = 0 135 | start_epoch = 0 136 | 137 | model.cuda(args.gpu) 138 | 139 | if args.optim_name == "adam": 140 | optimizer = torch.optim.Adam(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight) 141 | elif args.optim_name == "adamw": 142 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight) 143 | elif args.optim_name == "sgd": 144 | optimizer = torch.optim.SGD( 145 | model.parameters(), lr=args.optim_lr, momentum=args.momentum, nesterov=True, weight_decay=args.reg_weight 146 | ) 147 | else: 148 | raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name)) 149 | 150 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, 151 | patience=5) # goal: maximize Dice score 152 | 153 | if args.model_name =='unet': 154 | optimizer = torch.optim.RMSprop(model.parameters(), lr=args.optim_lr, weight_decay=1e-8, momentum=0.99) 155 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5) # goal: maximize Dice score 156 | 157 | if not args.test_mode: 158 | accuracy = run_training( 159 | model=model, 160 | train_loader=loader[0], 161 | val_loader=loader[1], 162 | optimizer=optimizer, 163 | loss_func=dice_loss, 164 | acc_func=dice_acc, 165 | args=args, 166 | model_inferer=None, 167 | scheduler=scheduler, 168 | start_epoch=start_epoch, 169 | post_label=post_label, 170 | post_pred=post_pred, 171 | ) 172 | else: 173 | accuracy, run_loss = val_epoch( 174 | model=model, 175 | loader=loader, 176 | args=args 177 | ) 178 | print("final acc:", accuracy, 'loss:', run_loss) 179 | return accuracy 180 | 181 | 182 | if __name__ == "__main__": 183 | 184 | train = 'train_fewshot_data' 185 | val = 'val_fewshot_data' 186 | main(train,val) 187 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/networks/__init__.py -------------------------------------------------------------------------------- /networks/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class UNet(nn.Module): 6 | def __init__(self, n_channels, n_classes, bilinear=False): 7 | super(UNet, self).__init__() 8 | self.n_channels = n_channels 9 | self.n_classes = n_classes 10 | self.bilinear = bilinear 11 | 12 | self.inc = DoubleConv(n_channels, 64) 13 | self.down1 = Down(64, 128) 14 | self.down2 = Down(128, 256) 15 | self.down3 = Down(256, 512) 16 | factor = 2 if bilinear else 1 17 | self.down4 = Down(512, 1024 // factor) 18 | self.up1 = Up(1024, 512 // factor, bilinear) 19 | self.up2 = Up(512, 256 // factor, bilinear) 20 | self.up3 = Up(256, 128 // factor, bilinear) 21 | self.up4 = Up(128, 64, bilinear) 22 | self.outc = OutConv(64, n_classes) 23 | 24 | def forward(self, x): 25 | x1 = self.inc(x) 26 | x2 = self.down1(x1) 27 | x3 = self.down2(x2) 28 | x4 = self.down3(x3) 29 | x5 = self.down4(x4) 30 | x = self.up1(x5, x4) 31 | x = self.up2(x, x3) 32 | x = self.up3(x, x2) 33 | x = self.up4(x, x1) 34 | logits = self.outc(x) 35 | return logits 36 | 37 | 38 | class DoubleConv(nn.Module): 39 | """(convolution => [BN] => ReLU) * 2""" 40 | 41 | def __init__(self, in_channels, out_channels, mid_channels=None): 42 | super().__init__() 43 | if not mid_channels: 44 | mid_channels = out_channels 45 | self.double_conv = nn.Sequential( 46 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 47 | nn.BatchNorm2d(mid_channels), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 50 | nn.BatchNorm2d(out_channels), 51 | nn.ReLU(inplace=True) 52 | ) 53 | 54 | def forward(self, x): 55 | return self.double_conv(x) 56 | 57 | 58 | class Down(nn.Module): 59 | """Downscaling with maxpool then double conv""" 60 | 61 | def __init__(self, in_channels, out_channels): 62 | super().__init__() 63 | self.maxpool_conv = nn.Sequential( 64 | nn.MaxPool2d(2), 65 | DoubleConv(in_channels, out_channels) 66 | ) 67 | 68 | def forward(self, x): 69 | return self.maxpool_conv(x) 70 | 71 | 72 | class Up(nn.Module): 73 | """Upscaling then double conv""" 74 | 75 | def __init__(self, in_channels, out_channels, bilinear=True): 76 | super().__init__() 77 | 78 | # if bilinear, use the normal convolutions to reduce the number of channels 79 | if bilinear: 80 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 81 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 82 | else: 83 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 84 | self.conv = DoubleConv(in_channels, out_channels) 85 | 86 | def forward(self, x1, x2): 87 | x1 = self.up(x1) 88 | # input is CHW 89 | diffY = x2.size()[2] - x1.size()[2] 90 | diffX = x2.size()[3] - x1.size()[3] 91 | 92 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 93 | diffY // 2, diffY - diffY // 2]) 94 | # if you have padding issues, see 95 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 96 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 97 | x = torch.cat([x2, x1], dim=1) 98 | return self.conv(x) 99 | 100 | 101 | class OutConv(nn.Module): 102 | def __init__(self, in_channels, out_channels): 103 | super(OutConv, self).__init__() 104 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 105 | 106 | def forward(self, x): 107 | return self.conv(x) -------------------------------------------------------------------------------- /networks/unetr.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Tuple, Union 3 | import timm 4 | import torch 5 | import torch.nn as nn 6 | from timm.models.vision_transformer import PatchEmbed, Block 7 | from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock 8 | from monai.networks.blocks.dynunet_block import UnetOutBlock 9 | 10 | 11 | def vit_base_patch16(**kwargs): 12 | model = VisionTransformer( 13 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 14 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 15 | return model 16 | 17 | 18 | def vit_large_patch16(**kwargs): 19 | model = VisionTransformer( 20 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 21 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 22 | return model 23 | 24 | 25 | def vit_huge_patch14(**kwargs): 26 | model = VisionTransformer( 27 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 28 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 29 | return model 30 | 31 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 32 | 33 | def __init__( 34 | self, 35 | in_channels: int = 3, 36 | out_channels: int = 1, 37 | img_size: int = 224, 38 | feature_size: int = 64, 39 | #hidden_size: int = 768, 40 | #mlp_dim: int = 3072, 41 | #num_heads: int = 12, 42 | #pos_embed: str = "perceptron", 43 | norm_name: Union[Tuple, str] = "instance", 44 | conv_block: bool = False, 45 | res_block: bool = True, 46 | dropout_rate: float = 0.0, 47 | 48 | patch_size=16, 49 | embed_dim=1024, 50 | depth=24, 51 | num_heads=16, 52 | mlp_ratio=4, 53 | qkv_bias=True, 54 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 55 | **kwargs 56 | ) -> None: 57 | """ 58 | Args: 59 | in_channels: dimension of input channels. 60 | out_channels: dimension of output channels. 61 | img_size: dimension of input image. 62 | feature_size: dimension of network feature size. 63 | hidden_size: dimension of hidden layer. 64 | mlp_dim: dimension of feedforward layer. 65 | num_heads: number of attention heads. 66 | pos_embed: position embedding layer type. 67 | norm_name: feature normalization type and arguments. 68 | conv_block: bool argument to determine if convolutional block is used. 69 | res_block: bool argument to determine if residual block is used. 70 | dropout_rate: faction of the input units to drop. 71 | """ 72 | super(VisionTransformer, self).__init__(embed_dim=768,**kwargs) 73 | 74 | #define vit 75 | # MAE encoder specifics 76 | self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim) 77 | num_patches = self.patch_embed.num_patches 78 | 79 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 80 | 81 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 82 | 83 | self.blocks = nn.ModuleList([ 84 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, attn_drop=dropout_rate) 85 | for i in range(depth)]) 86 | self.norm = norm_layer(embed_dim) 87 | 88 | #define CNN 89 | self.patch_size = (patch_size, patch_size) 90 | self.feat_size = ( 91 | img_size // self.patch_size[0], 92 | img_size // self.patch_size[1]) 93 | self.hidden_size = embed_dim 94 | self.classification = False 95 | 96 | self.encoder1 = UnetrBasicBlock( 97 | spatial_dims=2, 98 | in_channels=in_channels, 99 | out_channels=feature_size, 100 | kernel_size=3, 101 | stride=1, 102 | norm_name=norm_name, 103 | res_block=res_block, 104 | ) 105 | self.encoder2 = UnetrPrUpBlock( 106 | spatial_dims=2, 107 | in_channels=embed_dim, 108 | out_channels=feature_size * 2, 109 | num_layer=2, 110 | kernel_size=3, 111 | stride=1, 112 | upsample_kernel_size=2, 113 | norm_name=norm_name, 114 | conv_block=conv_block, 115 | res_block=res_block, 116 | ) 117 | self.encoder3 = UnetrPrUpBlock( 118 | spatial_dims=2, 119 | in_channels=embed_dim, 120 | out_channels=feature_size * 4, 121 | num_layer=1, 122 | kernel_size=3, 123 | stride=1, 124 | upsample_kernel_size=2, 125 | norm_name=norm_name, 126 | conv_block=conv_block, 127 | res_block=res_block, 128 | ) 129 | self.encoder4 = UnetrPrUpBlock( 130 | spatial_dims=2, 131 | in_channels=embed_dim, 132 | out_channels=feature_size * 8, 133 | num_layer=0, 134 | kernel_size=3, 135 | stride=1, 136 | upsample_kernel_size=2, 137 | norm_name=norm_name, 138 | conv_block=conv_block, 139 | res_block=res_block, 140 | ) 141 | self.decoder5 = UnetrUpBlock( 142 | spatial_dims=2, 143 | in_channels=embed_dim, 144 | out_channels=feature_size * 8, 145 | kernel_size=3, 146 | upsample_kernel_size=2, 147 | norm_name=norm_name, 148 | res_block=res_block, 149 | ) 150 | self.decoder4 = UnetrUpBlock( 151 | spatial_dims=2, 152 | in_channels=feature_size * 8, 153 | out_channels=feature_size * 4, 154 | kernel_size=3, 155 | upsample_kernel_size=2, 156 | norm_name=norm_name, 157 | res_block=res_block, 158 | ) 159 | self.decoder3 = UnetrUpBlock( 160 | spatial_dims=2, 161 | in_channels=feature_size * 4, 162 | out_channels=feature_size * 2, 163 | kernel_size=3, 164 | upsample_kernel_size=2, 165 | norm_name=norm_name, 166 | res_block=res_block, 167 | ) 168 | self.decoder2 = UnetrUpBlock( 169 | spatial_dims=2, 170 | in_channels=feature_size * 2, 171 | out_channels=feature_size, 172 | kernel_size=3, 173 | upsample_kernel_size=2, 174 | norm_name=norm_name, 175 | res_block=res_block, 176 | ) 177 | self.out = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore 178 | 179 | def proj_feat(self, x, hidden_size, feat_size): 180 | x = x.view(x.size(0), feat_size[0], feat_size[1], hidden_size) 181 | x = x.permute(0, 3, 1, 2).contiguous() 182 | return x 183 | 184 | def load_from(self, weights): 185 | with torch.no_grad(): 186 | res_weight = weights 187 | # copy weights from patch embedding 188 | for i in weights["state_dict"]: 189 | print(i) 190 | self.vit.patch_embedding.position_embeddings.copy_( 191 | weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"] 192 | ) 193 | self.vit.patch_embedding.cls_token.copy_( 194 | weights["state_dict"]["module.transformer.patch_embedding.cls_token"] 195 | ) 196 | self.vit.patch_embedding.patch_embeddings[1].weight.copy_( 197 | weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"] 198 | ) 199 | self.vit.patch_embedding.patch_embeddings[1].bias.copy_( 200 | weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"] 201 | ) 202 | 203 | # copy weights from encoding blocks (default: num of blocks: 12) 204 | for bname, block in self.vit.blocks.named_children(): 205 | print(block) 206 | block.loadFrom(weights, n_block=bname) 207 | # last norm layer of transformer 208 | self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"]) 209 | self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"]) 210 | 211 | def forward(self, x_in): 212 | #x, hidden_states_out = self.vit(x_in) 213 | #define IVT 214 | # embed patches 215 | x = self.patch_embed(x_in) 216 | # add pos embed w/o cls token 217 | x = x + self.pos_embed[:, 1:, :] 218 | 219 | # append cls token 220 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 221 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 222 | x = torch.cat((cls_tokens, x), dim=1) 223 | 224 | # apply Transformer blocks 225 | hidden_states_out = [] 226 | for blk in self.blocks: 227 | x = blk(x) 228 | hidden_states_out.append(x[:,1:,:]) 229 | x = self.norm(x) 230 | x = x[:,1:,:] 231 | 232 | #define CNN 233 | enc1 = self.encoder1(x_in) 234 | x2 = hidden_states_out[6] 235 | enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) 236 | x3 = hidden_states_out[12] 237 | enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) 238 | x4 = hidden_states_out[18] 239 | enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) 240 | 241 | dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) 242 | dec3 = self.decoder5(dec4, enc4) 243 | dec2 = self.decoder4(dec3, enc3) 244 | dec1 = self.decoder3(dec2, enc2) 245 | out = self.decoder2(dec1, enc1) 246 | logits = self.out(out) 247 | 248 | return logits 249 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/optimizers/__init__.py -------------------------------------------------------------------------------- /optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import math 13 | import warnings 14 | from typing import List 15 | 16 | from torch import nn as nn 17 | from torch.optim import Adam, Optimizer 18 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler 19 | 20 | __all__ = ["LinearLR", "ExponentialLR"] 21 | 22 | 23 | class _LRSchedulerMONAI(_LRScheduler): 24 | """Base class for increasing the learning rate between two boundaries over a number 25 | of iterations""" 26 | 27 | def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: 28 | """ 29 | Args: 30 | optimizer: wrapped optimizer. 31 | end_lr: the final learning rate. 32 | num_iter: the number of iterations over which the test occurs. 33 | last_epoch: the index of last epoch. 34 | Returns: 35 | None 36 | """ 37 | self.end_lr = end_lr 38 | self.num_iter = num_iter 39 | super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) 40 | 41 | 42 | class LinearLR(_LRSchedulerMONAI): 43 | """Linearly increases the learning rate between two boundaries over a number of 44 | iterations. 45 | """ 46 | 47 | def get_lr(self): 48 | r = self.last_epoch / (self.num_iter - 1) 49 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] 50 | 51 | 52 | class ExponentialLR(_LRSchedulerMONAI): 53 | """Exponentially increases the learning rate between two boundaries over a number of 54 | iterations. 55 | """ 56 | 57 | def get_lr(self): 58 | r = self.last_epoch / (self.num_iter - 1) 59 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] 60 | 61 | 62 | class WarmupCosineSchedule(LambdaLR): 63 | """Linear warmup and then cosine decay. 64 | Based on https://huggingface.co/ implementation. 65 | """ 66 | 67 | def __init__( 68 | self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1 69 | ) -> None: 70 | """ 71 | Args: 72 | optimizer: wrapped optimizer. 73 | warmup_steps: number of warmup iterations. 74 | t_total: total number of training iterations. 75 | cycles: cosine cycles parameter. 76 | last_epoch: the index of last epoch. 77 | Returns: 78 | None 79 | """ 80 | self.warmup_steps = warmup_steps 81 | self.t_total = t_total 82 | self.cycles = cycles 83 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) 84 | 85 | def lr_lambda(self, step): 86 | if step < self.warmup_steps: 87 | return float(step) / float(max(1.0, self.warmup_steps)) 88 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 89 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 90 | 91 | 92 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 93 | def __init__( 94 | self, 95 | optimizer: Optimizer, 96 | warmup_epochs: int, 97 | max_epochs: int, 98 | warmup_start_lr: float = 0.0, 99 | eta_min: float = 0.0, 100 | last_epoch: int = -1, 101 | ) -> None: 102 | """ 103 | Args: 104 | optimizer (Optimizer): Wrapped optimizer. 105 | warmup_epochs (int): Maximum number of iterations for linear warmup 106 | max_epochs (int): Maximum number of iterations 107 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 108 | eta_min (float): Minimum learning rate. Default: 0. 109 | last_epoch (int): The index of last epoch. Default: -1. 110 | """ 111 | self.warmup_epochs = warmup_epochs 112 | self.max_epochs = max_epochs 113 | self.warmup_start_lr = warmup_start_lr 114 | self.eta_min = eta_min 115 | 116 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 117 | 118 | def get_lr(self) -> List[float]: 119 | """ 120 | Compute learning rate using chainable form of the scheduler 121 | """ 122 | if not self._get_lr_called_within_step: 123 | warnings.warn( 124 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning 125 | ) 126 | 127 | if self.last_epoch == 0: 128 | return [self.warmup_start_lr] * len(self.base_lrs) 129 | elif self.last_epoch < self.warmup_epochs: 130 | return [ 131 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 132 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 133 | ] 134 | elif self.last_epoch == self.warmup_epochs: 135 | return self.base_lrs 136 | elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 137 | return [ 138 | group["lr"] 139 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 140 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 141 | ] 142 | 143 | return [ 144 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 145 | / ( 146 | 1 147 | + math.cos( 148 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) 149 | ) 150 | ) 151 | * (group["lr"] - self.eta_min) 152 | + self.eta_min 153 | for group in self.optimizer.param_groups 154 | ] 155 | 156 | def _get_closed_form_lr(self) -> List[float]: 157 | """ 158 | Called when epoch is passed as a param to the `step` function of the scheduler. 159 | """ 160 | if self.last_epoch < self.warmup_epochs: 161 | return [ 162 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 163 | for base_lr in self.base_lrs 164 | ] 165 | 166 | return [ 167 | self.eta_min 168 | + 0.5 169 | * (base_lr - self.eta_min) 170 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 171 | for base_lr in self.base_lrs 172 | ] 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.8 2 | torch==1.11.1 3 | numpy==1.19.5 4 | monai==0.7.0 5 | timm==0.3.2 6 | tensorboardX==2.1 7 | torchvision==0.12.0 8 | opencv-python==4.5.5 9 | -------------------------------------------------------------------------------- /runs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/runs/__init__.py -------------------------------------------------------------------------------- /runs/log/log.txt: -------------------------------------------------------------------------------- 1 | Final training:0/99,loss:2.14935310681661 2 | Final validation:0/99,dice:0.021508501284427224,loss:1.7313948571681976, 3 | Final training:0/99,loss:2.1603429714838662 4 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import os 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from monai import data, transforms 8 | import PIL 9 | import torchvision 10 | from PIL import Image 11 | from torchvision import transforms 12 | from torch.utils.data import Dataset 13 | 14 | class Sampler(torch.utils.data.Sampler): 15 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True): 16 | if num_replicas is None: 17 | if not torch.distributed.is_available(): 18 | raise RuntimeError("Requires distributed package to be available") 19 | num_replicas = torch.distributed.get_world_size() 20 | if rank is None: 21 | if not torch.distributed.is_available(): 22 | raise RuntimeError("Requires distributed package to be available") 23 | rank = torch.distributed.get_rank() 24 | self.shuffle = shuffle 25 | self.make_even = make_even 26 | self.dataset = dataset 27 | self.num_replicas = num_replicas 28 | self.rank = rank 29 | self.epoch = 0 30 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 31 | self.total_size = self.num_samples * self.num_replicas 32 | indices = list(range(len(self.dataset))) 33 | self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas]) 34 | 35 | def __iter__(self): 36 | if self.shuffle: 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 40 | else: 41 | indices = list(range(len(self.dataset))) 42 | if self.make_even: 43 | if len(indices) < self.total_size: 44 | if self.total_size - len(indices) < len(indices): 45 | indices += indices[: (self.total_size - len(indices))] 46 | else: 47 | extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices)) 48 | indices += [indices[ids] for ids in extra_ids] 49 | assert len(indices) == self.total_size 50 | indices = indices[self.rank : self.total_size : self.num_replicas] 51 | self.num_samples = len(indices) 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | 60 | 61 | def get_loader(args): 62 | if args.test_mode: 63 | dataset_test = build_dataset(args.val_text, args=args) 64 | data_loader_val = torch.utils.data.DataLoader( 65 | dataset_test, sampler=None, 66 | batch_size=args.batch_size, 67 | num_workers=args.num_workers, 68 | drop_last=False) 69 | loader = data_loader_val 70 | 71 | else: 72 | dataset_train = build_dataset(args.train_text, args=args) 73 | dataset_val = build_dataset(args.val_text, args=args) 74 | data_loader_train = torch.utils.data.DataLoader( 75 | dataset_train, sampler=None, 76 | batch_size=args.batch_size, 77 | num_workers=args.num_workers, 78 | shuffle = True, 79 | drop_last=False) 80 | data_loader_val = torch.utils.data.DataLoader( 81 | dataset_val, sampler=None, 82 | batch_size=args.batch_size, 83 | num_workers=args.num_workers, 84 | shuffle=False, 85 | drop_last=False) 86 | loader = [data_loader_train, data_loader_val] 87 | 88 | return loader 89 | 90 | def build_dataset(split, args): 91 | transform = build_transform(split, args) 92 | 93 | if 'fewshot' in split: 94 | dataset = Fewshot_dataset(args.label_index, args.data_dir, args.list_dir, split, transform=transform) 95 | else: 96 | dataset = Parent_dataset(args.data_dir, args.list_dir, split, transform=transform) 97 | print(dataset) 98 | return dataset 99 | 100 | def build_transform(split, args): 101 | if 'train' in split: 102 | transform_image = transforms.Compose([ 103 | transforms.ToPILImage(), 104 | transforms.Resize((args.input_size,args.input_size), interpolation=PIL.Image.BICUBIC), # 3 is bicubic 105 | transforms.RandomHorizontalFlip(p=0.5),# 106 | transforms.ToTensor(), 107 | transforms.Normalize(mean=[0.227, 0.227, 0.227], std=[0.1935, 0.1935, 0.1935])]) 108 | 109 | transform_label = transforms.Compose([ 110 | transforms.ToPILImage(), 111 | #transforms.RandomRotation(degrees=(-180, 180)), # 112 | transforms.Resize((args.input_size,args.input_size), interpolation=PIL.Image.NEAREST), # 3 is bicubic 113 | transforms.RandomHorizontalFlip(p=0.5),# 114 | transforms.ToTensor()]) 115 | trans = [transform_image, transform_label] 116 | return trans 117 | 118 | elif 'val' in split: 119 | transform_image = transforms.Compose([ 120 | transforms.ToPILImage(), 121 | transforms.Resize((args.input_size,args.input_size), interpolation=PIL.Image.BICUBIC), # 3 is bicubic 122 | transforms.ToTensor(), 123 | transforms.Normalize(mean=[0.227, 0.227, 0.227], std=[0.1935, 0.1935, 0.1935])]) 124 | #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 125 | transform_label = transforms.Compose([ 126 | transforms.ToPILImage(), 127 | transforms.Resize((args.input_size,args.input_size), interpolation=PIL.Image.NEAREST), # 3 is bicubic 128 | transforms.ToTensor()]) 129 | trans = [transform_image, transform_label] 130 | return trans 131 | 132 | class Parent_dataset(Dataset): 133 | def __init__(self, data_dir, list_dir, split, transform=None): 134 | self.transform = transform 135 | self.split = split 136 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 137 | self.data_dir = data_dir 138 | def __len__(self): 139 | return len(self.sample_list) 140 | 141 | def __getitem__(self, idx): 142 | slice_name = self.sample_list[idx].strip('\n') 143 | image_path = os.path.join(self.data_dir, self.split, slice_name) 144 | image = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE) 145 | h, w =image.shape[0], image.shape[1] 146 | image = np.expand_dims(np.array(image), axis = -1).repeat(3,2) 147 | label_path = os.path.join(self.data_dir, self.split, slice_name.replace('.png','_gt.png')) 148 | 149 | label = cv2.imread(label_path,cv2.IMREAD_GRAYSCALE) 150 | label = np.expand_dims(np.array(label), axis = -1) 151 | 152 | if self.transform: 153 | image = self.transform[0](image) 154 | label = self.transform[1](label) 155 | if self.split == 'val': 156 | return [image, label, image_path, h, w] 157 | 158 | return [image, label] 159 | 160 | class Fewshot_dataset(Dataset): 161 | def __init__(self, label_index, data_dir, list_dir, split, transform=None): 162 | self.transform = transform 163 | self.split = split 164 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines() 165 | self.data_dir = data_dir 166 | self.label_index=label_index 167 | 168 | def __len__(self): 169 | return len(self.sample_list) 170 | 171 | def __getitem__(self, idx): 172 | slice_name = self.sample_list[idx].strip('\n') 173 | if 'train' in self.split: 174 | image_path = os.path.join(self.data_dir, 'train_fewshot_data', slice_name) 175 | label_path = os.path.join(self.data_dir, 'train_fewshot_data', slice_name.replace('.png', '_gt.png')) 176 | elif 'val' in self.split: 177 | image_path = os.path.join(self.data_dir, 'val_fewshot_data', slice_name) 178 | label_path = os.path.join(self.data_dir, 'val_fewshot_data', slice_name.replace('.png', '_gt.png')) 179 | else: 180 | raise 'can not find the train or val dict' 181 | image = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE) 182 | 183 | h, w =image.shape[0], image.shape[1] 184 | image = np.expand_dims(np.array(image), axis = -1).repeat(3,2) 185 | 186 | label = cv2.imread(label_path,cv2.IMREAD_GRAYSCALE) 187 | 188 | # for i in self.label_index: 189 | # if i<255: 190 | # label[label==i*25] = 255 191 | # label[label!=255] = 0 192 | 193 | label = np.expand_dims(np.array(label), axis = -1) 194 | 195 | if self.transform: 196 | seed = torch.random.seed() 197 | torch.random.manual_seed(seed) 198 | image = self.transform[0](image) 199 | torch.random.manual_seed(seed) 200 | label = self.transform[1](label) 201 | 202 | if 'val' in self.split: 203 | return [image, label, image_path, h, w] 204 | 205 | return [image, label] 206 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn.parallel 8 | import torch.utils.data.distributed 9 | #from tensorboardX import SummaryWriter 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch.cuda.amp import GradScaler, autocast 12 | from torch.nn.modules.loss import BCELoss 13 | 14 | class EarlyStopping: 15 | """Early stops the training if validation loss doesn't improve after a given patience.""" 16 | 17 | def __init__(self, patience=10, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 18 | """ 19 | Args: 20 | patience (int): How long to wait after last time validation loss improved. 21 | Default: 7 22 | verbose (bool): If True, prints a message for each validation loss improvement. 23 | Default: False 24 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 25 | Default: 0 26 | path (str): Path for the checkpoint to be saved to. 27 | Default: 'checkpoint.pt' 28 | trace_func (function): trace print function. 29 | Default: print 30 | """ 31 | self.patience = patience 32 | self.verbose = verbose 33 | self.counter = 0 34 | self.best_score = None 35 | self.early_stop = False 36 | self.val_loss_min = np.Inf 37 | self.delta = delta 38 | self.path = path 39 | self.trace_func = trace_func 40 | 41 | def __call__(self, val_loss): 42 | 43 | score = val_loss 44 | 45 | if self.best_score is None: 46 | self.best_score = score 47 | elif score < self.best_score + self.delta: 48 | self.counter += 1 49 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 50 | if self.counter >= self.patience: 51 | self.early_stop = True 52 | else: 53 | self.best_score = score 54 | self.counter = 0 55 | 56 | 57 | class DiceLoss(torch.nn.Module): 58 | def __init__(self, n_classes): 59 | super(DiceLoss, self).__init__() 60 | self.n_classes = n_classes 61 | 62 | def _one_hot_encoder(self, input_tensor): 63 | tensor_list = [] 64 | for i in range(self.n_classes): 65 | temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 66 | tensor_list.append(temp_prob.unsqueeze(1)) 67 | output_tensor = torch.cat(tensor_list, dim=1) 68 | return output_tensor.float() 69 | 70 | def _dice_loss(self, score, target): 71 | target = target.float() 72 | smooth = 1e-5 73 | intersect = torch.sum(score * target) 74 | y_sum = torch.sum(target * target) 75 | z_sum = torch.sum(score * score) 76 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 77 | loss = 1 - loss 78 | return loss 79 | 80 | def forward(self, inputs, target, weight=None, softmax=False): 81 | if softmax: 82 | inputs = torch.softmax(inputs, dim=1) 83 | else: 84 | inputs = torch.sigmoid(inputs) 85 | #target = self._one_hot_encoder(target) 86 | if weight is None: 87 | weight = [1] * self.n_classes 88 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) 89 | class_wise_dice = [] 90 | loss = 0.0 91 | for i in range(0, self.n_classes): 92 | dice = self._dice_loss(inputs[:, i], target[:, i]) 93 | class_wise_dice.append(1.0 - dice.item()) 94 | loss += dice * weight[i] 95 | return loss / self.n_classes 96 | 97 | def dice(x, y): 98 | intersect = np.sum(np.sum(np.sum(x * y))) 99 | y_sum = np.sum(np.sum(np.sum(y))) 100 | smooth = 1e-5 101 | if y_sum == 0: 102 | return 0.0 103 | x_sum = np.sum(np.sum(np.sum(x))) 104 | return (2 * intersect +smooth) / (x_sum + y_sum + smooth) 105 | 106 | 107 | class AverageMeter(object): 108 | def __init__(self): 109 | self.reset() 110 | 111 | def reset(self): 112 | self.val = 0 113 | self.avg = 0 114 | self.sum = 0 115 | self.count = 0 116 | 117 | def update(self, val, n=1): 118 | self.val = val 119 | self.sum += val * n 120 | self.count += n 121 | self.avg = np.where(self.count > 0, self.sum / self.count, self.sum) 122 | 123 | 124 | def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args): 125 | model.train() 126 | 127 | ce_loss = torch.nn.BCEWithLogitsLoss() 128 | dice_loss = DiceLoss(1) 129 | start_time = time.time() 130 | run_loss = AverageMeter() 131 | for idx, batch_data in enumerate(loader): 132 | 133 | if isinstance(batch_data, list): 134 | data, target = batch_data 135 | else: 136 | data, target = batch_data["image"], batch_data["label"] 137 | 138 | if args.in_channels == 1: 139 | data = torch.unsqueeze(data[:, 0, :, :], dim=1) 140 | 141 | data, target = data.cuda(args.rank), target.cuda(args.rank) 142 | for param in model.parameters(): 143 | param.grad = None 144 | 145 | with autocast(enabled=args.amp): 146 | logits = model(data) 147 | 148 | loss_ce = ce_loss(logits, target) 149 | loss_dice = dice_loss(logits, target, softmax=False) 150 | loss = 1.0 * loss_ce + 1.0 * loss_dice 151 | print('ce_loss: {:.3f} dice_loss: {:.3f}'.format(loss_ce, loss_dice)) 152 | 153 | if args.amp: 154 | scaler.scale(loss).backward() 155 | scaler.step(optimizer) 156 | scaler.update() 157 | else: 158 | loss.backward() 159 | optimizer.step() 160 | 161 | run_loss.update(loss.item(), n=args.batch_size) 162 | if args.rank == 0: 163 | print( 164 | "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), 165 | "loss: {:.4f}".format(loss.item()), 166 | "lr: {:.8f}".format(optimizer.param_groups[0]['lr']), 167 | "time {:.2f}s".format(time.time() - start_time), 168 | ) 169 | start_time = time.time() 170 | for param in model.parameters(): 171 | param.grad = None 172 | 173 | return run_loss.avg 174 | 175 | 176 | def val_epoch(model, loader, epoch=None, acc_func=None, args=None, model_inferer=None, post_label=None, post_pred=None): 177 | model.eval() 178 | start_time = time.time() 179 | run_acc = AverageMeter() 180 | ce_loss = torch.nn.BCEWithLogitsLoss() 181 | dice_loss = DiceLoss(1) 182 | run_loss = AverageMeter() 183 | with torch.no_grad(): 184 | for idx, batch_data in enumerate(loader): 185 | if isinstance(batch_data, list): 186 | data, target, image_path, h, w = batch_data 187 | else: 188 | data, target = batch_data["image"], batch_data["label"] 189 | 190 | if args.in_channels == 1: 191 | data = torch.unsqueeze(data[:, 0, :, :], dim=1) # 三通道改单通道 192 | 193 | data, target = data.cuda(args.rank), target.cuda(args.rank) 194 | with autocast(enabled=args.amp): 195 | if model_inferer is not None: 196 | logits = model_inferer(data) 197 | else: 198 | logits = model(data) 199 | #loss 200 | loss_ce = ce_loss(logits, target) 201 | loss_dice = dice_loss(logits, target, softmax=False) 202 | loss = 1.0 * loss_ce + 1.0 * loss_dice 203 | #print('ce_loss: {:.3f} dice_loss: {:.3f}'.format(loss_ce, loss_dice)) 204 | #sigmoid for dice 205 | out = torch.sigmoid(logits) 206 | acc_list = [] 207 | if out.is_cuda: 208 | target = target.cpu().numpy() 209 | out = out.cpu().detach().numpy() 210 | h = h.cpu().numpy() 211 | w = w.cpu().numpy() 212 | assert out.shape == target.shape, 'predict {} & target {} shape do not match'.format(out.shape, target.shape) 213 | for i in range(out.shape[0]): 214 | out[i] = np.where(out[i]>0.5, 1, 0) 215 | 216 | acc_list.append(dice(out[i],target[i])) 217 | pre = np.array(out[i],dtype=np.uint8) 218 | pre = np.squeeze(pre,axis=0) 219 | pre = cv2.resize(pre,(w[i], h[i]),interpolation=cv2.INTER_NEAREST)*255 220 | #postprocess 221 | # pre[0:1,:]=0 222 | # pre[:, 0:1] = 0 223 | # pre = fillHole(pre) 224 | # pre = save_max_objects(pre) 225 | 226 | 227 | pre_path = image_path[i].replace('val_fewshot_data', 'val_fewshot_results') 228 | if not os.path.isdir(os.path.split(pre_path)[0]): 229 | os.makedirs(os.path.split(pre_path)[0]) 230 | 231 | path=pre_path.replace('.png', '_auto.png') 232 | cv2.imwrite(path, pre) 233 | 234 | 235 | avg_acc = np.mean(np.array(acc_list)) 236 | run_acc.update(avg_acc, n=args.batch_size) 237 | run_loss.update(loss.item(), n=args.batch_size) 238 | if args.rank == 0: 239 | print( 240 | "Val {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), 241 | "dice:", avg_acc, 242 | 'loss:',loss.cpu().numpy(), 243 | "time {:.2f}s".format(time.time() - start_time), 244 | ) 245 | start_time = time.time() 246 | return run_acc.avg, run_loss.avg, os.path.split(pre_path)[0] 247 | 248 | 249 | def save_checkpoint(model, epoch, args, filename="model.pth", best_acc=0, optimizer=None, scheduler=None): 250 | state_dict = model.state_dict() if not args.distributed else model.module.state_dict() 251 | save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict} 252 | if optimizer is not None: 253 | save_dict["optimizer"] = optimizer.state_dict() 254 | if scheduler is not None: 255 | save_dict["scheduler"] = scheduler.state_dict() 256 | filename = os.path.join(args.logdir, filename) 257 | torch.save(save_dict, filename) 258 | print("Saving checkpoint", filename) 259 | 260 | 261 | def run_training( 262 | model, 263 | train_loader, 264 | val_loader, 265 | optimizer, 266 | loss_func, 267 | acc_func, 268 | args, 269 | model_inferer=None, 270 | scheduler=None, 271 | start_epoch=0, 272 | post_label=None, 273 | post_pred=None, 274 | ): 275 | early_stopping = EarlyStopping(patience=10, verbose=True) 276 | spend_time = 0 277 | writer = None 278 | if args.logdir is not None and args.rank == 0: 279 | writer = SummaryWriter(log_dir=args.logdir) 280 | if args.rank == 0: 281 | print("Writing Tensorboard logs to ", args.logdir) 282 | scaler = None 283 | if args.amp: 284 | scaler = GradScaler()#using float16 to reduce memory 285 | val_acc_max = 0.0 286 | for epoch in range(start_epoch, args.max_epochs): 287 | 288 | print(args.rank, time.ctime(), "Epoch:", epoch) 289 | epoch_time = time.time() 290 | train_loss = train_epoch( 291 | model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args 292 | )#for training one epoch 293 | 294 | # if scheduler is not None: 295 | # scheduler.step() 296 | 297 | spend_time += time.time() - epoch_time 298 | if args.rank == 0: 299 | print( 300 | "Final training {}/{}".format(epoch, args.max_epochs - 1), 301 | "loss: {:.4f}".format(train_loss), 302 | "time {:.2f}s".format(time.time() - epoch_time), 303 | ) 304 | with open(os.path.join(args.logdir, "log.txt"), mode="a", encoding="utf-8") as f: 305 | f.write("Final training:{}/{},".format(epoch, args.max_epochs - 1) + "loss:{}".format(train_loss) + "\n") 306 | if args.rank == 0 and writer is not None: 307 | writer.add_scalar("train_loss", train_loss, epoch) 308 | b_new_best = False 309 | if (epoch + 1) % args.val_every == 0: 310 | if args.distributed: 311 | torch.distributed.barrier() 312 | epoch_time = time.time() 313 | val_avg_acc, run_loss, file_path = val_epoch( 314 | model, 315 | val_loader, 316 | epoch=epoch, 317 | acc_func=acc_func, 318 | model_inferer=model_inferer, 319 | args=args, 320 | post_label=post_label, 321 | post_pred=post_pred, 322 | ) 323 | if args.rank == 0: 324 | print( 325 | "Final validation {}/{}".format(epoch, args.max_epochs - 1), 326 | "acc:", val_avg_acc, 327 | 'std:', 328 | 'loss:',run_loss, 329 | "time {:.2f}s".format(time.time() - epoch_time), 330 | ) 331 | with open(os.path.join(args.logdir, "log.txt"), mode="a", encoding="utf-8") as f: 332 | f.write("Final validation:{}/{},".format(epoch, args.max_epochs - 1) + "dice:{},".format(val_avg_acc) 333 | + "loss:{},".format(run_loss)+ "\n") 334 | if writer is not None: 335 | writer.add_scalar("val_acc", val_avg_acc, epoch) 336 | if val_avg_acc > val_acc_max: 337 | print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc)) 338 | val_acc_max = val_avg_acc 339 | b_new_best = True 340 | #fewshot save bset results 341 | if os.path.exists(file_path.replace('val_fewshot_results','val_fewshot_results_best')): 342 | shutil.rmtree(file_path.replace('val_fewshot_results','val_fewshot_results_best')) 343 | shutil.copytree(file_path,file_path.replace('val_fewshot_results','val_fewshot_results_best')) 344 | 345 | #unet save best results 346 | if os.path.exists('./dataset/val_fewshot_results_best'): 347 | shutil.rmtree('./dataset/val_fewshot_results_best') 348 | shutil.copytree('./dataset/val_fewshot_results', 349 | './dataset/val_fewshot_results_best') 350 | 351 | #save weights 352 | if args.rank == 0 and args.logdir is not None and args.save_checkpoint: 353 | save_checkpoint( 354 | model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler 355 | ) 356 | if args.rank == 0 and args.logdir is not None and args.save_checkpoint: 357 | save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pth") 358 | if b_new_best: 359 | print("Copying to model.pt new best model!!!!") 360 | shutil.copyfile(os.path.join(args.logdir, "model_final.pth"), os.path.join(args.logdir, "model.pth")) 361 | 362 | early_stopping(val_avg_acc) 363 | if early_stopping.early_stop: 364 | print("Early stop!") 365 | break 366 | 367 | if scheduler is not None: 368 | scheduler.step(-val_avg_acc) 369 | 370 | print("Training Finished !, Best Accuracy: ", val_acc_max, "Total time: {} s.".format(round(spend_time))) 371 | 372 | return val_acc_max 373 | 374 | from skimage import measure 375 | def save_max_objects(img): 376 | labels = measure.label(img, connectivity=1) 377 | jj = measure.regionprops(labels) 378 | # is_del = False 379 | if len(jj) == 0: 380 | out = img 381 | return out 382 | elif len(jj) == 1: 383 | out = img 384 | return out 385 | # is_del = False 386 | else: 387 | num = labels.max() 388 | del_array = np.array([0] * (num + 1)) 389 | for k in range(num): 390 | if k == 0: 391 | initial_area = jj[0].area 392 | save_index = 1 393 | else: 394 | k_area = jj[k].area 395 | 396 | if initial_area < k_area: 397 | initial_area = k_area 398 | save_index = k + 1 399 | 400 | del_array[save_index] = 1 401 | del_mask = del_array[labels] 402 | out = img * del_mask 403 | return out 404 | 405 | 406 | def fillHole(im_in): 407 | im_in = im_in.astype(np.uint8) 408 | # print np.unique(im_in) 409 | im_floodfill = im_in.copy() 410 | # Mask used to flood filling. 411 | # Notice the size needs to be 2 pixels than the image. 412 | h, w = im_in.shape[:2] 413 | mask = np.zeros((h + 2, w + 2), np.uint8) 414 | 415 | # Floodfill from point (0, 0) 416 | cv2.floodFill(im_floodfill, mask, (0, 0), 255) 417 | 418 | # Invert floodfilled image 419 | im_floodfill_inv = cv2.bitwise_not(im_floodfill) 420 | 421 | # Combine the two images to get the foreground. 422 | im_out = im_in | im_floodfill_inv 423 | # print np.unique(im_out) 424 | return im_out 425 | --------------------------------------------------------------------------------