├── .commitlintrc.yml ├── .drone.yml ├── .gitignore ├── .orange-ci.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── datasets └── brains18.py ├── images ├── efficiency.gif └── logo.png ├── model.py ├── models └── resnet.py ├── requirements.txt ├── setting.py ├── test.py ├── test_ci.py ├── toy_data ├── MRBrainS18 │ ├── images │ │ └── 070.nii.gz │ ├── labels │ │ └── 070.nii.gz │ └── test_ci.txt └── test_ci.txt ├── train.py └── utils ├── file_process.py └── logger.py /.commitlintrc.yml: -------------------------------------------------------------------------------- 1 | extends: 2 | - "@commitlint/config-conventional" 3 | -------------------------------------------------------------------------------- /.drone.yml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | build: 3 | name: testing phase 4 | image: cshwhale/dockerfiles:latest 5 | commands: 6 | - python train.py --ci_test 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # My configurations: 126 | data/**/*.nii.gz 127 | data/**/*.txt 128 | pretrain/**/*.pth 129 | trails/**/*.pth 130 | -------------------------------------------------------------------------------- /.orange-ci.yml: -------------------------------------------------------------------------------- 1 | master: 2 | merge_request: 3 | - stages: 4 | - name: make commitlist 5 | type: git:commitList 6 | options: 7 | toFile: commits-data.json 8 | - name: do commitlint 9 | image: csighub.tencentyun.com/plugins/commitlint 10 | settings: 11 | from_file: commits-data.json 12 | push: 13 | - network: idc-ai-sse4 14 | stages: 15 | - name: testing phase 16 | image: cshwhale/dockerfiles:latest 17 | commands: 18 | - python train.py --ci_test 19 | 20 | $: 21 | tag_push: 22 | - stages: 23 | - name: changelog 24 | type: git:changeLog 25 | options: 26 | filename: CHANGELOG.md 27 | target: master 28 | envExport: 29 | latestChangeLog: LATEST_CHANGE_LOG 30 | - name: upload release 31 | type: git:release 32 | options: 33 | description: ${LATEST_CHANGE_LOG} 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/CONTRIBUTING.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Tencent is pleased to support the open source community by making MedicalNet available. 2 | 3 | Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 4 | 5 | MedicalNet is licensed under the MIT License, including the third-party component listed below. 6 | 7 | A copy of the MIT License is included in this file. 8 | 9 | Other dependency and license: 10 | 11 | 12 | Open Source Software Licensed Under the MIT License: 13 | -------------------------------------------------------------------- 14 | 1. 3D-ResNets-PyTorch 3.0 15 | Copyright (c) 2017 Kensho Hara 16 | 17 | 18 | Terms of the MIT License: 19 | --------------------------------------------------- 20 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 21 | 22 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # MedicalNet 5 | This repository contains a Pytorch implementation of [Med3D: Transfer Learning for 3D Medical Image Analysis](https://arxiv.org/abs/1904.00625). 6 | Many studies have shown that the performance on deep learning is significantly affected by volume of training data. The MedicalNet project aggregated the dataset with diverse modalities, target organs, and pathologies to to build relatively large datasets. Based on this dataset, a series of 3D-ResNet pre-trained models and corresponding transfer-learning training code are provided. 7 | 8 | ### License 9 | MedicalNet is released under the MIT License (refer to the LICENSE file for detailso). 10 | 11 | ### Citing MedicalNet 12 | If you use this code or pre-trained models, please cite the following: 13 | ``` 14 | @article{chen2019med3d, 15 | title={Med3D: Transfer Learning for 3D Medical Image Analysis}, 16 | author={Chen, Sihong and Ma, Kai and Zheng, Yefeng}, 17 | journal={arXiv preprint arXiv:1904.00625}, 18 | year={2019} 19 | } 20 | ``` 21 | ### Update(2019/07/30) 22 | We uploaded 4 pre-trained models based on more datasets (23 datasets). 23 | ``` 24 | Model name : parameters settings 25 | resnet_10_23dataset.pth: --model resnet --model_depth 10 --resnet_shortcut B 26 | resnet_18_23dataset.pth: --model resnet --model_depth 18 --resnet_shortcut A 27 | resnet_34_23dataset.pth: --model resnet --model_depth 34 --resnet_shortcut A 28 | resnet_50_23dataset.pth: --model resnet --model_depth 50 --resnet_shortcut B 29 | ``` 30 | We transferred the above pre-trained models to the multi-class segmentation task (left lung, right lung and background) on Visceral dataset. The results are as follows: 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 |
NetworkPretrainLungSeg(Dice)
3D-ResNet10Train from scratch69.31%
MedicalNet96.56%
3D-ResNet18Train from scratch70.89%
MedicalNet94.68%
3D-ResNet34Train from scratch75.25%
MedicalNet94.14%
3D-ResNet50Train from scratch52.94%
MedicalNet89.25%
74 | 75 | 76 | ### Contents 77 | 1. [Requirements](#Requirements) 78 | 2. [Installation](#Installation) 79 | 3. [Demo](#Demo) 80 | 4. [Experiments](#Experiments) 81 | 5. [TODO](#TODO) 82 | 6. [Acknowledgement](#Acknowledgement) 83 | 84 | ### Requirements 85 | - Python 3.7.0 86 | - PyTorch-0.4.1 87 | - CUDA Version 9.0 88 | - CUDNN 7.0.5 89 | 90 | ### Installation 91 | - Install Python 3.7.0 92 | - pip install -r requirements.txt 93 | 94 | 95 | ### Demo 96 | - Structure of data directories 97 | ``` 98 | MedicalNet is used to transfer the pre-trained model to other datasets (here the MRBrainS18 dataset is used as an example). 99 | MedicalNet/ 100 | |--datasets/:Data preprocessing module 101 | | |--brains18.py:MRBrainS18 data preprocessing script 102 | | |--models/:Model construction module 103 | | |--resnet.py:3D-ResNet network build script 104 | |--utils/:tools 105 | | |--logger.py:Logging script 106 | |--toy_data/:For CI test 107 | |--data/:Data storage module 108 | | |--MRBrainS18/:MRBrainS18 dataset 109 | | | |--images/:source image named with patient ID 110 | | | |--labels/:mask named with patient ID 111 | | |--train.txt: training data lists 112 | | |--val.txt: validation data lists 113 | |--pretrain/:Pre-trained models storage module 114 | |--model.py: Network processing script 115 | |--setting.py: Parameter setting script 116 | |--train.py: MRBrainS18 training demo script 117 | |--test.py: MRBrainS18 testing demo script 118 | |--requirement.txt: Dependent library list 119 | |--README.md 120 | ``` 121 | 122 | - Network structure parameter settings 123 | ``` 124 | Model name : parameters settings 125 | resnet_10.pth: --model resnet --model_depth 10 --resnet_shortcut B 126 | resnet_18.pth: --model resnet --model_depth 18 --resnet_shortcut A 127 | resnet_34.pth: --model resnet --model_depth 34 --resnet_shortcut A 128 | resnet_50.pth: --model resnet --model_depth 50 --resnet_shortcut B 129 | resnet_101.pth: --model resnet --model_depth 101 --resnet_shortcut B 130 | resnet_152.pth: --model resnet --model_depth 152 --resnet_shortcut B 131 | resnet_200.pth: --model resnet --model_depth 200 --resnet_shortcut B 132 | ``` 133 | 134 | - After successfully completing basic installation, you'll be ready to run the demo. 135 | 1. Clone the MedicalNet repository 136 | ``` 137 | git clone https://github.com/Tencent/MedicalNet 138 | ``` 139 | 2. Download data & pre-trained models ([Google Drive](https://drive.google.com/file/d/13tnSvXY7oDIEloNFiGTsjUIYfS3g3BfG/view?usp=sharing) or [Tencent Weiyun](https://share.weiyun.com/55sZyIx)) 140 | 141 | Unzip and move files 142 | ``` 143 | mv MedicalNet_pytorch_files.zip MedicalNet/. 144 | cd MedicalNet 145 | unzip MedicalNet_pytorch_files.zip 146 | ``` 147 | 3. Run the training code (e.g. 3D-ResNet-50) 148 | ``` 149 | python train.py --gpu_id 0 1 # multi-gpu training on gpu 0,1 150 | or 151 | python train.py --gpu_id 0 # single-gpu training on gpu 0 152 | ``` 153 | 4. Run the testing code (e.g. 3D-ResNet-50) 154 | ``` 155 | python test.py --gpu_id 0 --resume_path trails/models/resnet_50_epoch_110_batch_0.pth.tar --img_list data/val.txt 156 | ``` 157 | 158 | ### Experiments 159 | - Computational Cost 160 | ``` 161 | GPU:NVIDIA Tesla P40 162 | ``` 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 |
NetworkParamerers (M)Running time (s)
3D-ResNet1014.360.18
3D-ResNet1832.990.19
3D-ResNet3463.310.22
3D-ResNet5046.210.21
3D-ResNet10185.310.29
3D-ResNet152117.510.34
3D-ResNet200126.740.45
205 | 206 | - Performance 207 | ``` 208 | Visualization of the segmentation results of our approach vs. the comparison ones after the same training epochs. 209 | It has demonstrated that the efficiency for training convergence and accuracy based on our MedicalNet pre-trained models. 210 | ``` 211 | 212 | 213 | 214 | ``` 215 | Results of transfer MedicalNet pre-trained models to lung segmentation (LungSeg) and pulmonary nodule classification (NoduleCls) with Dice and accuracy evaluation metrics, respectively. 216 | ``` 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 |
NetworkPretrainLungSeg(Dice)NoduleCls(accuracy)
3D-ResNet10Train from scratch71.30%79.80%
MedicalNet87.16%86.87%
3D-ResNet18Train from scratch75.22%80.80%
MedicalNet87.26%88.89%
3D-ResNet34Train from scratch76.82%83.84%
MedicalNet89.31%89.90%
3D-ResNet50Train from scratch71.75%84.85%
MedicalNet93.31%89.90%
3D-ResNet101Train from scratch72.10%81.82%
MedicalNet92.79%90.91%
3D-ResNet152Train from scratch73.29%73.74%
MedicalNet92.33%90.91%
3D-ResNet200Train from scratch71.29%76.77%
MedicalNet92.06%90.91%
302 | 303 | - Please refer to [Med3D: Transfer Learning for 3D Medical Image Analysis](https://arxiv.org/abs/1904.00625) for more details: 304 | 305 | ### TODO 306 | - [x] 3D-ResNet series pre-trained models 307 | - [x] Transfer learning training code 308 | - [x] Training with multi-gpu 309 | - [ ] 3D efficient pre-trained models(e.g., 3D-MobileNet, 3D-ShuffleNet) 310 | - [ ] 2D medical pre-trained models 311 | - [x] Pre-trained MedicalNet models based on more medical dataset 312 | 313 | ### Acknowledgement 314 | We thank [3D-ResNets-PyTorch](https://github.com/kenshohara/3D-ResNets-PyTorch) and [MRBrainS18](https://mrbrains18.isi.uu.nl/) which we build MedicalNet refer to this releasing code and the dataset. 315 | 316 | ### Contribution 317 | If you want to contribute to MedicalNet, be sure to review the [contribution guidelines](https://github.com/Tencent/MedicalNet/blob/master/CONTRIBUTING.md) 318 | -------------------------------------------------------------------------------- /datasets/brains18.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Dataset for training 3 | Written by Whalechen 4 | ''' 5 | 6 | import math 7 | import os 8 | import random 9 | 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | import nibabel 13 | from scipy import ndimage 14 | 15 | class BrainS18Dataset(Dataset): 16 | 17 | def __init__(self, root_dir, img_list, sets): 18 | with open(img_list, 'r') as f: 19 | self.img_list = [line.strip() for line in f] 20 | print("Processing {} datas".format(len(self.img_list))) 21 | self.root_dir = root_dir 22 | self.input_D = sets.input_D 23 | self.input_H = sets.input_H 24 | self.input_W = sets.input_W 25 | self.phase = sets.phase 26 | 27 | def __nii2tensorarray__(self, data): 28 | [z, y, x] = data.shape 29 | new_data = np.reshape(data, [1, z, y, x]) 30 | new_data = new_data.astype("float32") 31 | 32 | return new_data 33 | 34 | def __len__(self): 35 | return len(self.img_list) 36 | 37 | def __getitem__(self, idx): 38 | 39 | if self.phase == "train": 40 | # read image and labels 41 | ith_info = self.img_list[idx].split(" ") 42 | img_name = os.path.join(self.root_dir, ith_info[0]) 43 | label_name = os.path.join(self.root_dir, ith_info[1]) 44 | assert os.path.isfile(img_name) 45 | assert os.path.isfile(label_name) 46 | img = nibabel.load(img_name) # We have transposed the data from WHD format to DHW 47 | assert img is not None 48 | mask = nibabel.load(label_name) 49 | assert mask is not None 50 | 51 | # data processing 52 | img_array, mask_array = self.__training_data_process__(img, mask) 53 | 54 | # 2 tensor array 55 | img_array = self.__nii2tensorarray__(img_array) 56 | mask_array = self.__nii2tensorarray__(mask_array) 57 | 58 | assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape) 59 | return img_array, mask_array 60 | 61 | elif self.phase == "test": 62 | # read image 63 | ith_info = self.img_list[idx].split(" ") 64 | img_name = os.path.join(self.root_dir, ith_info[0]) 65 | print(img_name) 66 | assert os.path.isfile(img_name) 67 | img = nibabel.load(img_name) 68 | assert img is not None 69 | 70 | # data processing 71 | img_array = self.__testing_data_process__(img) 72 | 73 | # 2 tensor array 74 | img_array = self.__nii2tensorarray__(img_array) 75 | 76 | return img_array 77 | 78 | 79 | def __drop_invalid_range__(self, volume, label=None): 80 | """ 81 | Cut off the invalid area 82 | """ 83 | zero_value = volume[0, 0, 0] 84 | non_zeros_idx = np.where(volume != zero_value) 85 | 86 | [max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1) 87 | [min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1) 88 | 89 | if label is not None: 90 | return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w] 91 | else: 92 | return volume[min_z:max_z, min_h:max_h, min_w:max_w] 93 | 94 | 95 | def __random_center_crop__(self, data, label): 96 | from random import random 97 | """ 98 | Random crop 99 | """ 100 | target_indexs = np.where(label>0) 101 | [img_d, img_h, img_w] = data.shape 102 | [max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1) 103 | [min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1) 104 | [target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W]) 105 | Z_min = int((min_D - target_depth*1.0/2) * random()) 106 | Y_min = int((min_H - target_height*1.0/2) * random()) 107 | X_min = int((min_W - target_width*1.0/2) * random()) 108 | 109 | Z_max = int(img_d - ((img_d - (max_D + target_depth*1.0/2)) * random())) 110 | Y_max = int(img_h - ((img_h - (max_H + target_height*1.0/2)) * random())) 111 | X_max = int(img_w - ((img_w - (max_W + target_width*1.0/2)) * random())) 112 | 113 | Z_min = np.max([0, Z_min]) 114 | Y_min = np.max([0, Y_min]) 115 | X_min = np.max([0, X_min]) 116 | 117 | Z_max = np.min([img_d, Z_max]) 118 | Y_max = np.min([img_h, Y_max]) 119 | X_max = np.min([img_w, X_max]) 120 | 121 | Z_min = int(Z_min) 122 | Y_min = int(Y_min) 123 | X_min = int(X_min) 124 | 125 | Z_max = int(Z_max) 126 | Y_max = int(Y_max) 127 | X_max = int(X_max) 128 | 129 | return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max] 130 | 131 | 132 | 133 | def __itensity_normalize_one_volume__(self, volume): 134 | """ 135 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 136 | inputs: 137 | volume: the input nd volume 138 | outputs: 139 | out: the normalized nd volume 140 | """ 141 | 142 | pixels = volume[volume > 0] 143 | mean = pixels.mean() 144 | std = pixels.std() 145 | out = (volume - mean)/std 146 | out_random = np.random.normal(0, 1, size = volume.shape) 147 | out[volume == 0] = out_random[volume == 0] 148 | return out 149 | 150 | def __resize_data__(self, data): 151 | """ 152 | Resize the data to the input size 153 | """ 154 | [depth, height, width] = data.shape 155 | scale = [self.input_D*1.0/depth, self.input_H*1.0/height, self.input_W*1.0/width] 156 | data = ndimage.interpolation.zoom(data, scale, order=0) 157 | 158 | return data 159 | 160 | 161 | def __crop_data__(self, data, label): 162 | """ 163 | Random crop with different methods: 164 | """ 165 | # random center crop 166 | data, label = self.__random_center_crop__ (data, label) 167 | 168 | return data, label 169 | 170 | def __training_data_process__(self, data, label): 171 | # crop data according net input size 172 | data = data.get_data() 173 | label = label.get_data() 174 | 175 | # drop out the invalid range 176 | data, label = self.__drop_invalid_range__(data, label) 177 | 178 | # crop data 179 | data, label = self.__crop_data__(data, label) 180 | 181 | # resize data 182 | data = self.__resize_data__(data) 183 | label = self.__resize_data__(label) 184 | 185 | # normalization datas 186 | data = self.__itensity_normalize_one_volume__(data) 187 | 188 | return data, label 189 | 190 | 191 | def __testing_data_process__(self, data): 192 | # crop data according net input size 193 | data = data.get_data() 194 | 195 | # resize data 196 | data = self.__resize_data__(data) 197 | 198 | # normalization datas 199 | data = self.__itensity_normalize_one_volume__(data) 200 | 201 | return data 202 | -------------------------------------------------------------------------------- /images/efficiency.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/images/efficiency.gif -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/images/logo.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models import resnet 4 | 5 | 6 | def generate_model(opt): 7 | assert opt.model in [ 8 | 'resnet' 9 | ] 10 | 11 | if opt.model == 'resnet': 12 | assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200] 13 | 14 | if opt.model_depth == 10: 15 | model = resnet.resnet10( 16 | sample_input_W=opt.input_W, 17 | sample_input_H=opt.input_H, 18 | sample_input_D=opt.input_D, 19 | shortcut_type=opt.resnet_shortcut, 20 | no_cuda=opt.no_cuda, 21 | num_seg_classes=opt.n_seg_classes) 22 | elif opt.model_depth == 18: 23 | model = resnet.resnet18( 24 | sample_input_W=opt.input_W, 25 | sample_input_H=opt.input_H, 26 | sample_input_D=opt.input_D, 27 | shortcut_type=opt.resnet_shortcut, 28 | no_cuda=opt.no_cuda, 29 | num_seg_classes=opt.n_seg_classes) 30 | elif opt.model_depth == 34: 31 | model = resnet.resnet34( 32 | sample_input_W=opt.input_W, 33 | sample_input_H=opt.input_H, 34 | sample_input_D=opt.input_D, 35 | shortcut_type=opt.resnet_shortcut, 36 | no_cuda=opt.no_cuda, 37 | num_seg_classes=opt.n_seg_classes) 38 | elif opt.model_depth == 50: 39 | model = resnet.resnet50( 40 | sample_input_W=opt.input_W, 41 | sample_input_H=opt.input_H, 42 | sample_input_D=opt.input_D, 43 | shortcut_type=opt.resnet_shortcut, 44 | no_cuda=opt.no_cuda, 45 | num_seg_classes=opt.n_seg_classes) 46 | elif opt.model_depth == 101: 47 | model = resnet.resnet101( 48 | sample_input_W=opt.input_W, 49 | sample_input_H=opt.input_H, 50 | sample_input_D=opt.input_D, 51 | shortcut_type=opt.resnet_shortcut, 52 | no_cuda=opt.no_cuda, 53 | num_seg_classes=opt.n_seg_classes) 54 | elif opt.model_depth == 152: 55 | model = resnet.resnet152( 56 | sample_input_W=opt.input_W, 57 | sample_input_H=opt.input_H, 58 | sample_input_D=opt.input_D, 59 | shortcut_type=opt.resnet_shortcut, 60 | no_cuda=opt.no_cuda, 61 | num_seg_classes=opt.n_seg_classes) 62 | elif opt.model_depth == 200: 63 | model = resnet.resnet200( 64 | sample_input_W=opt.input_W, 65 | sample_input_H=opt.input_H, 66 | sample_input_D=opt.input_D, 67 | shortcut_type=opt.resnet_shortcut, 68 | no_cuda=opt.no_cuda, 69 | num_seg_classes=opt.n_seg_classes) 70 | 71 | if not opt.no_cuda: 72 | if len(opt.gpu_id) > 1: 73 | model = model.cuda() 74 | model = nn.DataParallel(model, device_ids=opt.gpu_id) 75 | net_dict = model.state_dict() 76 | else: 77 | import os 78 | os.environ["CUDA_VISIBLE_DEVICES"]=str(opt.gpu_id[0]) 79 | model = model.cuda() 80 | model = nn.DataParallel(model, device_ids=None) 81 | net_dict = model.state_dict() 82 | else: 83 | net_dict = model.state_dict() 84 | 85 | # load pretrain 86 | if opt.phase != 'test' and opt.pretrain_path: 87 | print ('loading pretrained model {}'.format(opt.pretrain_path)) 88 | pretrain = torch.load(opt.pretrain_path) 89 | pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()} 90 | 91 | net_dict.update(pretrain_dict) 92 | model.load_state_dict(net_dict) 93 | 94 | new_parameters = [] 95 | for pname, p in model.named_parameters(): 96 | for layer_name in opt.new_layer_names: 97 | if pname.find(layer_name) >= 0: 98 | new_parameters.append(p) 99 | break 100 | 101 | new_parameters_id = list(map(id, new_parameters)) 102 | base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters())) 103 | parameters = {'base_parameters': base_parameters, 104 | 'new_parameters': new_parameters} 105 | 106 | return model, parameters 107 | 108 | return model, model.parameters() 109 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = [ 9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnet200' 11 | ] 12 | 13 | 14 | def conv3x3x3(in_planes, out_planes, stride=1, dilation=1): 15 | # 3x3x3 convolution with padding 16 | return nn.Conv3d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | dilation=dilation, 21 | stride=stride, 22 | padding=dilation, 23 | bias=False) 24 | 25 | 26 | def downsample_basic_block(x, planes, stride, no_cuda=False): 27 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 28 | zero_pads = torch.Tensor( 29 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 30 | out.size(4)).zero_() 31 | if not no_cuda: 32 | if isinstance(out.data, torch.cuda.FloatTensor): 33 | zero_pads = zero_pads.cuda() 34 | 35 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 36 | 37 | return out 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 44 | super(BasicBlock, self).__init__() 45 | self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation) 46 | self.bn1 = nn.BatchNorm3d(planes) 47 | self.relu = nn.ReLU(inplace=True) 48 | self.conv2 = conv3x3x3(planes, planes, dilation=dilation) 49 | self.bn2 = nn.BatchNorm3d(planes) 50 | self.downsample = downsample 51 | self.stride = stride 52 | self.dilation = dilation 53 | 54 | def forward(self, x): 55 | residual = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | out = self.conv2(out) 61 | out = self.bn2(out) 62 | 63 | if self.downsample is not None: 64 | residual = self.downsample(x) 65 | 66 | out += residual 67 | out = self.relu(out) 68 | 69 | return out 70 | 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 76 | super(Bottleneck, self).__init__() 77 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 78 | self.bn1 = nn.BatchNorm3d(planes) 79 | self.conv2 = nn.Conv3d( 80 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False) 81 | self.bn2 = nn.BatchNorm3d(planes) 82 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm3d(planes * 4) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | self.dilation = dilation 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | residual = self.downsample(x) 105 | 106 | out += residual 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, 115 | block, 116 | layers, 117 | sample_input_D, 118 | sample_input_H, 119 | sample_input_W, 120 | num_seg_classes, 121 | shortcut_type='B', 122 | no_cuda = False): 123 | self.inplanes = 64 124 | self.no_cuda = no_cuda 125 | super(ResNet, self).__init__() 126 | self.conv1 = nn.Conv3d( 127 | 1, 128 | 64, 129 | kernel_size=7, 130 | stride=(2, 2, 2), 131 | padding=(3, 3, 3), 132 | bias=False) 133 | 134 | self.bn1 = nn.BatchNorm3d(64) 135 | self.relu = nn.ReLU(inplace=True) 136 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 138 | self.layer2 = self._make_layer( 139 | block, 128, layers[1], shortcut_type, stride=2) 140 | self.layer3 = self._make_layer( 141 | block, 256, layers[2], shortcut_type, stride=1, dilation=2) 142 | self.layer4 = self._make_layer( 143 | block, 512, layers[3], shortcut_type, stride=1, dilation=4) 144 | 145 | self.conv_seg = nn.Sequential( 146 | nn.ConvTranspose3d( 147 | 512 * block.expansion, 148 | 32, 149 | 2, 150 | stride=2 151 | ), 152 | nn.BatchNorm3d(32), 153 | nn.ReLU(inplace=True), 154 | nn.Conv3d( 155 | 32, 156 | 32, 157 | kernel_size=3, 158 | stride=(1, 1, 1), 159 | padding=(1, 1, 1), 160 | bias=False), 161 | nn.BatchNorm3d(32), 162 | nn.ReLU(inplace=True), 163 | nn.Conv3d( 164 | 32, 165 | num_seg_classes, 166 | kernel_size=1, 167 | stride=(1, 1, 1), 168 | bias=False) 169 | ) 170 | 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv3d): 173 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') 174 | elif isinstance(m, nn.BatchNorm3d): 175 | m.weight.data.fill_(1) 176 | m.bias.data.zero_() 177 | 178 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1): 179 | downsample = None 180 | if stride != 1 or self.inplanes != planes * block.expansion: 181 | if shortcut_type == 'A': 182 | downsample = partial( 183 | downsample_basic_block, 184 | planes=planes * block.expansion, 185 | stride=stride, 186 | no_cuda=self.no_cuda) 187 | else: 188 | downsample = nn.Sequential( 189 | nn.Conv3d( 190 | self.inplanes, 191 | planes * block.expansion, 192 | kernel_size=1, 193 | stride=stride, 194 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 195 | 196 | layers = [] 197 | layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample)) 198 | self.inplanes = planes * block.expansion 199 | for i in range(1, blocks): 200 | layers.append(block(self.inplanes, planes, dilation=dilation)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, x): 205 | x = self.conv1(x) 206 | x = self.bn1(x) 207 | x = self.relu(x) 208 | x = self.maxpool(x) 209 | x = self.layer1(x) 210 | x = self.layer2(x) 211 | x = self.layer3(x) 212 | x = self.layer4(x) 213 | x = self.conv_seg(x) 214 | 215 | return x 216 | 217 | def resnet10(**kwargs): 218 | """Constructs a ResNet-18 model. 219 | """ 220 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 221 | return model 222 | 223 | 224 | def resnet18(**kwargs): 225 | """Constructs a ResNet-18 model. 226 | """ 227 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 228 | return model 229 | 230 | 231 | def resnet34(**kwargs): 232 | """Constructs a ResNet-34 model. 233 | """ 234 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 235 | return model 236 | 237 | 238 | def resnet50(**kwargs): 239 | """Constructs a ResNet-50 model. 240 | """ 241 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 242 | return model 243 | 244 | 245 | def resnet101(**kwargs): 246 | """Constructs a ResNet-101 model. 247 | """ 248 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 249 | return model 250 | 251 | 252 | def resnet152(**kwargs): 253 | """Constructs a ResNet-101 model. 254 | """ 255 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 256 | return model 257 | 258 | 259 | def resnet200(**kwargs): 260 | """Constructs a ResNet-101 model. 261 | """ 262 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 263 | return model 264 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # python requirements 2 | pip>=9.0.1 3 | #logging==0.4.9.6 4 | torch==0.4.1 5 | numpy==1.15.4 6 | nibabel==2.4.1 7 | scipy==1.1.0 8 | argparse==1.1 -------------------------------------------------------------------------------- /setting.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Configs for training & testing 3 | Written by Whalechen 4 | ''' 5 | 6 | import argparse 7 | 8 | def parse_opts(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | '--data_root', 12 | default='./data', 13 | type=str, 14 | help='Root directory path of data') 15 | parser.add_argument( 16 | '--img_list', 17 | default='./data/train.txt', 18 | type=str, 19 | help='Path for image list file') 20 | parser.add_argument( 21 | '--n_seg_classes', 22 | default=2, 23 | type=int, 24 | help="Number of segmentation classes" 25 | ) 26 | parser.add_argument( 27 | '--learning_rate', # set to 0.001 when finetune 28 | default=0.001, 29 | type=float, 30 | help= 31 | 'Initial learning rate (divided by 10 while training by lr scheduler)') 32 | parser.add_argument( 33 | '--num_workers', 34 | default=4, 35 | type=int, 36 | help='Number of jobs') 37 | parser.add_argument( 38 | '--batch_size', default=1, type=int, help='Batch Size') 39 | parser.add_argument( 40 | '--phase', default='train', type=str, help='Phase of train or test') 41 | parser.add_argument( 42 | '--save_intervals', 43 | default=10, 44 | type=int, 45 | help='Interation for saving model') 46 | parser.add_argument( 47 | '--n_epochs', 48 | default=200, 49 | type=int, 50 | help='Number of total epochs to run') 51 | parser.add_argument( 52 | '--input_D', 53 | default=56, 54 | type=int, 55 | help='Input size of depth') 56 | parser.add_argument( 57 | '--input_H', 58 | default=448, 59 | type=int, 60 | help='Input size of height') 61 | parser.add_argument( 62 | '--input_W', 63 | default=448, 64 | type=int, 65 | help='Input size of width') 66 | parser.add_argument( 67 | '--resume_path', 68 | default='', 69 | type=str, 70 | help= 71 | 'Path for resume model.' 72 | ) 73 | parser.add_argument( 74 | '--pretrain_path', 75 | default='pretrain/resnet_50.pth', 76 | type=str, 77 | help= 78 | 'Path for pretrained model.' 79 | ) 80 | parser.add_argument( 81 | '--new_layer_names', 82 | #default=['upsample1', 'cmp_layer3', 'upsample2', 'cmp_layer2', 'upsample3', 'cmp_layer1', 'upsample4', 'cmp_conv1', 'conv_seg'], 83 | default=['conv_seg'], 84 | type=list, 85 | help='New layer except for backbone') 86 | parser.add_argument( 87 | '--no_cuda', action='store_true', help='If true, cuda is not used.') 88 | parser.set_defaults(no_cuda=False) 89 | parser.add_argument( 90 | '--gpu_id', 91 | nargs='+', 92 | type=int, 93 | help='Gpu id lists') 94 | parser.add_argument( 95 | '--model', 96 | default='resnet', 97 | type=str, 98 | help='(resnet | preresnet | wideresnet | resnext | densenet | ') 99 | parser.add_argument( 100 | '--model_depth', 101 | default=50, 102 | type=int, 103 | help='Depth of resnet (10 | 18 | 34 | 50 | 101)') 104 | parser.add_argument( 105 | '--resnet_shortcut', 106 | default='B', 107 | type=str, 108 | help='Shortcut type of resnet (A | B)') 109 | parser.add_argument( 110 | '--manual_seed', default=1, type=int, help='Manually set random seed') 111 | parser.add_argument( 112 | '--ci_test', action='store_true', help='If true, ci testing is used.') 113 | args = parser.parse_args() 114 | args.save_folder = "./trails/models/{}_{}".format(args.model, args.model_depth) 115 | 116 | return args 117 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from setting import parse_opts 2 | from datasets.brains18 import BrainS18Dataset 3 | from model import generate_model 4 | import torch 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | import torch.nn.functional as F 8 | from scipy import ndimage 9 | import nibabel as nib 10 | import sys 11 | import os 12 | from utils.file_process import load_lines 13 | import numpy as np 14 | 15 | 16 | def seg_eval(pred, label, clss): 17 | """ 18 | calculate the dice between prediction and ground truth 19 | input: 20 | pred: predicted mask 21 | label: groud truth 22 | clss: eg. [0, 1] for binary class 23 | """ 24 | Ncls = len(clss) 25 | dices = np.zeros(Ncls) 26 | [depth, height, width] = pred.shape 27 | for idx, cls in enumerate(clss): 28 | # binary map 29 | pred_cls = np.zeros([depth, height, width]) 30 | pred_cls[np.where(pred == cls)] = 1 31 | label_cls = np.zeros([depth, height, width]) 32 | label_cls[np.where(label == cls)] = 1 33 | 34 | # cal the inter & conv 35 | s = pred_cls + label_cls 36 | inter = len(np.where(s >= 2)[0]) 37 | conv = len(np.where(s >= 1)[0]) + inter 38 | try: 39 | dice = 2.0 * inter / conv 40 | except: 41 | print("conv is zeros when dice = 2.0 * inter / conv") 42 | dice = -1 43 | 44 | dices[idx] = dice 45 | 46 | return dices 47 | 48 | def test(data_loader, model, img_names, sets): 49 | masks = [] 50 | model.eval() # for testing 51 | for batch_id, batch_data in enumerate(data_loader): 52 | # forward 53 | volume = batch_data 54 | if not sets.no_cuda: 55 | volume = volume.cuda() 56 | with torch.no_grad(): 57 | probs = model(volume) 58 | probs = F.softmax(probs, dim=1) 59 | 60 | # resize mask to original size 61 | [batchsize, _, mask_d, mask_h, mask_w] = probs.shape 62 | data = nib.load(os.path.join(sets.data_root, img_names[batch_id])) 63 | data = data.get_data() 64 | [depth, height, width] = data.shape 65 | mask = probs[0] 66 | scale = [1, depth*1.0/mask_d, height*1.0/mask_h, width*1.0/mask_w] 67 | mask = ndimage.interpolation.zoom(mask, scale, order=1) 68 | mask = np.argmax(mask, axis=0) 69 | 70 | masks.append(mask) 71 | 72 | return masks 73 | 74 | 75 | if __name__ == '__main__': 76 | # settting 77 | sets = parse_opts() 78 | sets.target_type = "normal" 79 | sets.phase = 'test' 80 | 81 | # getting model 82 | checkpoint = torch.load(sets.resume_path) 83 | net, _ = generate_model(sets) 84 | net.load_state_dict(checkpoint['state_dict']) 85 | 86 | # data tensor 87 | testing_data =BrainS18Dataset(sets.data_root, sets.img_list, sets) 88 | data_loader = DataLoader(testing_data, batch_size=1, shuffle=False, num_workers=1, pin_memory=False) 89 | 90 | # testing 91 | img_names = [info.split(" ")[0] for info in load_lines(sets.img_list)] 92 | masks = test(data_loader, net, img_names, sets) 93 | 94 | # evaluation: calculate dice 95 | label_names = [info.split(" ")[1] for info in load_lines(sets.img_list)] 96 | Nimg = len(label_names) 97 | dices = np.zeros([Nimg, sets.n_seg_classes]) 98 | for idx in range(Nimg): 99 | label = nib.load(os.path.join(sets.data_root, label_names[idx])) 100 | label = label.get_data() 101 | dices[idx, :] = seg_eval(masks[idx], label, range(sets.n_seg_classes)) 102 | 103 | # print result 104 | for idx in range(1, sets.n_seg_classes): 105 | mean_dice_per_task = np.mean(dices[:, idx]) 106 | print('mean dice for class-{} is {}'.format(idx, mean_dice_per_task)) 107 | -------------------------------------------------------------------------------- /test_ci.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | print("test successful!") -------------------------------------------------------------------------------- /toy_data/MRBrainS18/images/070.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/toy_data/MRBrainS18/images/070.nii.gz -------------------------------------------------------------------------------- /toy_data/MRBrainS18/labels/070.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent/MedicalNet/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/toy_data/MRBrainS18/labels/070.nii.gz -------------------------------------------------------------------------------- /toy_data/MRBrainS18/test_ci.txt: -------------------------------------------------------------------------------- 1 | MRBrainS18/images/070.nii.gz MRBrainS18/labels/070.nii.gz 2 | -------------------------------------------------------------------------------- /toy_data/test_ci.txt: -------------------------------------------------------------------------------- 1 | MRBrainS18/images/070.nii.gz MRBrainS18/labels/070.nii.gz 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Training code for MRBrainS18 datasets segmentation 3 | Written by Whalechen 4 | ''' 5 | 6 | from setting import parse_opts 7 | from datasets.brains18 import BrainS18Dataset 8 | from model import generate_model 9 | import torch 10 | import numpy as np 11 | from torch import nn 12 | from torch import optim 13 | from torch.optim import lr_scheduler 14 | from torch.utils.data import DataLoader 15 | import time 16 | from utils.logger import log 17 | from scipy import ndimage 18 | import os 19 | 20 | def train(data_loader, model, optimizer, scheduler, total_epochs, save_interval, save_folder, sets): 21 | # settings 22 | batches_per_epoch = len(data_loader) 23 | log.info('{} epochs in total, {} batches per epoch'.format(total_epochs, batches_per_epoch)) 24 | loss_seg = nn.CrossEntropyLoss(ignore_index=-1) 25 | 26 | print("Current setting is:") 27 | print(sets) 28 | print("\n\n") 29 | if not sets.no_cuda: 30 | loss_seg = loss_seg.cuda() 31 | 32 | model.train() 33 | train_time_sp = time.time() 34 | for epoch in range(total_epochs): 35 | log.info('Start epoch {}'.format(epoch)) 36 | 37 | scheduler.step() 38 | log.info('lr = {}'.format(scheduler.get_lr())) 39 | 40 | for batch_id, batch_data in enumerate(data_loader): 41 | # getting data batch 42 | batch_id_sp = epoch * batches_per_epoch 43 | volumes, label_masks = batch_data 44 | 45 | if not sets.no_cuda: 46 | volumes = volumes.cuda() 47 | 48 | optimizer.zero_grad() 49 | out_masks = model(volumes) 50 | # resize label 51 | [n, _, d, h, w] = out_masks.shape 52 | new_label_masks = np.zeros([n, d, h, w]) 53 | for label_id in range(n): 54 | label_mask = label_masks[label_id] 55 | [ori_c, ori_d, ori_h, ori_w] = label_mask.shape 56 | label_mask = np.reshape(label_mask, [ori_d, ori_h, ori_w]) 57 | scale = [d*1.0/ori_d, h*1.0/ori_h, w*1.0/ori_w] 58 | label_mask = ndimage.interpolation.zoom(label_mask, scale, order=0) 59 | new_label_masks[label_id] = label_mask 60 | 61 | new_label_masks = torch.tensor(new_label_masks).to(torch.int64) 62 | if not sets.no_cuda: 63 | new_label_masks = new_label_masks.cuda() 64 | 65 | # calculating loss 66 | loss_value_seg = loss_seg(out_masks, new_label_masks) 67 | loss = loss_value_seg 68 | loss.backward() 69 | optimizer.step() 70 | 71 | avg_batch_time = (time.time() - train_time_sp) / (1 + batch_id_sp) 72 | log.info( 73 | 'Batch: {}-{} ({}), loss = {:.3f}, loss_seg = {:.3f}, avg_batch_time = {:.3f}'\ 74 | .format(epoch, batch_id, batch_id_sp, loss.item(), loss_value_seg.item(), avg_batch_time)) 75 | 76 | if not sets.ci_test: 77 | # save model 78 | if batch_id == 0 and batch_id_sp != 0 and batch_id_sp % save_interval == 0: 79 | #if batch_id_sp != 0 and batch_id_sp % save_interval == 0: 80 | model_save_path = '{}_epoch_{}_batch_{}.pth.tar'.format(save_folder, epoch, batch_id) 81 | model_save_dir = os.path.dirname(model_save_path) 82 | if not os.path.exists(model_save_dir): 83 | os.makedirs(model_save_dir) 84 | 85 | log.info('Save checkpoints: epoch = {}, batch_id = {}'.format(epoch, batch_id)) 86 | torch.save({ 87 | 'ecpoch': epoch, 88 | 'batch_id': batch_id, 89 | 'state_dict': model.state_dict(), 90 | 'optimizer': optimizer.state_dict()}, 91 | model_save_path) 92 | 93 | print('Finished training') 94 | if sets.ci_test: 95 | exit() 96 | 97 | 98 | if __name__ == '__main__': 99 | # settting 100 | sets = parse_opts() 101 | if sets.ci_test: 102 | sets.img_list = './toy_data/test_ci.txt' 103 | sets.n_epochs = 1 104 | sets.no_cuda = True 105 | sets.data_root = './toy_data' 106 | sets.pretrain_path = '' 107 | sets.num_workers = 0 108 | sets.model_depth = 10 109 | sets.resnet_shortcut = 'A' 110 | sets.input_D = 14 111 | sets.input_H = 28 112 | sets.input_W = 28 113 | 114 | 115 | 116 | # getting model 117 | torch.manual_seed(sets.manual_seed) 118 | model, parameters = generate_model(sets) 119 | print (model) 120 | # optimizer 121 | if sets.ci_test: 122 | params = [{'params': parameters, 'lr': sets.learning_rate}] 123 | else: 124 | params = [ 125 | { 'params': parameters['base_parameters'], 'lr': sets.learning_rate }, 126 | { 'params': parameters['new_parameters'], 'lr': sets.learning_rate*100 } 127 | ] 128 | optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3) 129 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 130 | 131 | # train from resume 132 | if sets.resume_path: 133 | if os.path.isfile(sets.resume_path): 134 | print("=> loading checkpoint '{}'".format(sets.resume_path)) 135 | checkpoint = torch.load(sets.resume_path) 136 | model.load_state_dict(checkpoint['state_dict']) 137 | optimizer.load_state_dict(checkpoint['optimizer']) 138 | print("=> loaded checkpoint '{}' (epoch {})" 139 | .format(sets.resume_path, checkpoint['epoch'])) 140 | 141 | # getting data 142 | sets.phase = 'train' 143 | if sets.no_cuda: 144 | sets.pin_memory = False 145 | else: 146 | sets.pin_memory = True 147 | training_dataset = BrainS18Dataset(sets.data_root, sets.img_list, sets) 148 | data_loader = DataLoader(training_dataset, batch_size=sets.batch_size, shuffle=True, num_workers=sets.num_workers, pin_memory=sets.pin_memory) 149 | 150 | # training 151 | train(data_loader, model, optimizer, scheduler, total_epochs=sets.n_epochs, save_interval=sets.save_intervals, save_folder=sets.save_folder, sets=sets) 152 | -------------------------------------------------------------------------------- /utils/file_process.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import os 5 | import os.path as osp 6 | 7 | def load_lines(file_path): 8 | """Read file into a list of lines. 9 | 10 | Input 11 | file_path: file path 12 | 13 | Output 14 | lines: an array of lines 15 | """ 16 | with open(file_path, 'r') as fio: 17 | lines = fio.read().splitlines() 18 | return lines 19 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Written by Whalechen 3 | ''' 4 | 5 | import logging 6 | 7 | logging.basicConfig( 8 | format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', 9 | datefmt='%Y-%m-%d %H:%M:%S', 10 | level=logging.DEBUG) 11 | 12 | log = logging.getLogger() 13 | --------------------------------------------------------------------------------