├── .commitlintrc.yml
├── .drone.yml
├── .gitignore
├── .orange-ci.yml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── datasets
└── brains18.py
├── images
├── efficiency.gif
└── logo.png
├── model.py
├── models
└── resnet.py
├── requirements.txt
├── setting.py
├── test.py
├── test_ci.py
├── toy_data
├── MRBrainS18
│ ├── images
│ │ └── 070.nii.gz
│ ├── labels
│ │ └── 070.nii.gz
│ └── test_ci.txt
└── test_ci.txt
├── train.py
└── utils
├── file_process.py
└── logger.py
/.commitlintrc.yml:
--------------------------------------------------------------------------------
1 | extends:
2 | - "@commitlint/config-conventional"
3 |
--------------------------------------------------------------------------------
/.drone.yml:
--------------------------------------------------------------------------------
1 | pipeline:
2 | build:
3 | name: testing phase
4 | image: cshwhale/dockerfiles:latest
5 | commands:
6 | - python train.py --ci_test
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # pipenv
86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
89 | # install all needed dependencies.
90 | #Pipfile.lock
91 |
92 | # celery beat schedule file
93 | celerybeat-schedule
94 |
95 | # SageMath parsed files
96 | *.sage.py
97 |
98 | # Environments
99 | .env
100 | .venv
101 | env/
102 | venv/
103 | ENV/
104 | env.bak/
105 | venv.bak/
106 |
107 | # Spyder project settings
108 | .spyderproject
109 | .spyproject
110 |
111 | # Rope project settings
112 | .ropeproject
113 |
114 | # mkdocs documentation
115 | /site
116 |
117 | # mypy
118 | .mypy_cache/
119 | .dmypy.json
120 | dmypy.json
121 |
122 | # Pyre type checker
123 | .pyre/
124 |
125 | # My configurations:
126 | data/**/*.nii.gz
127 | data/**/*.txt
128 | pretrain/**/*.pth
129 | trails/**/*.pth
130 |
--------------------------------------------------------------------------------
/.orange-ci.yml:
--------------------------------------------------------------------------------
1 | master:
2 | merge_request:
3 | - stages:
4 | - name: make commitlist
5 | type: git:commitList
6 | options:
7 | toFile: commits-data.json
8 | - name: do commitlint
9 | image: csighub.tencentyun.com/plugins/commitlint
10 | settings:
11 | from_file: commits-data.json
12 | push:
13 | - network: idc-ai-sse4
14 | stages:
15 | - name: testing phase
16 | image: cshwhale/dockerfiles:latest
17 | commands:
18 | - python train.py --ci_test
19 |
20 | $:
21 | tag_push:
22 | - stages:
23 | - name: changelog
24 | type: git:changeLog
25 | options:
26 | filename: CHANGELOG.md
27 | target: master
28 | envExport:
29 | latestChangeLog: LATEST_CHANGE_LOG
30 | - name: upload release
31 | type: git:release
32 | options:
33 | description: ${LATEST_CHANGE_LOG}
34 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/CONTRIBUTING.md
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Tencent is pleased to support the open source community by making MedicalNet available.
2 |
3 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 |
5 | MedicalNet is licensed under the MIT License, including the third-party component listed below.
6 |
7 | A copy of the MIT License is included in this file.
8 |
9 | Other dependency and license:
10 |
11 |
12 | Open Source Software Licensed Under the MIT License:
13 | --------------------------------------------------------------------
14 | 1. 3D-ResNets-PyTorch 3.0
15 | Copyright (c) 2017 Kensho Hara
16 |
17 |
18 | Terms of the MIT License:
19 | ---------------------------------------------------
20 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
21 |
22 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
23 |
24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # MedicalNet
5 | This repository contains a Pytorch implementation of [Med3D: Transfer Learning for 3D Medical Image Analysis](https://arxiv.org/abs/1904.00625).
6 | Many studies have shown that the performance on deep learning is significantly affected by volume of training data. The MedicalNet project aggregated the dataset with diverse modalities, target organs, and pathologies to to build relatively large datasets. Based on this dataset, a series of 3D-ResNet pre-trained models and corresponding transfer-learning training code are provided.
7 |
8 | ### License
9 | MedicalNet is released under the MIT License (refer to the LICENSE file for detailso).
10 |
11 | ### Citing MedicalNet
12 | If you use this code or pre-trained models, please cite the following:
13 | ```
14 | @article{chen2019med3d,
15 | title={Med3D: Transfer Learning for 3D Medical Image Analysis},
16 | author={Chen, Sihong and Ma, Kai and Zheng, Yefeng},
17 | journal={arXiv preprint arXiv:1904.00625},
18 | year={2019}
19 | }
20 | ```
21 | ### Update(2019/07/30)
22 | We uploaded 4 pre-trained models based on more datasets (23 datasets).
23 | ```
24 | Model name : parameters settings
25 | resnet_10_23dataset.pth: --model resnet --model_depth 10 --resnet_shortcut B
26 | resnet_18_23dataset.pth: --model resnet --model_depth 18 --resnet_shortcut A
27 | resnet_34_23dataset.pth: --model resnet --model_depth 34 --resnet_shortcut A
28 | resnet_50_23dataset.pth: --model resnet --model_depth 50 --resnet_shortcut B
29 | ```
30 | We transferred the above pre-trained models to the multi-class segmentation task (left lung, right lung and background) on Visceral dataset. The results are as follows:
31 |
32 |
33 | Network |
34 | Pretrain |
35 | LungSeg(Dice) |
36 |
37 |
38 | 3D-ResNet10 |
39 | Train from scratch |
40 | 69.31% |
41 |
42 |
43 | MedicalNet |
44 | 96.56% |
45 |
46 |
47 | 3D-ResNet18 |
48 | Train from scratch |
49 | 70.89% |
50 |
51 |
52 | MedicalNet |
53 | 94.68% |
54 |
55 |
56 | 3D-ResNet34 |
57 | Train from scratch |
58 | 75.25% |
59 |
60 |
61 | MedicalNet |
62 | 94.14% |
63 |
64 |
65 | 3D-ResNet50 |
66 | Train from scratch |
67 | 52.94% |
68 |
69 |
70 | MedicalNet |
71 | 89.25% |
72 |
73 |
74 |
75 |
76 | ### Contents
77 | 1. [Requirements](#Requirements)
78 | 2. [Installation](#Installation)
79 | 3. [Demo](#Demo)
80 | 4. [Experiments](#Experiments)
81 | 5. [TODO](#TODO)
82 | 6. [Acknowledgement](#Acknowledgement)
83 |
84 | ### Requirements
85 | - Python 3.7.0
86 | - PyTorch-0.4.1
87 | - CUDA Version 9.0
88 | - CUDNN 7.0.5
89 |
90 | ### Installation
91 | - Install Python 3.7.0
92 | - pip install -r requirements.txt
93 |
94 |
95 | ### Demo
96 | - Structure of data directories
97 | ```
98 | MedicalNet is used to transfer the pre-trained model to other datasets (here the MRBrainS18 dataset is used as an example).
99 | MedicalNet/
100 | |--datasets/:Data preprocessing module
101 | | |--brains18.py:MRBrainS18 data preprocessing script
102 | | |--models/:Model construction module
103 | | |--resnet.py:3D-ResNet network build script
104 | |--utils/:tools
105 | | |--logger.py:Logging script
106 | |--toy_data/:For CI test
107 | |--data/:Data storage module
108 | | |--MRBrainS18/:MRBrainS18 dataset
109 | | | |--images/:source image named with patient ID
110 | | | |--labels/:mask named with patient ID
111 | | |--train.txt: training data lists
112 | | |--val.txt: validation data lists
113 | |--pretrain/:Pre-trained models storage module
114 | |--model.py: Network processing script
115 | |--setting.py: Parameter setting script
116 | |--train.py: MRBrainS18 training demo script
117 | |--test.py: MRBrainS18 testing demo script
118 | |--requirement.txt: Dependent library list
119 | |--README.md
120 | ```
121 |
122 | - Network structure parameter settings
123 | ```
124 | Model name : parameters settings
125 | resnet_10.pth: --model resnet --model_depth 10 --resnet_shortcut B
126 | resnet_18.pth: --model resnet --model_depth 18 --resnet_shortcut A
127 | resnet_34.pth: --model resnet --model_depth 34 --resnet_shortcut A
128 | resnet_50.pth: --model resnet --model_depth 50 --resnet_shortcut B
129 | resnet_101.pth: --model resnet --model_depth 101 --resnet_shortcut B
130 | resnet_152.pth: --model resnet --model_depth 152 --resnet_shortcut B
131 | resnet_200.pth: --model resnet --model_depth 200 --resnet_shortcut B
132 | ```
133 |
134 | - After successfully completing basic installation, you'll be ready to run the demo.
135 | 1. Clone the MedicalNet repository
136 | ```
137 | git clone https://github.com/Tencent/MedicalNet
138 | ```
139 | 2. Download data & pre-trained models ([Google Drive](https://drive.google.com/file/d/13tnSvXY7oDIEloNFiGTsjUIYfS3g3BfG/view?usp=sharing) or [Tencent Weiyun](https://share.weiyun.com/55sZyIx))
140 |
141 | Unzip and move files
142 | ```
143 | mv MedicalNet_pytorch_files.zip MedicalNet/.
144 | cd MedicalNet
145 | unzip MedicalNet_pytorch_files.zip
146 | ```
147 | 3. Run the training code (e.g. 3D-ResNet-50)
148 | ```
149 | python train.py --gpu_id 0 1 # multi-gpu training on gpu 0,1
150 | or
151 | python train.py --gpu_id 0 # single-gpu training on gpu 0
152 | ```
153 | 4. Run the testing code (e.g. 3D-ResNet-50)
154 | ```
155 | python test.py --gpu_id 0 --resume_path trails/models/resnet_50_epoch_110_batch_0.pth.tar --img_list data/val.txt
156 | ```
157 |
158 | ### Experiments
159 | - Computational Cost
160 | ```
161 | GPU:NVIDIA Tesla P40
162 | ```
163 |
164 |
165 | Network |
166 | Paramerers (M) |
167 | Running time (s) |
168 |
169 |
170 | 3D-ResNet10 |
171 | 14.36 |
172 | 0.18 |
173 |
174 |
175 | 3D-ResNet18 |
176 | 32.99 |
177 | 0.19 |
178 |
179 |
180 | 3D-ResNet34 |
181 | 63.31 |
182 | 0.22 |
183 |
184 |
185 | 3D-ResNet50 |
186 | 46.21 |
187 | 0.21 |
188 |
189 |
190 | 3D-ResNet101 |
191 | 85.31 |
192 | 0.29 |
193 |
194 |
195 | 3D-ResNet152 |
196 | 117.51 |
197 | 0.34 |
198 |
199 |
200 | 3D-ResNet200 |
201 | 126.74 |
202 | 0.45 |
203 |
204 |
205 |
206 | - Performance
207 | ```
208 | Visualization of the segmentation results of our approach vs. the comparison ones after the same training epochs.
209 | It has demonstrated that the efficiency for training convergence and accuracy based on our MedicalNet pre-trained models.
210 | ```
211 |
212 |
213 |
214 | ```
215 | Results of transfer MedicalNet pre-trained models to lung segmentation (LungSeg) and pulmonary nodule classification (NoduleCls) with Dice and accuracy evaluation metrics, respectively.
216 | ```
217 |
218 |
219 | Network |
220 | Pretrain |
221 | LungSeg(Dice) |
222 | NoduleCls(accuracy) |
223 |
224 |
225 | 3D-ResNet10 |
226 | Train from scratch |
227 | 71.30% |
228 | 79.80% |
229 |
230 |
231 | MedicalNet |
232 | 87.16% |
233 | 86.87% |
234 |
235 |
236 | 3D-ResNet18 |
237 | Train from scratch |
238 | 75.22% |
239 | 80.80% |
240 |
241 |
242 | MedicalNet |
243 | 87.26% |
244 | 88.89% |
245 |
246 |
247 | 3D-ResNet34 |
248 | Train from scratch |
249 | 76.82% |
250 | 83.84% |
251 |
252 |
253 | MedicalNet |
254 | 89.31% |
255 | 89.90% |
256 |
257 |
258 | 3D-ResNet50 |
259 | Train from scratch |
260 | 71.75% |
261 | 84.85% |
262 |
263 |
264 | MedicalNet |
265 | 93.31% |
266 | 89.90% |
267 |
268 |
269 | 3D-ResNet101 |
270 | Train from scratch |
271 | 72.10% |
272 | 81.82% |
273 |
274 |
275 | MedicalNet |
276 | 92.79% |
277 | 90.91% |
278 |
279 |
280 | 3D-ResNet152 |
281 | Train from scratch |
282 | 73.29% |
283 | 73.74% |
284 |
285 |
286 | MedicalNet |
287 | 92.33% |
288 | 90.91% |
289 |
290 |
291 | 3D-ResNet200 |
292 | Train from scratch |
293 | 71.29% |
294 | 76.77% |
295 |
296 |
297 | MedicalNet |
298 | 92.06% |
299 | 90.91% |
300 |
301 |
302 |
303 | - Please refer to [Med3D: Transfer Learning for 3D Medical Image Analysis](https://arxiv.org/abs/1904.00625) for more details:
304 |
305 | ### TODO
306 | - [x] 3D-ResNet series pre-trained models
307 | - [x] Transfer learning training code
308 | - [x] Training with multi-gpu
309 | - [ ] 3D efficient pre-trained models(e.g., 3D-MobileNet, 3D-ShuffleNet)
310 | - [ ] 2D medical pre-trained models
311 | - [x] Pre-trained MedicalNet models based on more medical dataset
312 |
313 | ### Acknowledgement
314 | We thank [3D-ResNets-PyTorch](https://github.com/kenshohara/3D-ResNets-PyTorch) and [MRBrainS18](https://mrbrains18.isi.uu.nl/) which we build MedicalNet refer to this releasing code and the dataset.
315 |
316 | ### Contribution
317 | If you want to contribute to MedicalNet, be sure to review the [contribution guidelines](https://github.com/Tencent/MedicalNet/blob/master/CONTRIBUTING.md)
318 |
--------------------------------------------------------------------------------
/datasets/brains18.py:
--------------------------------------------------------------------------------
1 | '''
2 | Dataset for training
3 | Written by Whalechen
4 | '''
5 |
6 | import math
7 | import os
8 | import random
9 |
10 | import numpy as np
11 | from torch.utils.data import Dataset
12 | import nibabel
13 | from scipy import ndimage
14 |
15 | class BrainS18Dataset(Dataset):
16 |
17 | def __init__(self, root_dir, img_list, sets):
18 | with open(img_list, 'r') as f:
19 | self.img_list = [line.strip() for line in f]
20 | print("Processing {} datas".format(len(self.img_list)))
21 | self.root_dir = root_dir
22 | self.input_D = sets.input_D
23 | self.input_H = sets.input_H
24 | self.input_W = sets.input_W
25 | self.phase = sets.phase
26 |
27 | def __nii2tensorarray__(self, data):
28 | [z, y, x] = data.shape
29 | new_data = np.reshape(data, [1, z, y, x])
30 | new_data = new_data.astype("float32")
31 |
32 | return new_data
33 |
34 | def __len__(self):
35 | return len(self.img_list)
36 |
37 | def __getitem__(self, idx):
38 |
39 | if self.phase == "train":
40 | # read image and labels
41 | ith_info = self.img_list[idx].split(" ")
42 | img_name = os.path.join(self.root_dir, ith_info[0])
43 | label_name = os.path.join(self.root_dir, ith_info[1])
44 | assert os.path.isfile(img_name)
45 | assert os.path.isfile(label_name)
46 | img = nibabel.load(img_name) # We have transposed the data from WHD format to DHW
47 | assert img is not None
48 | mask = nibabel.load(label_name)
49 | assert mask is not None
50 |
51 | # data processing
52 | img_array, mask_array = self.__training_data_process__(img, mask)
53 |
54 | # 2 tensor array
55 | img_array = self.__nii2tensorarray__(img_array)
56 | mask_array = self.__nii2tensorarray__(mask_array)
57 |
58 | assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
59 | return img_array, mask_array
60 |
61 | elif self.phase == "test":
62 | # read image
63 | ith_info = self.img_list[idx].split(" ")
64 | img_name = os.path.join(self.root_dir, ith_info[0])
65 | print(img_name)
66 | assert os.path.isfile(img_name)
67 | img = nibabel.load(img_name)
68 | assert img is not None
69 |
70 | # data processing
71 | img_array = self.__testing_data_process__(img)
72 |
73 | # 2 tensor array
74 | img_array = self.__nii2tensorarray__(img_array)
75 |
76 | return img_array
77 |
78 |
79 | def __drop_invalid_range__(self, volume, label=None):
80 | """
81 | Cut off the invalid area
82 | """
83 | zero_value = volume[0, 0, 0]
84 | non_zeros_idx = np.where(volume != zero_value)
85 |
86 | [max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
87 | [min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)
88 |
89 | if label is not None:
90 | return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w]
91 | else:
92 | return volume[min_z:max_z, min_h:max_h, min_w:max_w]
93 |
94 |
95 | def __random_center_crop__(self, data, label):
96 | from random import random
97 | """
98 | Random crop
99 | """
100 | target_indexs = np.where(label>0)
101 | [img_d, img_h, img_w] = data.shape
102 | [max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1)
103 | [min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1)
104 | [target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W])
105 | Z_min = int((min_D - target_depth*1.0/2) * random())
106 | Y_min = int((min_H - target_height*1.0/2) * random())
107 | X_min = int((min_W - target_width*1.0/2) * random())
108 |
109 | Z_max = int(img_d - ((img_d - (max_D + target_depth*1.0/2)) * random()))
110 | Y_max = int(img_h - ((img_h - (max_H + target_height*1.0/2)) * random()))
111 | X_max = int(img_w - ((img_w - (max_W + target_width*1.0/2)) * random()))
112 |
113 | Z_min = np.max([0, Z_min])
114 | Y_min = np.max([0, Y_min])
115 | X_min = np.max([0, X_min])
116 |
117 | Z_max = np.min([img_d, Z_max])
118 | Y_max = np.min([img_h, Y_max])
119 | X_max = np.min([img_w, X_max])
120 |
121 | Z_min = int(Z_min)
122 | Y_min = int(Y_min)
123 | X_min = int(X_min)
124 |
125 | Z_max = int(Z_max)
126 | Y_max = int(Y_max)
127 | X_max = int(X_max)
128 |
129 | return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max]
130 |
131 |
132 |
133 | def __itensity_normalize_one_volume__(self, volume):
134 | """
135 | normalize the itensity of an nd volume based on the mean and std of nonzeor region
136 | inputs:
137 | volume: the input nd volume
138 | outputs:
139 | out: the normalized nd volume
140 | """
141 |
142 | pixels = volume[volume > 0]
143 | mean = pixels.mean()
144 | std = pixels.std()
145 | out = (volume - mean)/std
146 | out_random = np.random.normal(0, 1, size = volume.shape)
147 | out[volume == 0] = out_random[volume == 0]
148 | return out
149 |
150 | def __resize_data__(self, data):
151 | """
152 | Resize the data to the input size
153 | """
154 | [depth, height, width] = data.shape
155 | scale = [self.input_D*1.0/depth, self.input_H*1.0/height, self.input_W*1.0/width]
156 | data = ndimage.interpolation.zoom(data, scale, order=0)
157 |
158 | return data
159 |
160 |
161 | def __crop_data__(self, data, label):
162 | """
163 | Random crop with different methods:
164 | """
165 | # random center crop
166 | data, label = self.__random_center_crop__ (data, label)
167 |
168 | return data, label
169 |
170 | def __training_data_process__(self, data, label):
171 | # crop data according net input size
172 | data = data.get_data()
173 | label = label.get_data()
174 |
175 | # drop out the invalid range
176 | data, label = self.__drop_invalid_range__(data, label)
177 |
178 | # crop data
179 | data, label = self.__crop_data__(data, label)
180 |
181 | # resize data
182 | data = self.__resize_data__(data)
183 | label = self.__resize_data__(label)
184 |
185 | # normalization datas
186 | data = self.__itensity_normalize_one_volume__(data)
187 |
188 | return data, label
189 |
190 |
191 | def __testing_data_process__(self, data):
192 | # crop data according net input size
193 | data = data.get_data()
194 |
195 | # resize data
196 | data = self.__resize_data__(data)
197 |
198 | # normalization datas
199 | data = self.__itensity_normalize_one_volume__(data)
200 |
201 | return data
202 |
--------------------------------------------------------------------------------
/images/efficiency.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/images/efficiency.gif
--------------------------------------------------------------------------------
/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/images/logo.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from models import resnet
4 |
5 |
6 | def generate_model(opt):
7 | assert opt.model in [
8 | 'resnet'
9 | ]
10 |
11 | if opt.model == 'resnet':
12 | assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
13 |
14 | if opt.model_depth == 10:
15 | model = resnet.resnet10(
16 | sample_input_W=opt.input_W,
17 | sample_input_H=opt.input_H,
18 | sample_input_D=opt.input_D,
19 | shortcut_type=opt.resnet_shortcut,
20 | no_cuda=opt.no_cuda,
21 | num_seg_classes=opt.n_seg_classes)
22 | elif opt.model_depth == 18:
23 | model = resnet.resnet18(
24 | sample_input_W=opt.input_W,
25 | sample_input_H=opt.input_H,
26 | sample_input_D=opt.input_D,
27 | shortcut_type=opt.resnet_shortcut,
28 | no_cuda=opt.no_cuda,
29 | num_seg_classes=opt.n_seg_classes)
30 | elif opt.model_depth == 34:
31 | model = resnet.resnet34(
32 | sample_input_W=opt.input_W,
33 | sample_input_H=opt.input_H,
34 | sample_input_D=opt.input_D,
35 | shortcut_type=opt.resnet_shortcut,
36 | no_cuda=opt.no_cuda,
37 | num_seg_classes=opt.n_seg_classes)
38 | elif opt.model_depth == 50:
39 | model = resnet.resnet50(
40 | sample_input_W=opt.input_W,
41 | sample_input_H=opt.input_H,
42 | sample_input_D=opt.input_D,
43 | shortcut_type=opt.resnet_shortcut,
44 | no_cuda=opt.no_cuda,
45 | num_seg_classes=opt.n_seg_classes)
46 | elif opt.model_depth == 101:
47 | model = resnet.resnet101(
48 | sample_input_W=opt.input_W,
49 | sample_input_H=opt.input_H,
50 | sample_input_D=opt.input_D,
51 | shortcut_type=opt.resnet_shortcut,
52 | no_cuda=opt.no_cuda,
53 | num_seg_classes=opt.n_seg_classes)
54 | elif opt.model_depth == 152:
55 | model = resnet.resnet152(
56 | sample_input_W=opt.input_W,
57 | sample_input_H=opt.input_H,
58 | sample_input_D=opt.input_D,
59 | shortcut_type=opt.resnet_shortcut,
60 | no_cuda=opt.no_cuda,
61 | num_seg_classes=opt.n_seg_classes)
62 | elif opt.model_depth == 200:
63 | model = resnet.resnet200(
64 | sample_input_W=opt.input_W,
65 | sample_input_H=opt.input_H,
66 | sample_input_D=opt.input_D,
67 | shortcut_type=opt.resnet_shortcut,
68 | no_cuda=opt.no_cuda,
69 | num_seg_classes=opt.n_seg_classes)
70 |
71 | if not opt.no_cuda:
72 | if len(opt.gpu_id) > 1:
73 | model = model.cuda()
74 | model = nn.DataParallel(model, device_ids=opt.gpu_id)
75 | net_dict = model.state_dict()
76 | else:
77 | import os
78 | os.environ["CUDA_VISIBLE_DEVICES"]=str(opt.gpu_id[0])
79 | model = model.cuda()
80 | model = nn.DataParallel(model, device_ids=None)
81 | net_dict = model.state_dict()
82 | else:
83 | net_dict = model.state_dict()
84 |
85 | # load pretrain
86 | if opt.phase != 'test' and opt.pretrain_path:
87 | print ('loading pretrained model {}'.format(opt.pretrain_path))
88 | pretrain = torch.load(opt.pretrain_path)
89 | pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()}
90 |
91 | net_dict.update(pretrain_dict)
92 | model.load_state_dict(net_dict)
93 |
94 | new_parameters = []
95 | for pname, p in model.named_parameters():
96 | for layer_name in opt.new_layer_names:
97 | if pname.find(layer_name) >= 0:
98 | new_parameters.append(p)
99 | break
100 |
101 | new_parameters_id = list(map(id, new_parameters))
102 | base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
103 | parameters = {'base_parameters': base_parameters,
104 | 'new_parameters': new_parameters}
105 |
106 | return model, parameters
107 |
108 | return model, model.parameters()
109 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import math
6 | from functools import partial
7 |
8 | __all__ = [
9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152', 'resnet200'
11 | ]
12 |
13 |
14 | def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
15 | # 3x3x3 convolution with padding
16 | return nn.Conv3d(
17 | in_planes,
18 | out_planes,
19 | kernel_size=3,
20 | dilation=dilation,
21 | stride=stride,
22 | padding=dilation,
23 | bias=False)
24 |
25 |
26 | def downsample_basic_block(x, planes, stride, no_cuda=False):
27 | out = F.avg_pool3d(x, kernel_size=1, stride=stride)
28 | zero_pads = torch.Tensor(
29 | out.size(0), planes - out.size(1), out.size(2), out.size(3),
30 | out.size(4)).zero_()
31 | if not no_cuda:
32 | if isinstance(out.data, torch.cuda.FloatTensor):
33 | zero_pads = zero_pads.cuda()
34 |
35 | out = Variable(torch.cat([out.data, zero_pads], dim=1))
36 |
37 | return out
38 |
39 |
40 | class BasicBlock(nn.Module):
41 | expansion = 1
42 |
43 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
44 | super(BasicBlock, self).__init__()
45 | self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
46 | self.bn1 = nn.BatchNorm3d(planes)
47 | self.relu = nn.ReLU(inplace=True)
48 | self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
49 | self.bn2 = nn.BatchNorm3d(planes)
50 | self.downsample = downsample
51 | self.stride = stride
52 | self.dilation = dilation
53 |
54 | def forward(self, x):
55 | residual = x
56 |
57 | out = self.conv1(x)
58 | out = self.bn1(out)
59 | out = self.relu(out)
60 | out = self.conv2(out)
61 | out = self.bn2(out)
62 |
63 | if self.downsample is not None:
64 | residual = self.downsample(x)
65 |
66 | out += residual
67 | out = self.relu(out)
68 |
69 | return out
70 |
71 |
72 | class Bottleneck(nn.Module):
73 | expansion = 4
74 |
75 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
76 | super(Bottleneck, self).__init__()
77 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
78 | self.bn1 = nn.BatchNorm3d(planes)
79 | self.conv2 = nn.Conv3d(
80 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
81 | self.bn2 = nn.BatchNorm3d(planes)
82 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
83 | self.bn3 = nn.BatchNorm3d(planes * 4)
84 | self.relu = nn.ReLU(inplace=True)
85 | self.downsample = downsample
86 | self.stride = stride
87 | self.dilation = dilation
88 |
89 | def forward(self, x):
90 | residual = x
91 |
92 | out = self.conv1(x)
93 | out = self.bn1(out)
94 | out = self.relu(out)
95 |
96 | out = self.conv2(out)
97 | out = self.bn2(out)
98 | out = self.relu(out)
99 |
100 | out = self.conv3(out)
101 | out = self.bn3(out)
102 |
103 | if self.downsample is not None:
104 | residual = self.downsample(x)
105 |
106 | out += residual
107 | out = self.relu(out)
108 |
109 | return out
110 |
111 |
112 | class ResNet(nn.Module):
113 |
114 | def __init__(self,
115 | block,
116 | layers,
117 | sample_input_D,
118 | sample_input_H,
119 | sample_input_W,
120 | num_seg_classes,
121 | shortcut_type='B',
122 | no_cuda = False):
123 | self.inplanes = 64
124 | self.no_cuda = no_cuda
125 | super(ResNet, self).__init__()
126 | self.conv1 = nn.Conv3d(
127 | 1,
128 | 64,
129 | kernel_size=7,
130 | stride=(2, 2, 2),
131 | padding=(3, 3, 3),
132 | bias=False)
133 |
134 | self.bn1 = nn.BatchNorm3d(64)
135 | self.relu = nn.ReLU(inplace=True)
136 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
137 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
138 | self.layer2 = self._make_layer(
139 | block, 128, layers[1], shortcut_type, stride=2)
140 | self.layer3 = self._make_layer(
141 | block, 256, layers[2], shortcut_type, stride=1, dilation=2)
142 | self.layer4 = self._make_layer(
143 | block, 512, layers[3], shortcut_type, stride=1, dilation=4)
144 |
145 | self.conv_seg = nn.Sequential(
146 | nn.ConvTranspose3d(
147 | 512 * block.expansion,
148 | 32,
149 | 2,
150 | stride=2
151 | ),
152 | nn.BatchNorm3d(32),
153 | nn.ReLU(inplace=True),
154 | nn.Conv3d(
155 | 32,
156 | 32,
157 | kernel_size=3,
158 | stride=(1, 1, 1),
159 | padding=(1, 1, 1),
160 | bias=False),
161 | nn.BatchNorm3d(32),
162 | nn.ReLU(inplace=True),
163 | nn.Conv3d(
164 | 32,
165 | num_seg_classes,
166 | kernel_size=1,
167 | stride=(1, 1, 1),
168 | bias=False)
169 | )
170 |
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv3d):
173 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
174 | elif isinstance(m, nn.BatchNorm3d):
175 | m.weight.data.fill_(1)
176 | m.bias.data.zero_()
177 |
178 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
179 | downsample = None
180 | if stride != 1 or self.inplanes != planes * block.expansion:
181 | if shortcut_type == 'A':
182 | downsample = partial(
183 | downsample_basic_block,
184 | planes=planes * block.expansion,
185 | stride=stride,
186 | no_cuda=self.no_cuda)
187 | else:
188 | downsample = nn.Sequential(
189 | nn.Conv3d(
190 | self.inplanes,
191 | planes * block.expansion,
192 | kernel_size=1,
193 | stride=stride,
194 | bias=False), nn.BatchNorm3d(planes * block.expansion))
195 |
196 | layers = []
197 | layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
198 | self.inplanes = planes * block.expansion
199 | for i in range(1, blocks):
200 | layers.append(block(self.inplanes, planes, dilation=dilation))
201 |
202 | return nn.Sequential(*layers)
203 |
204 | def forward(self, x):
205 | x = self.conv1(x)
206 | x = self.bn1(x)
207 | x = self.relu(x)
208 | x = self.maxpool(x)
209 | x = self.layer1(x)
210 | x = self.layer2(x)
211 | x = self.layer3(x)
212 | x = self.layer4(x)
213 | x = self.conv_seg(x)
214 |
215 | return x
216 |
217 | def resnet10(**kwargs):
218 | """Constructs a ResNet-18 model.
219 | """
220 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
221 | return model
222 |
223 |
224 | def resnet18(**kwargs):
225 | """Constructs a ResNet-18 model.
226 | """
227 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
228 | return model
229 |
230 |
231 | def resnet34(**kwargs):
232 | """Constructs a ResNet-34 model.
233 | """
234 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
235 | return model
236 |
237 |
238 | def resnet50(**kwargs):
239 | """Constructs a ResNet-50 model.
240 | """
241 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
242 | return model
243 |
244 |
245 | def resnet101(**kwargs):
246 | """Constructs a ResNet-101 model.
247 | """
248 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
249 | return model
250 |
251 |
252 | def resnet152(**kwargs):
253 | """Constructs a ResNet-101 model.
254 | """
255 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
256 | return model
257 |
258 |
259 | def resnet200(**kwargs):
260 | """Constructs a ResNet-101 model.
261 | """
262 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
263 | return model
264 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # python requirements
2 | pip>=9.0.1
3 | #logging==0.4.9.6
4 | torch==0.4.1
5 | numpy==1.15.4
6 | nibabel==2.4.1
7 | scipy==1.1.0
8 | argparse==1.1
--------------------------------------------------------------------------------
/setting.py:
--------------------------------------------------------------------------------
1 | '''
2 | Configs for training & testing
3 | Written by Whalechen
4 | '''
5 |
6 | import argparse
7 |
8 | def parse_opts():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument(
11 | '--data_root',
12 | default='./data',
13 | type=str,
14 | help='Root directory path of data')
15 | parser.add_argument(
16 | '--img_list',
17 | default='./data/train.txt',
18 | type=str,
19 | help='Path for image list file')
20 | parser.add_argument(
21 | '--n_seg_classes',
22 | default=2,
23 | type=int,
24 | help="Number of segmentation classes"
25 | )
26 | parser.add_argument(
27 | '--learning_rate', # set to 0.001 when finetune
28 | default=0.001,
29 | type=float,
30 | help=
31 | 'Initial learning rate (divided by 10 while training by lr scheduler)')
32 | parser.add_argument(
33 | '--num_workers',
34 | default=4,
35 | type=int,
36 | help='Number of jobs')
37 | parser.add_argument(
38 | '--batch_size', default=1, type=int, help='Batch Size')
39 | parser.add_argument(
40 | '--phase', default='train', type=str, help='Phase of train or test')
41 | parser.add_argument(
42 | '--save_intervals',
43 | default=10,
44 | type=int,
45 | help='Interation for saving model')
46 | parser.add_argument(
47 | '--n_epochs',
48 | default=200,
49 | type=int,
50 | help='Number of total epochs to run')
51 | parser.add_argument(
52 | '--input_D',
53 | default=56,
54 | type=int,
55 | help='Input size of depth')
56 | parser.add_argument(
57 | '--input_H',
58 | default=448,
59 | type=int,
60 | help='Input size of height')
61 | parser.add_argument(
62 | '--input_W',
63 | default=448,
64 | type=int,
65 | help='Input size of width')
66 | parser.add_argument(
67 | '--resume_path',
68 | default='',
69 | type=str,
70 | help=
71 | 'Path for resume model.'
72 | )
73 | parser.add_argument(
74 | '--pretrain_path',
75 | default='pretrain/resnet_50.pth',
76 | type=str,
77 | help=
78 | 'Path for pretrained model.'
79 | )
80 | parser.add_argument(
81 | '--new_layer_names',
82 | #default=['upsample1', 'cmp_layer3', 'upsample2', 'cmp_layer2', 'upsample3', 'cmp_layer1', 'upsample4', 'cmp_conv1', 'conv_seg'],
83 | default=['conv_seg'],
84 | type=list,
85 | help='New layer except for backbone')
86 | parser.add_argument(
87 | '--no_cuda', action='store_true', help='If true, cuda is not used.')
88 | parser.set_defaults(no_cuda=False)
89 | parser.add_argument(
90 | '--gpu_id',
91 | nargs='+',
92 | type=int,
93 | help='Gpu id lists')
94 | parser.add_argument(
95 | '--model',
96 | default='resnet',
97 | type=str,
98 | help='(resnet | preresnet | wideresnet | resnext | densenet | ')
99 | parser.add_argument(
100 | '--model_depth',
101 | default=50,
102 | type=int,
103 | help='Depth of resnet (10 | 18 | 34 | 50 | 101)')
104 | parser.add_argument(
105 | '--resnet_shortcut',
106 | default='B',
107 | type=str,
108 | help='Shortcut type of resnet (A | B)')
109 | parser.add_argument(
110 | '--manual_seed', default=1, type=int, help='Manually set random seed')
111 | parser.add_argument(
112 | '--ci_test', action='store_true', help='If true, ci testing is used.')
113 | args = parser.parse_args()
114 | args.save_folder = "./trails/models/{}_{}".format(args.model, args.model_depth)
115 |
116 | return args
117 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from setting import parse_opts
2 | from datasets.brains18 import BrainS18Dataset
3 | from model import generate_model
4 | import torch
5 | import numpy as np
6 | from torch.utils.data import DataLoader
7 | import torch.nn.functional as F
8 | from scipy import ndimage
9 | import nibabel as nib
10 | import sys
11 | import os
12 | from utils.file_process import load_lines
13 | import numpy as np
14 |
15 |
16 | def seg_eval(pred, label, clss):
17 | """
18 | calculate the dice between prediction and ground truth
19 | input:
20 | pred: predicted mask
21 | label: groud truth
22 | clss: eg. [0, 1] for binary class
23 | """
24 | Ncls = len(clss)
25 | dices = np.zeros(Ncls)
26 | [depth, height, width] = pred.shape
27 | for idx, cls in enumerate(clss):
28 | # binary map
29 | pred_cls = np.zeros([depth, height, width])
30 | pred_cls[np.where(pred == cls)] = 1
31 | label_cls = np.zeros([depth, height, width])
32 | label_cls[np.where(label == cls)] = 1
33 |
34 | # cal the inter & conv
35 | s = pred_cls + label_cls
36 | inter = len(np.where(s >= 2)[0])
37 | conv = len(np.where(s >= 1)[0]) + inter
38 | try:
39 | dice = 2.0 * inter / conv
40 | except:
41 | print("conv is zeros when dice = 2.0 * inter / conv")
42 | dice = -1
43 |
44 | dices[idx] = dice
45 |
46 | return dices
47 |
48 | def test(data_loader, model, img_names, sets):
49 | masks = []
50 | model.eval() # for testing
51 | for batch_id, batch_data in enumerate(data_loader):
52 | # forward
53 | volume = batch_data
54 | if not sets.no_cuda:
55 | volume = volume.cuda()
56 | with torch.no_grad():
57 | probs = model(volume)
58 | probs = F.softmax(probs, dim=1)
59 |
60 | # resize mask to original size
61 | [batchsize, _, mask_d, mask_h, mask_w] = probs.shape
62 | data = nib.load(os.path.join(sets.data_root, img_names[batch_id]))
63 | data = data.get_data()
64 | [depth, height, width] = data.shape
65 | mask = probs[0]
66 | scale = [1, depth*1.0/mask_d, height*1.0/mask_h, width*1.0/mask_w]
67 | mask = ndimage.interpolation.zoom(mask, scale, order=1)
68 | mask = np.argmax(mask, axis=0)
69 |
70 | masks.append(mask)
71 |
72 | return masks
73 |
74 |
75 | if __name__ == '__main__':
76 | # settting
77 | sets = parse_opts()
78 | sets.target_type = "normal"
79 | sets.phase = 'test'
80 |
81 | # getting model
82 | checkpoint = torch.load(sets.resume_path)
83 | net, _ = generate_model(sets)
84 | net.load_state_dict(checkpoint['state_dict'])
85 |
86 | # data tensor
87 | testing_data =BrainS18Dataset(sets.data_root, sets.img_list, sets)
88 | data_loader = DataLoader(testing_data, batch_size=1, shuffle=False, num_workers=1, pin_memory=False)
89 |
90 | # testing
91 | img_names = [info.split(" ")[0] for info in load_lines(sets.img_list)]
92 | masks = test(data_loader, net, img_names, sets)
93 |
94 | # evaluation: calculate dice
95 | label_names = [info.split(" ")[1] for info in load_lines(sets.img_list)]
96 | Nimg = len(label_names)
97 | dices = np.zeros([Nimg, sets.n_seg_classes])
98 | for idx in range(Nimg):
99 | label = nib.load(os.path.join(sets.data_root, label_names[idx]))
100 | label = label.get_data()
101 | dices[idx, :] = seg_eval(masks[idx], label, range(sets.n_seg_classes))
102 |
103 | # print result
104 | for idx in range(1, sets.n_seg_classes):
105 | mean_dice_per_task = np.mean(dices[:, idx])
106 | print('mean dice for class-{} is {}'.format(idx, mean_dice_per_task))
107 |
--------------------------------------------------------------------------------
/test_ci.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | print("test successful!")
--------------------------------------------------------------------------------
/toy_data/MRBrainS18/images/070.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/toy_data/MRBrainS18/images/070.nii.gz
--------------------------------------------------------------------------------
/toy_data/MRBrainS18/labels/070.nii.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/toy_data/MRBrainS18/labels/070.nii.gz
--------------------------------------------------------------------------------
/toy_data/MRBrainS18/test_ci.txt:
--------------------------------------------------------------------------------
1 | MRBrainS18/images/070.nii.gz MRBrainS18/labels/070.nii.gz
2 |
--------------------------------------------------------------------------------
/toy_data/test_ci.txt:
--------------------------------------------------------------------------------
1 | MRBrainS18/images/070.nii.gz MRBrainS18/labels/070.nii.gz
2 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | Training code for MRBrainS18 datasets segmentation
3 | Written by Whalechen
4 | '''
5 |
6 | from setting import parse_opts
7 | from datasets.brains18 import BrainS18Dataset
8 | from model import generate_model
9 | import torch
10 | import numpy as np
11 | from torch import nn
12 | from torch import optim
13 | from torch.optim import lr_scheduler
14 | from torch.utils.data import DataLoader
15 | import time
16 | from utils.logger import log
17 | from scipy import ndimage
18 | import os
19 |
20 | def train(data_loader, model, optimizer, scheduler, total_epochs, save_interval, save_folder, sets):
21 | # settings
22 | batches_per_epoch = len(data_loader)
23 | log.info('{} epochs in total, {} batches per epoch'.format(total_epochs, batches_per_epoch))
24 | loss_seg = nn.CrossEntropyLoss(ignore_index=-1)
25 |
26 | print("Current setting is:")
27 | print(sets)
28 | print("\n\n")
29 | if not sets.no_cuda:
30 | loss_seg = loss_seg.cuda()
31 |
32 | model.train()
33 | train_time_sp = time.time()
34 | for epoch in range(total_epochs):
35 | log.info('Start epoch {}'.format(epoch))
36 |
37 | scheduler.step()
38 | log.info('lr = {}'.format(scheduler.get_lr()))
39 |
40 | for batch_id, batch_data in enumerate(data_loader):
41 | # getting data batch
42 | batch_id_sp = epoch * batches_per_epoch
43 | volumes, label_masks = batch_data
44 |
45 | if not sets.no_cuda:
46 | volumes = volumes.cuda()
47 |
48 | optimizer.zero_grad()
49 | out_masks = model(volumes)
50 | # resize label
51 | [n, _, d, h, w] = out_masks.shape
52 | new_label_masks = np.zeros([n, d, h, w])
53 | for label_id in range(n):
54 | label_mask = label_masks[label_id]
55 | [ori_c, ori_d, ori_h, ori_w] = label_mask.shape
56 | label_mask = np.reshape(label_mask, [ori_d, ori_h, ori_w])
57 | scale = [d*1.0/ori_d, h*1.0/ori_h, w*1.0/ori_w]
58 | label_mask = ndimage.interpolation.zoom(label_mask, scale, order=0)
59 | new_label_masks[label_id] = label_mask
60 |
61 | new_label_masks = torch.tensor(new_label_masks).to(torch.int64)
62 | if not sets.no_cuda:
63 | new_label_masks = new_label_masks.cuda()
64 |
65 | # calculating loss
66 | loss_value_seg = loss_seg(out_masks, new_label_masks)
67 | loss = loss_value_seg
68 | loss.backward()
69 | optimizer.step()
70 |
71 | avg_batch_time = (time.time() - train_time_sp) / (1 + batch_id_sp)
72 | log.info(
73 | 'Batch: {}-{} ({}), loss = {:.3f}, loss_seg = {:.3f}, avg_batch_time = {:.3f}'\
74 | .format(epoch, batch_id, batch_id_sp, loss.item(), loss_value_seg.item(), avg_batch_time))
75 |
76 | if not sets.ci_test:
77 | # save model
78 | if batch_id == 0 and batch_id_sp != 0 and batch_id_sp % save_interval == 0:
79 | #if batch_id_sp != 0 and batch_id_sp % save_interval == 0:
80 | model_save_path = '{}_epoch_{}_batch_{}.pth.tar'.format(save_folder, epoch, batch_id)
81 | model_save_dir = os.path.dirname(model_save_path)
82 | if not os.path.exists(model_save_dir):
83 | os.makedirs(model_save_dir)
84 |
85 | log.info('Save checkpoints: epoch = {}, batch_id = {}'.format(epoch, batch_id))
86 | torch.save({
87 | 'ecpoch': epoch,
88 | 'batch_id': batch_id,
89 | 'state_dict': model.state_dict(),
90 | 'optimizer': optimizer.state_dict()},
91 | model_save_path)
92 |
93 | print('Finished training')
94 | if sets.ci_test:
95 | exit()
96 |
97 |
98 | if __name__ == '__main__':
99 | # settting
100 | sets = parse_opts()
101 | if sets.ci_test:
102 | sets.img_list = './toy_data/test_ci.txt'
103 | sets.n_epochs = 1
104 | sets.no_cuda = True
105 | sets.data_root = './toy_data'
106 | sets.pretrain_path = ''
107 | sets.num_workers = 0
108 | sets.model_depth = 10
109 | sets.resnet_shortcut = 'A'
110 | sets.input_D = 14
111 | sets.input_H = 28
112 | sets.input_W = 28
113 |
114 |
115 |
116 | # getting model
117 | torch.manual_seed(sets.manual_seed)
118 | model, parameters = generate_model(sets)
119 | print (model)
120 | # optimizer
121 | if sets.ci_test:
122 | params = [{'params': parameters, 'lr': sets.learning_rate}]
123 | else:
124 | params = [
125 | { 'params': parameters['base_parameters'], 'lr': sets.learning_rate },
126 | { 'params': parameters['new_parameters'], 'lr': sets.learning_rate*100 }
127 | ]
128 | optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3)
129 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
130 |
131 | # train from resume
132 | if sets.resume_path:
133 | if os.path.isfile(sets.resume_path):
134 | print("=> loading checkpoint '{}'".format(sets.resume_path))
135 | checkpoint = torch.load(sets.resume_path)
136 | model.load_state_dict(checkpoint['state_dict'])
137 | optimizer.load_state_dict(checkpoint['optimizer'])
138 | print("=> loaded checkpoint '{}' (epoch {})"
139 | .format(sets.resume_path, checkpoint['epoch']))
140 |
141 | # getting data
142 | sets.phase = 'train'
143 | if sets.no_cuda:
144 | sets.pin_memory = False
145 | else:
146 | sets.pin_memory = True
147 | training_dataset = BrainS18Dataset(sets.data_root, sets.img_list, sets)
148 | data_loader = DataLoader(training_dataset, batch_size=sets.batch_size, shuffle=True, num_workers=sets.num_workers, pin_memory=sets.pin_memory)
149 |
150 | # training
151 | train(data_loader, model, optimizer, scheduler, total_epochs=sets.n_epochs, save_interval=sets.save_intervals, save_folder=sets.save_folder, sets=sets)
152 |
--------------------------------------------------------------------------------
/utils/file_process.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import os
5 | import os.path as osp
6 |
7 | def load_lines(file_path):
8 | """Read file into a list of lines.
9 |
10 | Input
11 | file_path: file path
12 |
13 | Output
14 | lines: an array of lines
15 | """
16 | with open(file_path, 'r') as fio:
17 | lines = fio.read().splitlines()
18 | return lines
19 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | '''
2 | Written by Whalechen
3 | '''
4 |
5 | import logging
6 |
7 | logging.basicConfig(
8 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
9 | datefmt='%Y-%m-%d %H:%M:%S',
10 | level=logging.DEBUG)
11 |
12 | log = logging.getLogger()
13 |
--------------------------------------------------------------------------------