├── README.md
├── coreset selection.py
├── dataset
├── train_fewshot_data.txt
├── train_fewshot_data
│ └── TRAIN001
│ │ ├── 20.png
│ │ ├── 20_gt.png
│ │ ├── 21.png
│ │ ├── 21_gt.png
│ │ ├── 22.png
│ │ ├── 22_gt.png
│ │ ├── 23.png
│ │ ├── 23_gt.png
│ │ ├── 24.png
│ │ ├── 24_gt.png
│ │ ├── 25.png
│ │ ├── 25_gt.png
│ │ ├── 26.png
│ │ ├── 26_gt.png
│ │ ├── 27.png
│ │ ├── 27_gt.png
│ │ ├── 28.png
│ │ ├── 28_gt.png
│ │ ├── 29.png
│ │ ├── 29_gt.png
│ │ ├── 30.png
│ │ └── 30_gt.png
├── val_fewshot_data.txt
└── val_fewshot_data
│ └── VAL001
│ ├── 20.png
│ ├── 20_gt.png
│ ├── 21.png
│ ├── 21_gt.png
│ ├── 22.png
│ ├── 22_gt.png
│ ├── 23.png
│ ├── 23_gt.png
│ ├── 24.png
│ ├── 24_gt.png
│ ├── 25.png
│ ├── 25_gt.png
│ ├── 26.png
│ ├── 26_gt.png
│ ├── 27.png
│ ├── 27_gt.png
│ ├── 28.png
│ ├── 28_gt.png
│ ├── 29.png
│ ├── 29_gt.png
│ ├── 30.png
│ ├── 30_gt.png
│ ├── 31.png
│ ├── 31_gt.png
│ ├── 32.png
│ ├── 32_gt.png
│ ├── 33.png
│ └── 33_gt.png
├── images
├── Figure 1.png
└── Figure 2.png
├── main.py
├── networks
├── __init__.py
├── unet.py
└── unetr.py
├── optimizers
├── __init__.py
└── lr_scheduler.py
├── requirements.txt
├── runs
├── __init__.py
└── log
│ └── log.txt
└── utils
├── __init__.py
├── data_utils.py
└── trainer.py
/README.md:
--------------------------------------------------------------------------------
1 | # Annotation-efficient-learning-for-OCT-segmentation
2 |
3 | This repository contains the code for the paper ["Annotation-efficient learning for OCT segmentation"](https://opg.optica.org/boe/fulltext.cfm?uri=boe-14-7-3294&id=531648). We propose an annotation-efficient learning method for OCT segmentation that could significantly reduce annotation costs and improve learning efficiency. Here we provide generative pre-trained transformer-based encoder and CNN-based segmentation decoder, both pretrained on open-access OCTdatasets. The proposed pre-trained model can be directly transfered to your ROI segmeantation based on OCT image. We hope this may help improve the intelligence and application penetration of OCT.
4 |
5 | 
6 | 
7 |
8 | ## Dependencies
9 | python==3.8
10 | torch==1.11.1
11 | numpy==1.19.5
12 | monai==0.7.0
13 | timm==0.3.2
14 | tensorboardX==2.1
15 | torchvision==0.12.0
16 | opencv-python==4.5.5
17 |
18 | ## Usage
19 | 1. Clone the repository:
20 | ```
21 | git clone https://github.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation.git
22 | ```
23 |
24 | 2. Install the required dependencies:
25 | ```
26 | pip install -r requirements.txt
27 | ```
28 | 3. Download the pre-trained [phase1 model file](https://drive.google.com/file/d/1JHdL1HRZM86n4761uoO3_6N4p4zqBDmx/view?usp=sharing) for weights of encoder and [phase2 model file](https://drive.google.com/file/d/1gOihHsH4-GAtS6R6wzxkOQsaMAF6fj_h/view?usp=sharing) for weights of decoder, and then put them in `./runs/` folder.
29 |
30 | 4. Edit suitable path and parameters in main.py
31 |
32 | 5. Go to the corresponding folder and run:
33 | ```
34 | cd Annotation-efficient-learning-for-OCT-segmentation
35 | python main.py
36 | ```
37 |
38 | ## Training on your Dataset
39 | The prepared architecture of dataset is referenced to `./dataset/` folder containing `train_fewshot_data` and `val_fewshot_data`. The name index of images is listed in `train_fewshot_data.txt` and `val_fewshot_data.txt`.
40 |
41 | ## Citation
42 | ```
43 | @article{
44 | title={Annotation-efficient learning for OCT segmentation},
45 | author={Zhang, Haoran and Yang, Jianlong and Zheng, Ce and Zhao, Shiqing and Zhang, Aili},
46 | journal={Biomedical Optics Express},
47 | volume={14},
48 | number={7},
49 | pages={3294--3307},
50 | year={2023},
51 | publisher={Optica Publishing Group}
52 | }
53 | ```
54 |
--------------------------------------------------------------------------------
/coreset selection.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def euclidean_dist(x, y):
5 | m, n = x.size(0), y.size(0)
6 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
7 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
8 | dist = xx + yy
9 | dist.addmm_(1, -2, x, y.t())
10 | dist = dist.clamp(min=1e-12).sqrt()
11 | return dist
12 |
13 | def coreset_selection(matrix, budget: int, metric, device, random_seed=None, index=None, already_selected=None,
14 | print_freq: int = 20):
15 | if type(matrix) == torch.Tensor:
16 | assert matrix.dim() == 2
17 | elif type(matrix) == np.ndarray:
18 | assert matrix.ndim == 2
19 | matrix = torch.from_numpy(matrix).requires_grad_(False).to(device)
20 |
21 | sample_num = matrix.shape[0]
22 | assert sample_num >= 1
23 |
24 | if budget < 0:
25 | raise ValueError("Illegal budget size.")
26 | elif budget > sample_num:
27 | budget = sample_num
28 |
29 | if index is not None:
30 | assert matrix.shape[0] == len(index)
31 | else:
32 | index = np.arange(sample_num)
33 |
34 | assert callable(metric)
35 |
36 | already_selected = np.array(already_selected)
37 |
38 | with torch.no_grad():
39 | np.random.seed(random_seed)
40 | if already_selected.__len__() == 0:
41 | select_result = np.zeros(sample_num, dtype=bool)
42 | # Randomly select one initial point.
43 | already_selected = [np.random.randint(0, sample_num)]
44 | budget -= 1
45 | select_result[already_selected] = True
46 | else:
47 | select_result = np.in1d(index, already_selected)
48 |
49 | num_of_already_selected = np.sum(select_result)# =1
50 |
51 | # Initialize a (num_of_already_selected+budget-1)*sample_num matrix storing distances of pool points from
52 | # each clustering center.
53 | dis_matrix = -1 * torch.ones([num_of_already_selected + budget - 1, sample_num], requires_grad=False).to(device)
54 |
55 | dis_matrix[:num_of_already_selected, ~select_result] = metric(matrix[select_result], matrix[~select_result])
56 |
57 | mins = torch.min(dis_matrix[:num_of_already_selected, :], dim=0).values#每个点和对应的already_selected最小值
58 |
59 | for i in range(budget):
60 | if i % print_freq == 0:
61 | print("| Selecting [%3d/%3d]" % (i + 1, budget))
62 | p = torch.argmax(mins).item()
63 | select_result[p] = True
64 |
65 | if i == budget - 1:
66 | break
67 | mins[p] = -1
68 | dis_matrix[num_of_already_selected + i, ~select_result] = metric(matrix[[p]], matrix[~select_result])
69 | mins = torch.min(mins, dis_matrix[num_of_already_selected + i])
70 | return index[select_result]
71 |
72 | def main(txtName, numpyName, caseList, budget, device, seed):
73 | """
74 | txtName: The .txt contains the ordinal name list of images. For example:"TRAIN001/001.png"
75 | numpyName: The .np file containsthe ordinal 1-D feature maps of image (reshaped from 2-D feature map),
76 | which are acquired from the output of encoder. The size of it is (number of images, feature dimensions).
77 | caseList: The ordinal list of case name. For example: ['TRAIN001', 'TRAIN002', 'TRAIN003', 'TRAIN04'].
78 | budget: The number of images for core-set selection.
79 | device: The cuda number or cpu.
80 | seed: The random seed decides the initial point.
81 | """
82 |
83 | #loading txt list
84 | with open(txtName, mode='r') as F:
85 | imageList = F.readlines()
86 | #load image features
87 | already_selected = []
88 | matrix = np.load(numpyName)
89 | matrix = np.array(matrix, dtype=np.float32)
90 | print("number of images total:",matrix.shape[0],", feature dimension:",matrix.shape[1])
91 |
92 | for case in caseList:
93 | indexlist = []
94 | featurelist = []
95 | for i, imageName in enumerate(imageList):
96 | if case in imageName:
97 | indexlist.append(i)
98 | featurelist.append(matrix[i])
99 | x = torch.from_numpy(np.array(featurelist))
100 | y = torch.from_numpy(np.array(featurelist))
101 | dis_matrix = euclidean_dist(x, y)
102 | min = int(torch.argmin(torch.sum(dis_matrix, dim=0), dim=0).numpy())
103 | already_selected.append(indexlist[min])
104 |
105 | subset = coreset_selection(matrix, budget=int(np.round(budget)),
106 | metric=euclidean_dist, device=device,
107 | random_seed=seed,
108 | already_selected=np.array(already_selected))
109 |
110 | print("{} images has been selected as for core-set!".format(len(subset)))
111 | print("The index of selected images are as follows:")
112 | print(subset)
113 |
114 | if __name__ == '__main__':
115 |
116 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
117 | txtName = './train_fewshot_data.txt'
118 | numpyName = 'imageFeature.npy'
119 | caseList = ['TRAIN001', 'TRAIN002', 'TRAIN003']
120 | budget = 50
121 | seed=700
122 |
123 | main(txtName, numpyName, caseList, budget, device, seed)
124 |
--------------------------------------------------------------------------------
/dataset/train_fewshot_data.txt:
--------------------------------------------------------------------------------
1 | TRAIN001/20.png
2 | TRAIN001/21.png
3 | TRAIN001/22.png
4 | TRAIN001/23.png
5 | TRAIN001/24.png
6 | TRAIN001/25.png
7 | TRAIN001/26.png
8 | TRAIN001/27.png
9 | TRAIN001/28.png
10 | TRAIN001/29.png
11 | TRAIN001/30.png
12 |
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/20.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/20_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/20_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/21.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/21_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/21_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/22.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/22_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/22_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/23.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/23_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/23_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/24.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/24_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/24_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/25.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/25_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/25_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/26.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/26_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/26_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/27.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/27_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/27_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/28.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/28_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/28_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/29.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/29_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/29_gt.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/30.png
--------------------------------------------------------------------------------
/dataset/train_fewshot_data/TRAIN001/30_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/train_fewshot_data/TRAIN001/30_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data.txt:
--------------------------------------------------------------------------------
1 | VAL001/20.png
2 | VAL001/21.png
3 | VAL001/22.png
4 | VAL001/23.png
5 | VAL001/24.png
6 | VAL001/25.png
7 | VAL001/26.png
8 | VAL001/27.png
9 | VAL001/28.png
10 | VAL001/29.png
11 | VAL001/30.png
12 | VAL001/31.png
13 | VAL001/32.png
14 | VAL001/33.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/20.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/20_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/20_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/21.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/21_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/21_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/22.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/22_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/22_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/23.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/23_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/23_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/24.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/24_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/24_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/25.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/25_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/25_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/26.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/26.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/26_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/26_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/27.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/27_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/27_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/28.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/28.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/28_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/28_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/29.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/29_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/29_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/30.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/30_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/30_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/31.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/31.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/31_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/31_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/32.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/32_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/32_gt.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/33.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/33.png
--------------------------------------------------------------------------------
/dataset/val_fewshot_data/VAL001/33_gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/dataset/val_fewshot_data/VAL001/33_gt.png
--------------------------------------------------------------------------------
/images/Figure 1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/images/Figure 1.png
--------------------------------------------------------------------------------
/images/Figure 2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/images/Figure 2.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 | import torch.nn.parallel
6 | import torch.utils.data.distributed
7 | import networks.unetr
8 | from utils.trainer import run_training, val_epoch
9 | from utils.data_utils import get_loader
10 | from networks.unet import UNet
11 | from monai.losses import DiceCELoss
12 | from monai.metrics import DiceMetric
13 | from monai.transforms import AsDiscrete
14 | from monai.utils.enums import MetricReduction
15 | import random
16 |
17 | def main(train,val):
18 | parser = argparse.ArgumentParser(description="UNETR segmentation pipeline")
19 | parser.add_argument("--test_mode", default=False, type=bool, help="test mode or not")
20 | parser.add_argument("--label_index",default=[255],type=list, help="index of label for foreground segmentation")
21 | parser.add_argument("--phase2_weight", default="./runs/phase2 model.pth", type=str,help="load phase2 model weight")
22 | parser.add_argument("--phase1_weight", default="./runs/phase1 model.pth", type=str, help="load phase1 model weight")
23 | parser.add_argument("--logdir", default="./runs/log", type=str, help="directory to save the tensorboard logs")
24 | parser.add_argument('--list_dir', type=str, default='./dataset', help='list dir')
25 | parser.add_argument("--data_dir", default="./dataset", type=str, help="dataset directory")
26 | parser.add_argument("--train_text", default=train, type=str, help="training image list text name")
27 | parser.add_argument("--val_text", default=val, type=str, help="validating image list text name")
28 | parser.add_argument('--model_name', default='vit_large_patch16', type=str, metavar='MODEL',help='Name of model to train')
29 | parser.add_argument("--save_checkpoint", default=True,type=bool, help="save checkpoint during training")
30 | parser.add_argument("--max_epochs", default=100, type=int, help="max number of training epochs")
31 | parser.add_argument("--batch_size", default=4, type=int, help="number of batch size")
32 | parser.add_argument("--optim_lr", default=0.5*1e-4, type=float, help="optimization learning rate")
33 | parser.add_argument("--optim_name", default="adam", type=str, help="optimization algorithm")#adamw
34 | parser.add_argument("--reg_weight", default=1e-5, type=float, help="regularization weight")
35 | parser.add_argument("--momentum", default=0.99, type=float, help="momentum")
36 | parser.add_argument("--noamp", action="store_true", help="do NOT use amp for training")
37 | parser.add_argument("--val_every", default=1, type=int, help="validation frequency")
38 | parser.add_argument("--distributed", action="store_true", help="start distributed training")
39 | parser.add_argument("--world_size", default=1, type=int, help="number of nodes for distributed training")
40 | parser.add_argument("--rank", default=0, type=int, help="node rank for distributed training")
41 | parser.add_argument("--num_workers", default=2, type=int, help="number of workers")
42 | parser.add_argument("--norm_name", default="instance", type=str, help="normalization layer type in decoder")
43 | parser.add_argument("--feature_size", default=64, type=int, help="feature size dimention")#
44 | parser.add_argument("--in_channels", default=3, type=int, help="number of input channels")
45 | parser.add_argument("--out_channels", default=1, type=int, help="number of output channels")
46 | parser.add_argument("--res_block", default=True,type=bool, help="use residual blocks")
47 | parser.add_argument("--conv_block", default=True,type=bool, help="use conv blocks")
48 | parser.add_argument('--input_size', default=224, type=int,help='images input size')#
49 | parser.add_argument("--dropout_rate", default=0.4, type=float, help="dropout rate")
50 | parser.add_argument("--lrschedule", default="warmup_cosine", type=str, help="type of learning rate scheduler")
51 | parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs")
52 | parser.add_argument("--smooth_dr", default=1e-6, type=float, help="constant added to dice denominator to avoid nan")
53 | parser.add_argument("--smooth_nr", default=0.0, type=float, help="constant added to dice numerator to avoid zero")
54 |
55 | args = parser.parse_args()
56 | args.amp = not args.noamp
57 |
58 | main_worker(gpu=0, args=args)
59 | seed = 666
60 | torch.manual_seed(seed)
61 | torch.cuda.manual_seed_all(seed)
62 | np.random.seed(seed)
63 | random.seed(seed)
64 | torch.backends.cudnn.deterministic = True
65 |
66 |
67 | def main_worker(gpu, args):
68 | np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)
69 | args.gpu = gpu
70 | torch.cuda.set_device(args.gpu)
71 | torch.backends.cudnn.benchmark = True
72 |
73 | loader = get_loader(args)
74 | print(args.rank, " gpu", args.gpu)
75 | if args.rank == 0:
76 | print("Batch size is:", args.batch_size, "epochs", args.max_epochs)
77 |
78 |
79 | if (args.model_name is None) or args.model_name in ['vit_base_patch16','vit_large_patch16','vit_huge_patch14']:
80 | model = networks.unetr.__dict__[args.model_name](
81 | in_channels=args.in_channels,
82 | out_channels=args.out_channels,
83 | img_size=args.input_size,
84 | feature_size=args.feature_size,
85 | norm_name=args.norm_name,
86 | conv_block=True,
87 | res_block=True,
88 | dropout_rate=args.dropout_rate,
89 | num_classes=1000)
90 |
91 | # loading weights
92 | if not args.test_mode:
93 | if args.phase1_weight is not None and args.phase2_weight is None:# training phase 2
94 | checkpoint = torch.load(args.phase1_weight,map_location='cpu')
95 | checkpoint_model = checkpoint['model']
96 | print(checkpoint['model'])
97 | model.load_state_dict(checkpoint_model, strict=False)
98 | print("Use pretrained weights")
99 | #print(model.state_dict())
100 | elif args.phase1_weight is not None and args.phase2_weight is not None: # training phase 4
101 | #load weight of phase2 model first (for decoder)
102 | assert args.phase2_weight != None, 'No segmentation model weights loaded'
103 | model_dict = torch.load(args.phase2_weight, map_location="cpu")
104 | print(model_dict['state_dict'])
105 | model.load_state_dict(model_dict['state_dict'], strict=True)
106 | # then load weight of phase1 model (for encoder)
107 | checkpoint = torch.load(args.phase1_weight,map_location='cpu')
108 | checkpoint_model = checkpoint['model']
109 | print(checkpoint_model)
110 | model.load_state_dict(checkpoint_model, strict=False)
111 |
112 |
113 | else:# test mode, load best model
114 | assert args.phase2_weight != None, 'No segmentation model weights loaded'
115 | model_dict = torch.load(args.phase2_weight, map_location="cpu")
116 | model.load_state_dict(model_dict['state_dict'])
117 |
118 | elif args.model_name =='unet':
119 | model = UNet(n_channels = args.in_channels, n_classes = args.out_channels, bilinear=True)
120 | else:
121 | raise ValueError("Unsupported model " + str(args.model_name))
122 |
123 | #setup loss
124 | dice_loss = DiceCELoss(
125 | to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr)
126 |
127 | post_label = AsDiscrete(to_onehot=True, n_classes=args.out_channels)
128 | post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=args.out_channels)
129 | dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN, get_not_nans=True)
130 |
131 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
132 | print("Total parameters count", pytorch_total_params)
133 |
134 | best_acc = 0
135 | start_epoch = 0
136 |
137 | model.cuda(args.gpu)
138 |
139 | if args.optim_name == "adam":
140 | optimizer = torch.optim.Adam(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight)
141 | elif args.optim_name == "adamw":
142 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.optim_lr, weight_decay=args.reg_weight)
143 | elif args.optim_name == "sgd":
144 | optimizer = torch.optim.SGD(
145 | model.parameters(), lr=args.optim_lr, momentum=args.momentum, nesterov=True, weight_decay=args.reg_weight
146 | )
147 | else:
148 | raise ValueError("Unsupported Optimization Procedure: " + str(args.optim_name))
149 |
150 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1,
151 | patience=5) # goal: maximize Dice score
152 |
153 | if args.model_name =='unet':
154 | optimizer = torch.optim.RMSprop(model.parameters(), lr=args.optim_lr, weight_decay=1e-8, momentum=0.99)
155 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5) # goal: maximize Dice score
156 |
157 | if not args.test_mode:
158 | accuracy = run_training(
159 | model=model,
160 | train_loader=loader[0],
161 | val_loader=loader[1],
162 | optimizer=optimizer,
163 | loss_func=dice_loss,
164 | acc_func=dice_acc,
165 | args=args,
166 | model_inferer=None,
167 | scheduler=scheduler,
168 | start_epoch=start_epoch,
169 | post_label=post_label,
170 | post_pred=post_pred,
171 | )
172 | else:
173 | accuracy, run_loss = val_epoch(
174 | model=model,
175 | loader=loader,
176 | args=args
177 | )
178 | print("final acc:", accuracy, 'loss:', run_loss)
179 | return accuracy
180 |
181 |
182 | if __name__ == "__main__":
183 |
184 | train = 'train_fewshot_data'
185 | val = 'val_fewshot_data'
186 | main(train,val)
187 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/networks/__init__.py
--------------------------------------------------------------------------------
/networks/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class UNet(nn.Module):
6 | def __init__(self, n_channels, n_classes, bilinear=False):
7 | super(UNet, self).__init__()
8 | self.n_channels = n_channels
9 | self.n_classes = n_classes
10 | self.bilinear = bilinear
11 |
12 | self.inc = DoubleConv(n_channels, 64)
13 | self.down1 = Down(64, 128)
14 | self.down2 = Down(128, 256)
15 | self.down3 = Down(256, 512)
16 | factor = 2 if bilinear else 1
17 | self.down4 = Down(512, 1024 // factor)
18 | self.up1 = Up(1024, 512 // factor, bilinear)
19 | self.up2 = Up(512, 256 // factor, bilinear)
20 | self.up3 = Up(256, 128 // factor, bilinear)
21 | self.up4 = Up(128, 64, bilinear)
22 | self.outc = OutConv(64, n_classes)
23 |
24 | def forward(self, x):
25 | x1 = self.inc(x)
26 | x2 = self.down1(x1)
27 | x3 = self.down2(x2)
28 | x4 = self.down3(x3)
29 | x5 = self.down4(x4)
30 | x = self.up1(x5, x4)
31 | x = self.up2(x, x3)
32 | x = self.up3(x, x2)
33 | x = self.up4(x, x1)
34 | logits = self.outc(x)
35 | return logits
36 |
37 |
38 | class DoubleConv(nn.Module):
39 | """(convolution => [BN] => ReLU) * 2"""
40 |
41 | def __init__(self, in_channels, out_channels, mid_channels=None):
42 | super().__init__()
43 | if not mid_channels:
44 | mid_channels = out_channels
45 | self.double_conv = nn.Sequential(
46 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
47 | nn.BatchNorm2d(mid_channels),
48 | nn.ReLU(inplace=True),
49 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
50 | nn.BatchNorm2d(out_channels),
51 | nn.ReLU(inplace=True)
52 | )
53 |
54 | def forward(self, x):
55 | return self.double_conv(x)
56 |
57 |
58 | class Down(nn.Module):
59 | """Downscaling with maxpool then double conv"""
60 |
61 | def __init__(self, in_channels, out_channels):
62 | super().__init__()
63 | self.maxpool_conv = nn.Sequential(
64 | nn.MaxPool2d(2),
65 | DoubleConv(in_channels, out_channels)
66 | )
67 |
68 | def forward(self, x):
69 | return self.maxpool_conv(x)
70 |
71 |
72 | class Up(nn.Module):
73 | """Upscaling then double conv"""
74 |
75 | def __init__(self, in_channels, out_channels, bilinear=True):
76 | super().__init__()
77 |
78 | # if bilinear, use the normal convolutions to reduce the number of channels
79 | if bilinear:
80 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
81 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
82 | else:
83 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
84 | self.conv = DoubleConv(in_channels, out_channels)
85 |
86 | def forward(self, x1, x2):
87 | x1 = self.up(x1)
88 | # input is CHW
89 | diffY = x2.size()[2] - x1.size()[2]
90 | diffX = x2.size()[3] - x1.size()[3]
91 |
92 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
93 | diffY // 2, diffY - diffY // 2])
94 | # if you have padding issues, see
95 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
96 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
97 | x = torch.cat([x2, x1], dim=1)
98 | return self.conv(x)
99 |
100 |
101 | class OutConv(nn.Module):
102 | def __init__(self, in_channels, out_channels):
103 | super(OutConv, self).__init__()
104 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
105 |
106 | def forward(self, x):
107 | return self.conv(x)
--------------------------------------------------------------------------------
/networks/unetr.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Tuple, Union
3 | import timm
4 | import torch
5 | import torch.nn as nn
6 | from timm.models.vision_transformer import PatchEmbed, Block
7 | from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
8 | from monai.networks.blocks.dynunet_block import UnetOutBlock
9 |
10 |
11 | def vit_base_patch16(**kwargs):
12 | model = VisionTransformer(
13 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
14 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
15 | return model
16 |
17 |
18 | def vit_large_patch16(**kwargs):
19 | model = VisionTransformer(
20 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
21 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
22 | return model
23 |
24 |
25 | def vit_huge_patch14(**kwargs):
26 | model = VisionTransformer(
27 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
28 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
29 | return model
30 |
31 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
32 |
33 | def __init__(
34 | self,
35 | in_channels: int = 3,
36 | out_channels: int = 1,
37 | img_size: int = 224,
38 | feature_size: int = 64,
39 | #hidden_size: int = 768,
40 | #mlp_dim: int = 3072,
41 | #num_heads: int = 12,
42 | #pos_embed: str = "perceptron",
43 | norm_name: Union[Tuple, str] = "instance",
44 | conv_block: bool = False,
45 | res_block: bool = True,
46 | dropout_rate: float = 0.0,
47 |
48 | patch_size=16,
49 | embed_dim=1024,
50 | depth=24,
51 | num_heads=16,
52 | mlp_ratio=4,
53 | qkv_bias=True,
54 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
55 | **kwargs
56 | ) -> None:
57 | """
58 | Args:
59 | in_channels: dimension of input channels.
60 | out_channels: dimension of output channels.
61 | img_size: dimension of input image.
62 | feature_size: dimension of network feature size.
63 | hidden_size: dimension of hidden layer.
64 | mlp_dim: dimension of feedforward layer.
65 | num_heads: number of attention heads.
66 | pos_embed: position embedding layer type.
67 | norm_name: feature normalization type and arguments.
68 | conv_block: bool argument to determine if convolutional block is used.
69 | res_block: bool argument to determine if residual block is used.
70 | dropout_rate: faction of the input units to drop.
71 | """
72 | super(VisionTransformer, self).__init__(embed_dim=768,**kwargs)
73 |
74 | #define vit
75 | # MAE encoder specifics
76 | self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
77 | num_patches = self.patch_embed.num_patches
78 |
79 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
80 |
81 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
82 |
83 | self.blocks = nn.ModuleList([
84 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, attn_drop=dropout_rate)
85 | for i in range(depth)])
86 | self.norm = norm_layer(embed_dim)
87 |
88 | #define CNN
89 | self.patch_size = (patch_size, patch_size)
90 | self.feat_size = (
91 | img_size // self.patch_size[0],
92 | img_size // self.patch_size[1])
93 | self.hidden_size = embed_dim
94 | self.classification = False
95 |
96 | self.encoder1 = UnetrBasicBlock(
97 | spatial_dims=2,
98 | in_channels=in_channels,
99 | out_channels=feature_size,
100 | kernel_size=3,
101 | stride=1,
102 | norm_name=norm_name,
103 | res_block=res_block,
104 | )
105 | self.encoder2 = UnetrPrUpBlock(
106 | spatial_dims=2,
107 | in_channels=embed_dim,
108 | out_channels=feature_size * 2,
109 | num_layer=2,
110 | kernel_size=3,
111 | stride=1,
112 | upsample_kernel_size=2,
113 | norm_name=norm_name,
114 | conv_block=conv_block,
115 | res_block=res_block,
116 | )
117 | self.encoder3 = UnetrPrUpBlock(
118 | spatial_dims=2,
119 | in_channels=embed_dim,
120 | out_channels=feature_size * 4,
121 | num_layer=1,
122 | kernel_size=3,
123 | stride=1,
124 | upsample_kernel_size=2,
125 | norm_name=norm_name,
126 | conv_block=conv_block,
127 | res_block=res_block,
128 | )
129 | self.encoder4 = UnetrPrUpBlock(
130 | spatial_dims=2,
131 | in_channels=embed_dim,
132 | out_channels=feature_size * 8,
133 | num_layer=0,
134 | kernel_size=3,
135 | stride=1,
136 | upsample_kernel_size=2,
137 | norm_name=norm_name,
138 | conv_block=conv_block,
139 | res_block=res_block,
140 | )
141 | self.decoder5 = UnetrUpBlock(
142 | spatial_dims=2,
143 | in_channels=embed_dim,
144 | out_channels=feature_size * 8,
145 | kernel_size=3,
146 | upsample_kernel_size=2,
147 | norm_name=norm_name,
148 | res_block=res_block,
149 | )
150 | self.decoder4 = UnetrUpBlock(
151 | spatial_dims=2,
152 | in_channels=feature_size * 8,
153 | out_channels=feature_size * 4,
154 | kernel_size=3,
155 | upsample_kernel_size=2,
156 | norm_name=norm_name,
157 | res_block=res_block,
158 | )
159 | self.decoder3 = UnetrUpBlock(
160 | spatial_dims=2,
161 | in_channels=feature_size * 4,
162 | out_channels=feature_size * 2,
163 | kernel_size=3,
164 | upsample_kernel_size=2,
165 | norm_name=norm_name,
166 | res_block=res_block,
167 | )
168 | self.decoder2 = UnetrUpBlock(
169 | spatial_dims=2,
170 | in_channels=feature_size * 2,
171 | out_channels=feature_size,
172 | kernel_size=3,
173 | upsample_kernel_size=2,
174 | norm_name=norm_name,
175 | res_block=res_block,
176 | )
177 | self.out = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore
178 |
179 | def proj_feat(self, x, hidden_size, feat_size):
180 | x = x.view(x.size(0), feat_size[0], feat_size[1], hidden_size)
181 | x = x.permute(0, 3, 1, 2).contiguous()
182 | return x
183 |
184 | def load_from(self, weights):
185 | with torch.no_grad():
186 | res_weight = weights
187 | # copy weights from patch embedding
188 | for i in weights["state_dict"]:
189 | print(i)
190 | self.vit.patch_embedding.position_embeddings.copy_(
191 | weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"]
192 | )
193 | self.vit.patch_embedding.cls_token.copy_(
194 | weights["state_dict"]["module.transformer.patch_embedding.cls_token"]
195 | )
196 | self.vit.patch_embedding.patch_embeddings[1].weight.copy_(
197 | weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"]
198 | )
199 | self.vit.patch_embedding.patch_embeddings[1].bias.copy_(
200 | weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"]
201 | )
202 |
203 | # copy weights from encoding blocks (default: num of blocks: 12)
204 | for bname, block in self.vit.blocks.named_children():
205 | print(block)
206 | block.loadFrom(weights, n_block=bname)
207 | # last norm layer of transformer
208 | self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"])
209 | self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"])
210 |
211 | def forward(self, x_in):
212 | #x, hidden_states_out = self.vit(x_in)
213 | #define IVT
214 | # embed patches
215 | x = self.patch_embed(x_in)
216 | # add pos embed w/o cls token
217 | x = x + self.pos_embed[:, 1:, :]
218 |
219 | # append cls token
220 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
221 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
222 | x = torch.cat((cls_tokens, x), dim=1)
223 |
224 | # apply Transformer blocks
225 | hidden_states_out = []
226 | for blk in self.blocks:
227 | x = blk(x)
228 | hidden_states_out.append(x[:,1:,:])
229 | x = self.norm(x)
230 | x = x[:,1:,:]
231 |
232 | #define CNN
233 | enc1 = self.encoder1(x_in)
234 | x2 = hidden_states_out[6]
235 | enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size))
236 | x3 = hidden_states_out[12]
237 | enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size))
238 | x4 = hidden_states_out[18]
239 | enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size))
240 |
241 | dec4 = self.proj_feat(x, self.hidden_size, self.feat_size)
242 | dec3 = self.decoder5(dec4, enc4)
243 | dec2 = self.decoder4(dec3, enc3)
244 | dec1 = self.decoder3(dec2, enc2)
245 | out = self.decoder2(dec1, enc1)
246 | logits = self.out(out)
247 |
248 | return logits
249 |
--------------------------------------------------------------------------------
/optimizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/optimizers/__init__.py
--------------------------------------------------------------------------------
/optimizers/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 - 2021 MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | import math
13 | import warnings
14 | from typing import List
15 |
16 | from torch import nn as nn
17 | from torch.optim import Adam, Optimizer
18 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
19 |
20 | __all__ = ["LinearLR", "ExponentialLR"]
21 |
22 |
23 | class _LRSchedulerMONAI(_LRScheduler):
24 | """Base class for increasing the learning rate between two boundaries over a number
25 | of iterations"""
26 |
27 | def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None:
28 | """
29 | Args:
30 | optimizer: wrapped optimizer.
31 | end_lr: the final learning rate.
32 | num_iter: the number of iterations over which the test occurs.
33 | last_epoch: the index of last epoch.
34 | Returns:
35 | None
36 | """
37 | self.end_lr = end_lr
38 | self.num_iter = num_iter
39 | super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch)
40 |
41 |
42 | class LinearLR(_LRSchedulerMONAI):
43 | """Linearly increases the learning rate between two boundaries over a number of
44 | iterations.
45 | """
46 |
47 | def get_lr(self):
48 | r = self.last_epoch / (self.num_iter - 1)
49 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
50 |
51 |
52 | class ExponentialLR(_LRSchedulerMONAI):
53 | """Exponentially increases the learning rate between two boundaries over a number of
54 | iterations.
55 | """
56 |
57 | def get_lr(self):
58 | r = self.last_epoch / (self.num_iter - 1)
59 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
60 |
61 |
62 | class WarmupCosineSchedule(LambdaLR):
63 | """Linear warmup and then cosine decay.
64 | Based on https://huggingface.co/ implementation.
65 | """
66 |
67 | def __init__(
68 | self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1
69 | ) -> None:
70 | """
71 | Args:
72 | optimizer: wrapped optimizer.
73 | warmup_steps: number of warmup iterations.
74 | t_total: total number of training iterations.
75 | cycles: cosine cycles parameter.
76 | last_epoch: the index of last epoch.
77 | Returns:
78 | None
79 | """
80 | self.warmup_steps = warmup_steps
81 | self.t_total = t_total
82 | self.cycles = cycles
83 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch)
84 |
85 | def lr_lambda(self, step):
86 | if step < self.warmup_steps:
87 | return float(step) / float(max(1.0, self.warmup_steps))
88 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
89 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
90 |
91 |
92 | class LinearWarmupCosineAnnealingLR(_LRScheduler):
93 | def __init__(
94 | self,
95 | optimizer: Optimizer,
96 | warmup_epochs: int,
97 | max_epochs: int,
98 | warmup_start_lr: float = 0.0,
99 | eta_min: float = 0.0,
100 | last_epoch: int = -1,
101 | ) -> None:
102 | """
103 | Args:
104 | optimizer (Optimizer): Wrapped optimizer.
105 | warmup_epochs (int): Maximum number of iterations for linear warmup
106 | max_epochs (int): Maximum number of iterations
107 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
108 | eta_min (float): Minimum learning rate. Default: 0.
109 | last_epoch (int): The index of last epoch. Default: -1.
110 | """
111 | self.warmup_epochs = warmup_epochs
112 | self.max_epochs = max_epochs
113 | self.warmup_start_lr = warmup_start_lr
114 | self.eta_min = eta_min
115 |
116 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
117 |
118 | def get_lr(self) -> List[float]:
119 | """
120 | Compute learning rate using chainable form of the scheduler
121 | """
122 | if not self._get_lr_called_within_step:
123 | warnings.warn(
124 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
125 | )
126 |
127 | if self.last_epoch == 0:
128 | return [self.warmup_start_lr] * len(self.base_lrs)
129 | elif self.last_epoch < self.warmup_epochs:
130 | return [
131 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
132 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
133 | ]
134 | elif self.last_epoch == self.warmup_epochs:
135 | return self.base_lrs
136 | elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
137 | return [
138 | group["lr"]
139 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
140 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
141 | ]
142 |
143 | return [
144 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
145 | / (
146 | 1
147 | + math.cos(
148 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
149 | )
150 | )
151 | * (group["lr"] - self.eta_min)
152 | + self.eta_min
153 | for group in self.optimizer.param_groups
154 | ]
155 |
156 | def _get_closed_form_lr(self) -> List[float]:
157 | """
158 | Called when epoch is passed as a param to the `step` function of the scheduler.
159 | """
160 | if self.last_epoch < self.warmup_epochs:
161 | return [
162 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
163 | for base_lr in self.base_lrs
164 | ]
165 |
166 | return [
167 | self.eta_min
168 | + 0.5
169 | * (base_lr - self.eta_min)
170 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
171 | for base_lr in self.base_lrs
172 | ]
173 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python==3.8
2 | torch==1.11.1
3 | numpy==1.19.5
4 | monai==0.7.0
5 | timm==0.3.2
6 | tensorboardX==2.1
7 | torchvision==0.12.0
8 | opencv-python==4.5.5
9 |
--------------------------------------------------------------------------------
/runs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/runs/__init__.py
--------------------------------------------------------------------------------
/runs/log/log.txt:
--------------------------------------------------------------------------------
1 | Final training:0/99,loss:2.14935310681661
2 | Final validation:0/99,dice:0.021508501284427224,loss:1.7313948571681976,
3 | Final training:0/99,loss:2.1603429714838662
4 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SJTU-Intelligent-Optics-Lab/Annotation-efficient-learning-for-OCT-segmentation/d3bd2761765215dbef29eba0639d1ae462ebf1a2/utils/__init__.py
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import os
4 | import cv2
5 | import numpy as np
6 | import torch
7 | from monai import data, transforms
8 | import PIL
9 | import torchvision
10 | from PIL import Image
11 | from torchvision import transforms
12 | from torch.utils.data import Dataset
13 |
14 | class Sampler(torch.utils.data.Sampler):
15 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
16 | if num_replicas is None:
17 | if not torch.distributed.is_available():
18 | raise RuntimeError("Requires distributed package to be available")
19 | num_replicas = torch.distributed.get_world_size()
20 | if rank is None:
21 | if not torch.distributed.is_available():
22 | raise RuntimeError("Requires distributed package to be available")
23 | rank = torch.distributed.get_rank()
24 | self.shuffle = shuffle
25 | self.make_even = make_even
26 | self.dataset = dataset
27 | self.num_replicas = num_replicas
28 | self.rank = rank
29 | self.epoch = 0
30 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
31 | self.total_size = self.num_samples * self.num_replicas
32 | indices = list(range(len(self.dataset)))
33 | self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas])
34 |
35 | def __iter__(self):
36 | if self.shuffle:
37 | g = torch.Generator()
38 | g.manual_seed(self.epoch)
39 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
40 | else:
41 | indices = list(range(len(self.dataset)))
42 | if self.make_even:
43 | if len(indices) < self.total_size:
44 | if self.total_size - len(indices) < len(indices):
45 | indices += indices[: (self.total_size - len(indices))]
46 | else:
47 | extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices))
48 | indices += [indices[ids] for ids in extra_ids]
49 | assert len(indices) == self.total_size
50 | indices = indices[self.rank : self.total_size : self.num_replicas]
51 | self.num_samples = len(indices)
52 | return iter(indices)
53 |
54 | def __len__(self):
55 | return self.num_samples
56 |
57 | def set_epoch(self, epoch):
58 | self.epoch = epoch
59 |
60 |
61 | def get_loader(args):
62 | if args.test_mode:
63 | dataset_test = build_dataset(args.val_text, args=args)
64 | data_loader_val = torch.utils.data.DataLoader(
65 | dataset_test, sampler=None,
66 | batch_size=args.batch_size,
67 | num_workers=args.num_workers,
68 | drop_last=False)
69 | loader = data_loader_val
70 |
71 | else:
72 | dataset_train = build_dataset(args.train_text, args=args)
73 | dataset_val = build_dataset(args.val_text, args=args)
74 | data_loader_train = torch.utils.data.DataLoader(
75 | dataset_train, sampler=None,
76 | batch_size=args.batch_size,
77 | num_workers=args.num_workers,
78 | shuffle = True,
79 | drop_last=False)
80 | data_loader_val = torch.utils.data.DataLoader(
81 | dataset_val, sampler=None,
82 | batch_size=args.batch_size,
83 | num_workers=args.num_workers,
84 | shuffle=False,
85 | drop_last=False)
86 | loader = [data_loader_train, data_loader_val]
87 |
88 | return loader
89 |
90 | def build_dataset(split, args):
91 | transform = build_transform(split, args)
92 |
93 | if 'fewshot' in split:
94 | dataset = Fewshot_dataset(args.label_index, args.data_dir, args.list_dir, split, transform=transform)
95 | else:
96 | dataset = Parent_dataset(args.data_dir, args.list_dir, split, transform=transform)
97 | print(dataset)
98 | return dataset
99 |
100 | def build_transform(split, args):
101 | if 'train' in split:
102 | transform_image = transforms.Compose([
103 | transforms.ToPILImage(),
104 | transforms.Resize((args.input_size,args.input_size), interpolation=PIL.Image.BICUBIC), # 3 is bicubic
105 | transforms.RandomHorizontalFlip(p=0.5),#
106 | transforms.ToTensor(),
107 | transforms.Normalize(mean=[0.227, 0.227, 0.227], std=[0.1935, 0.1935, 0.1935])])
108 |
109 | transform_label = transforms.Compose([
110 | transforms.ToPILImage(),
111 | #transforms.RandomRotation(degrees=(-180, 180)), #
112 | transforms.Resize((args.input_size,args.input_size), interpolation=PIL.Image.NEAREST), # 3 is bicubic
113 | transforms.RandomHorizontalFlip(p=0.5),#
114 | transforms.ToTensor()])
115 | trans = [transform_image, transform_label]
116 | return trans
117 |
118 | elif 'val' in split:
119 | transform_image = transforms.Compose([
120 | transforms.ToPILImage(),
121 | transforms.Resize((args.input_size,args.input_size), interpolation=PIL.Image.BICUBIC), # 3 is bicubic
122 | transforms.ToTensor(),
123 | transforms.Normalize(mean=[0.227, 0.227, 0.227], std=[0.1935, 0.1935, 0.1935])])
124 | #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
125 | transform_label = transforms.Compose([
126 | transforms.ToPILImage(),
127 | transforms.Resize((args.input_size,args.input_size), interpolation=PIL.Image.NEAREST), # 3 is bicubic
128 | transforms.ToTensor()])
129 | trans = [transform_image, transform_label]
130 | return trans
131 |
132 | class Parent_dataset(Dataset):
133 | def __init__(self, data_dir, list_dir, split, transform=None):
134 | self.transform = transform
135 | self.split = split
136 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
137 | self.data_dir = data_dir
138 | def __len__(self):
139 | return len(self.sample_list)
140 |
141 | def __getitem__(self, idx):
142 | slice_name = self.sample_list[idx].strip('\n')
143 | image_path = os.path.join(self.data_dir, self.split, slice_name)
144 | image = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)
145 | h, w =image.shape[0], image.shape[1]
146 | image = np.expand_dims(np.array(image), axis = -1).repeat(3,2)
147 | label_path = os.path.join(self.data_dir, self.split, slice_name.replace('.png','_gt.png'))
148 |
149 | label = cv2.imread(label_path,cv2.IMREAD_GRAYSCALE)
150 | label = np.expand_dims(np.array(label), axis = -1)
151 |
152 | if self.transform:
153 | image = self.transform[0](image)
154 | label = self.transform[1](label)
155 | if self.split == 'val':
156 | return [image, label, image_path, h, w]
157 |
158 | return [image, label]
159 |
160 | class Fewshot_dataset(Dataset):
161 | def __init__(self, label_index, data_dir, list_dir, split, transform=None):
162 | self.transform = transform
163 | self.split = split
164 | self.sample_list = open(os.path.join(list_dir, self.split+'.txt')).readlines()
165 | self.data_dir = data_dir
166 | self.label_index=label_index
167 |
168 | def __len__(self):
169 | return len(self.sample_list)
170 |
171 | def __getitem__(self, idx):
172 | slice_name = self.sample_list[idx].strip('\n')
173 | if 'train' in self.split:
174 | image_path = os.path.join(self.data_dir, 'train_fewshot_data', slice_name)
175 | label_path = os.path.join(self.data_dir, 'train_fewshot_data', slice_name.replace('.png', '_gt.png'))
176 | elif 'val' in self.split:
177 | image_path = os.path.join(self.data_dir, 'val_fewshot_data', slice_name)
178 | label_path = os.path.join(self.data_dir, 'val_fewshot_data', slice_name.replace('.png', '_gt.png'))
179 | else:
180 | raise 'can not find the train or val dict'
181 | image = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)
182 |
183 | h, w =image.shape[0], image.shape[1]
184 | image = np.expand_dims(np.array(image), axis = -1).repeat(3,2)
185 |
186 | label = cv2.imread(label_path,cv2.IMREAD_GRAYSCALE)
187 |
188 | # for i in self.label_index:
189 | # if i<255:
190 | # label[label==i*25] = 255
191 | # label[label!=255] = 0
192 |
193 | label = np.expand_dims(np.array(label), axis = -1)
194 |
195 | if self.transform:
196 | seed = torch.random.seed()
197 | torch.random.manual_seed(seed)
198 | image = self.transform[0](image)
199 | torch.random.manual_seed(seed)
200 | label = self.transform[1](label)
201 |
202 | if 'val' in self.split:
203 | return [image, label, image_path, h, w]
204 |
205 | return [image, label]
206 |
--------------------------------------------------------------------------------
/utils/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | import cv2
5 | import numpy as np
6 | import torch
7 | import torch.nn.parallel
8 | import torch.utils.data.distributed
9 | #from tensorboardX import SummaryWriter
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torch.cuda.amp import GradScaler, autocast
12 | from torch.nn.modules.loss import BCELoss
13 |
14 | class EarlyStopping:
15 | """Early stops the training if validation loss doesn't improve after a given patience."""
16 |
17 | def __init__(self, patience=10, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
18 | """
19 | Args:
20 | patience (int): How long to wait after last time validation loss improved.
21 | Default: 7
22 | verbose (bool): If True, prints a message for each validation loss improvement.
23 | Default: False
24 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
25 | Default: 0
26 | path (str): Path for the checkpoint to be saved to.
27 | Default: 'checkpoint.pt'
28 | trace_func (function): trace print function.
29 | Default: print
30 | """
31 | self.patience = patience
32 | self.verbose = verbose
33 | self.counter = 0
34 | self.best_score = None
35 | self.early_stop = False
36 | self.val_loss_min = np.Inf
37 | self.delta = delta
38 | self.path = path
39 | self.trace_func = trace_func
40 |
41 | def __call__(self, val_loss):
42 |
43 | score = val_loss
44 |
45 | if self.best_score is None:
46 | self.best_score = score
47 | elif score < self.best_score + self.delta:
48 | self.counter += 1
49 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
50 | if self.counter >= self.patience:
51 | self.early_stop = True
52 | else:
53 | self.best_score = score
54 | self.counter = 0
55 |
56 |
57 | class DiceLoss(torch.nn.Module):
58 | def __init__(self, n_classes):
59 | super(DiceLoss, self).__init__()
60 | self.n_classes = n_classes
61 |
62 | def _one_hot_encoder(self, input_tensor):
63 | tensor_list = []
64 | for i in range(self.n_classes):
65 | temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
66 | tensor_list.append(temp_prob.unsqueeze(1))
67 | output_tensor = torch.cat(tensor_list, dim=1)
68 | return output_tensor.float()
69 |
70 | def _dice_loss(self, score, target):
71 | target = target.float()
72 | smooth = 1e-5
73 | intersect = torch.sum(score * target)
74 | y_sum = torch.sum(target * target)
75 | z_sum = torch.sum(score * score)
76 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
77 | loss = 1 - loss
78 | return loss
79 |
80 | def forward(self, inputs, target, weight=None, softmax=False):
81 | if softmax:
82 | inputs = torch.softmax(inputs, dim=1)
83 | else:
84 | inputs = torch.sigmoid(inputs)
85 | #target = self._one_hot_encoder(target)
86 | if weight is None:
87 | weight = [1] * self.n_classes
88 | assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
89 | class_wise_dice = []
90 | loss = 0.0
91 | for i in range(0, self.n_classes):
92 | dice = self._dice_loss(inputs[:, i], target[:, i])
93 | class_wise_dice.append(1.0 - dice.item())
94 | loss += dice * weight[i]
95 | return loss / self.n_classes
96 |
97 | def dice(x, y):
98 | intersect = np.sum(np.sum(np.sum(x * y)))
99 | y_sum = np.sum(np.sum(np.sum(y)))
100 | smooth = 1e-5
101 | if y_sum == 0:
102 | return 0.0
103 | x_sum = np.sum(np.sum(np.sum(x)))
104 | return (2 * intersect +smooth) / (x_sum + y_sum + smooth)
105 |
106 |
107 | class AverageMeter(object):
108 | def __init__(self):
109 | self.reset()
110 |
111 | def reset(self):
112 | self.val = 0
113 | self.avg = 0
114 | self.sum = 0
115 | self.count = 0
116 |
117 | def update(self, val, n=1):
118 | self.val = val
119 | self.sum += val * n
120 | self.count += n
121 | self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)
122 |
123 |
124 | def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args):
125 | model.train()
126 |
127 | ce_loss = torch.nn.BCEWithLogitsLoss()
128 | dice_loss = DiceLoss(1)
129 | start_time = time.time()
130 | run_loss = AverageMeter()
131 | for idx, batch_data in enumerate(loader):
132 |
133 | if isinstance(batch_data, list):
134 | data, target = batch_data
135 | else:
136 | data, target = batch_data["image"], batch_data["label"]
137 |
138 | if args.in_channels == 1:
139 | data = torch.unsqueeze(data[:, 0, :, :], dim=1)
140 |
141 | data, target = data.cuda(args.rank), target.cuda(args.rank)
142 | for param in model.parameters():
143 | param.grad = None
144 |
145 | with autocast(enabled=args.amp):
146 | logits = model(data)
147 |
148 | loss_ce = ce_loss(logits, target)
149 | loss_dice = dice_loss(logits, target, softmax=False)
150 | loss = 1.0 * loss_ce + 1.0 * loss_dice
151 | print('ce_loss: {:.3f} dice_loss: {:.3f}'.format(loss_ce, loss_dice))
152 |
153 | if args.amp:
154 | scaler.scale(loss).backward()
155 | scaler.step(optimizer)
156 | scaler.update()
157 | else:
158 | loss.backward()
159 | optimizer.step()
160 |
161 | run_loss.update(loss.item(), n=args.batch_size)
162 | if args.rank == 0:
163 | print(
164 | "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
165 | "loss: {:.4f}".format(loss.item()),
166 | "lr: {:.8f}".format(optimizer.param_groups[0]['lr']),
167 | "time {:.2f}s".format(time.time() - start_time),
168 | )
169 | start_time = time.time()
170 | for param in model.parameters():
171 | param.grad = None
172 |
173 | return run_loss.avg
174 |
175 |
176 | def val_epoch(model, loader, epoch=None, acc_func=None, args=None, model_inferer=None, post_label=None, post_pred=None):
177 | model.eval()
178 | start_time = time.time()
179 | run_acc = AverageMeter()
180 | ce_loss = torch.nn.BCEWithLogitsLoss()
181 | dice_loss = DiceLoss(1)
182 | run_loss = AverageMeter()
183 | with torch.no_grad():
184 | for idx, batch_data in enumerate(loader):
185 | if isinstance(batch_data, list):
186 | data, target, image_path, h, w = batch_data
187 | else:
188 | data, target = batch_data["image"], batch_data["label"]
189 |
190 | if args.in_channels == 1:
191 | data = torch.unsqueeze(data[:, 0, :, :], dim=1) # 三通道改单通道
192 |
193 | data, target = data.cuda(args.rank), target.cuda(args.rank)
194 | with autocast(enabled=args.amp):
195 | if model_inferer is not None:
196 | logits = model_inferer(data)
197 | else:
198 | logits = model(data)
199 | #loss
200 | loss_ce = ce_loss(logits, target)
201 | loss_dice = dice_loss(logits, target, softmax=False)
202 | loss = 1.0 * loss_ce + 1.0 * loss_dice
203 | #print('ce_loss: {:.3f} dice_loss: {:.3f}'.format(loss_ce, loss_dice))
204 | #sigmoid for dice
205 | out = torch.sigmoid(logits)
206 | acc_list = []
207 | if out.is_cuda:
208 | target = target.cpu().numpy()
209 | out = out.cpu().detach().numpy()
210 | h = h.cpu().numpy()
211 | w = w.cpu().numpy()
212 | assert out.shape == target.shape, 'predict {} & target {} shape do not match'.format(out.shape, target.shape)
213 | for i in range(out.shape[0]):
214 | out[i] = np.where(out[i]>0.5, 1, 0)
215 |
216 | acc_list.append(dice(out[i],target[i]))
217 | pre = np.array(out[i],dtype=np.uint8)
218 | pre = np.squeeze(pre,axis=0)
219 | pre = cv2.resize(pre,(w[i], h[i]),interpolation=cv2.INTER_NEAREST)*255
220 | #postprocess
221 | # pre[0:1,:]=0
222 | # pre[:, 0:1] = 0
223 | # pre = fillHole(pre)
224 | # pre = save_max_objects(pre)
225 |
226 |
227 | pre_path = image_path[i].replace('val_fewshot_data', 'val_fewshot_results')
228 | if not os.path.isdir(os.path.split(pre_path)[0]):
229 | os.makedirs(os.path.split(pre_path)[0])
230 |
231 | path=pre_path.replace('.png', '_auto.png')
232 | cv2.imwrite(path, pre)
233 |
234 |
235 | avg_acc = np.mean(np.array(acc_list))
236 | run_acc.update(avg_acc, n=args.batch_size)
237 | run_loss.update(loss.item(), n=args.batch_size)
238 | if args.rank == 0:
239 | print(
240 | "Val {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
241 | "dice:", avg_acc,
242 | 'loss:',loss.cpu().numpy(),
243 | "time {:.2f}s".format(time.time() - start_time),
244 | )
245 | start_time = time.time()
246 | return run_acc.avg, run_loss.avg, os.path.split(pre_path)[0]
247 |
248 |
249 | def save_checkpoint(model, epoch, args, filename="model.pth", best_acc=0, optimizer=None, scheduler=None):
250 | state_dict = model.state_dict() if not args.distributed else model.module.state_dict()
251 | save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
252 | if optimizer is not None:
253 | save_dict["optimizer"] = optimizer.state_dict()
254 | if scheduler is not None:
255 | save_dict["scheduler"] = scheduler.state_dict()
256 | filename = os.path.join(args.logdir, filename)
257 | torch.save(save_dict, filename)
258 | print("Saving checkpoint", filename)
259 |
260 |
261 | def run_training(
262 | model,
263 | train_loader,
264 | val_loader,
265 | optimizer,
266 | loss_func,
267 | acc_func,
268 | args,
269 | model_inferer=None,
270 | scheduler=None,
271 | start_epoch=0,
272 | post_label=None,
273 | post_pred=None,
274 | ):
275 | early_stopping = EarlyStopping(patience=10, verbose=True)
276 | spend_time = 0
277 | writer = None
278 | if args.logdir is not None and args.rank == 0:
279 | writer = SummaryWriter(log_dir=args.logdir)
280 | if args.rank == 0:
281 | print("Writing Tensorboard logs to ", args.logdir)
282 | scaler = None
283 | if args.amp:
284 | scaler = GradScaler()#using float16 to reduce memory
285 | val_acc_max = 0.0
286 | for epoch in range(start_epoch, args.max_epochs):
287 |
288 | print(args.rank, time.ctime(), "Epoch:", epoch)
289 | epoch_time = time.time()
290 | train_loss = train_epoch(
291 | model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args
292 | )#for training one epoch
293 |
294 | # if scheduler is not None:
295 | # scheduler.step()
296 |
297 | spend_time += time.time() - epoch_time
298 | if args.rank == 0:
299 | print(
300 | "Final training {}/{}".format(epoch, args.max_epochs - 1),
301 | "loss: {:.4f}".format(train_loss),
302 | "time {:.2f}s".format(time.time() - epoch_time),
303 | )
304 | with open(os.path.join(args.logdir, "log.txt"), mode="a", encoding="utf-8") as f:
305 | f.write("Final training:{}/{},".format(epoch, args.max_epochs - 1) + "loss:{}".format(train_loss) + "\n")
306 | if args.rank == 0 and writer is not None:
307 | writer.add_scalar("train_loss", train_loss, epoch)
308 | b_new_best = False
309 | if (epoch + 1) % args.val_every == 0:
310 | if args.distributed:
311 | torch.distributed.barrier()
312 | epoch_time = time.time()
313 | val_avg_acc, run_loss, file_path = val_epoch(
314 | model,
315 | val_loader,
316 | epoch=epoch,
317 | acc_func=acc_func,
318 | model_inferer=model_inferer,
319 | args=args,
320 | post_label=post_label,
321 | post_pred=post_pred,
322 | )
323 | if args.rank == 0:
324 | print(
325 | "Final validation {}/{}".format(epoch, args.max_epochs - 1),
326 | "acc:", val_avg_acc,
327 | 'std:',
328 | 'loss:',run_loss,
329 | "time {:.2f}s".format(time.time() - epoch_time),
330 | )
331 | with open(os.path.join(args.logdir, "log.txt"), mode="a", encoding="utf-8") as f:
332 | f.write("Final validation:{}/{},".format(epoch, args.max_epochs - 1) + "dice:{},".format(val_avg_acc)
333 | + "loss:{},".format(run_loss)+ "\n")
334 | if writer is not None:
335 | writer.add_scalar("val_acc", val_avg_acc, epoch)
336 | if val_avg_acc > val_acc_max:
337 | print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc))
338 | val_acc_max = val_avg_acc
339 | b_new_best = True
340 | #fewshot save bset results
341 | if os.path.exists(file_path.replace('val_fewshot_results','val_fewshot_results_best')):
342 | shutil.rmtree(file_path.replace('val_fewshot_results','val_fewshot_results_best'))
343 | shutil.copytree(file_path,file_path.replace('val_fewshot_results','val_fewshot_results_best'))
344 |
345 | #unet save best results
346 | if os.path.exists('./dataset/val_fewshot_results_best'):
347 | shutil.rmtree('./dataset/val_fewshot_results_best')
348 | shutil.copytree('./dataset/val_fewshot_results',
349 | './dataset/val_fewshot_results_best')
350 |
351 | #save weights
352 | if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
353 | save_checkpoint(
354 | model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler
355 | )
356 | if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
357 | save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pth")
358 | if b_new_best:
359 | print("Copying to model.pt new best model!!!!")
360 | shutil.copyfile(os.path.join(args.logdir, "model_final.pth"), os.path.join(args.logdir, "model.pth"))
361 |
362 | early_stopping(val_avg_acc)
363 | if early_stopping.early_stop:
364 | print("Early stop!")
365 | break
366 |
367 | if scheduler is not None:
368 | scheduler.step(-val_avg_acc)
369 |
370 | print("Training Finished !, Best Accuracy: ", val_acc_max, "Total time: {} s.".format(round(spend_time)))
371 |
372 | return val_acc_max
373 |
374 | from skimage import measure
375 | def save_max_objects(img):
376 | labels = measure.label(img, connectivity=1)
377 | jj = measure.regionprops(labels)
378 | # is_del = False
379 | if len(jj) == 0:
380 | out = img
381 | return out
382 | elif len(jj) == 1:
383 | out = img
384 | return out
385 | # is_del = False
386 | else:
387 | num = labels.max()
388 | del_array = np.array([0] * (num + 1))
389 | for k in range(num):
390 | if k == 0:
391 | initial_area = jj[0].area
392 | save_index = 1
393 | else:
394 | k_area = jj[k].area
395 |
396 | if initial_area < k_area:
397 | initial_area = k_area
398 | save_index = k + 1
399 |
400 | del_array[save_index] = 1
401 | del_mask = del_array[labels]
402 | out = img * del_mask
403 | return out
404 |
405 |
406 | def fillHole(im_in):
407 | im_in = im_in.astype(np.uint8)
408 | # print np.unique(im_in)
409 | im_floodfill = im_in.copy()
410 | # Mask used to flood filling.
411 | # Notice the size needs to be 2 pixels than the image.
412 | h, w = im_in.shape[:2]
413 | mask = np.zeros((h + 2, w + 2), np.uint8)
414 |
415 | # Floodfill from point (0, 0)
416 | cv2.floodFill(im_floodfill, mask, (0, 0), 255)
417 |
418 | # Invert floodfilled image
419 | im_floodfill_inv = cv2.bitwise_not(im_floodfill)
420 |
421 | # Combine the two images to get the foreground.
422 | im_out = im_in | im_floodfill_inv
423 | # print np.unique(im_out)
424 | return im_out
425 |
--------------------------------------------------------------------------------