├── .gitignore ├── README.md ├── config.py ├── dataset.py ├── fig1.png ├── fig2.png ├── loss.py ├── metric.py ├── model.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | .vscode/settings.json 133 | 134 | tmp.py 135 | .vscode/launch.json 136 | 137 | models/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedCross 2 | Federated Cross Learning for Medical Image Segmentation 3 | 4 | This is a python (PyTorch) implementation of **Federated Cross Learning (FedCross)** method proposed in our paper [**"Federated Cross Learning for Medical Image Segmentation"**](https://openreview.net/forum?id=DrZbwobH_zo) published in *Medical Imaging with Deep Learning 2023* conference (Nashville, Tennessee, United States, Jul. 10-12, 2023). A preprint version of this paper is also available on [arXiv](https://arxiv.org/abs/2204.02450). 5 | 6 | ## Citation 7 | *X. Xu, H. H. Deng, T. Chen, T. Kuang, J. C. Barber, D. Kim, J. Gateno, J. J. Xia, and P. Yan, "Federated Cross Learning for Medical Image Segmentation," in Medical Imaging with Deep Learning 2023. Nashville, Tennessee, United States, Jul. 10-12, 2023.* 8 | 9 | @inproceedings{Xu2023FedCross, 10 | title={Federated Cross Learning for Medical Image Segmentation}, 11 | author={Xuanang Xu and Hannah H. Deng and Tianyi Chen and Tianshu Kuang and Joshua C. Barber and Daeseung Kim and Jaime Gateno and James J. Xia and Pingkun Yan}, 12 | booktitle={Medical Imaging with Deep Learning}, 13 | year={2023}, 14 | url={https://openreview.net/forum?id=DrZbwobH_zo} 15 | } 16 | 17 | ## Update 18 | - **Nov 21, 2023**: Fix a bug (`read_image` function was missing in `utils.py` file). 19 | 20 | ## Abstract 21 | Federated learning (FL) can collaboratively train deep learning models using isolated patient data owned by different hospitals for various clinical applications, including medical image segmentation. However, a major problem of FL is its performance degradation when dealing with data that are not independently and identically distributed (non-iid), which is often the case in medical images. In this paper, we first conduct a theoretical analysis on the FL algorithm to reveal the problem of model aggregation during training on non-iid data. With the insights gained through the analysis, we propose a simple yet effective method, federated cross learning (FedCross), to tackle this challenging problem. Unlike the conventional FL methods that combine multiple individually trained local models on a server node, our FedCross sequentially trains the global model across different clients in a round-robin manner, and thus the entire training procedure does not involve any model aggregation steps. To further improve its performance to be comparable with the centralized learning method, we combine the FedCross with an ensemble learning mechanism to compose a federated cross ensemble learning (FedCrossEns) method. Finally, we conduct extensive experiments using a set of public datasets. The experimental results show that the proposed FedCross training strategy outperforms the mainstream FL methods on non-iid data. In addition to improving the segmentation performance, our FedCrossEns can further provide a quantitative estimation of the model uncertainty, demonstrating the effectiveness and clinical significance of our designs. Source code is publicly available at [https://github.com/DIAL-RPI/FedCross](https://github.com/DIAL-RPI/FedCross). 22 | 23 | ## An illustration of non-iid problem in FL model aggregation 24 | The model aggregation process in FL may lead to sub-optimal solution when dealing with non-iid data. In the figure below, (a) and (b) show that the locally trained models $θ^{J+1}_k$ and $θ^{J+1}_m$ each individually achieve their minimal in the loss landscape of the client dataset $D_k$ and $D_m$, respectively. (c) indicates that the aggregated FL model $θ^′$ is located at a non-minimal position in the global loss landscape of $D_k \cup D_m$. 25 | 26 | 27 | ## Method 28 | ### Training schemes of (a) FedAvg, (b) FedCross, and (c) FedCrossEns 29 | 30 | 31 | ## Contact 32 | You are welcome to contact us: 33 | - [xux12@rpi.edu](mailto:xux12@rpi.edu) ([Dr. Xuanang Xu](https://superxuang.github.io/)) 34 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | cfg = {} 2 | 3 | cfg['cls_num'] = 1 4 | cfg['gpu'] = '0,1,2,3' # to use multiple gpu: cfg['gpu'] = '0,1,2,3' 5 | cfg['batch_size'] = 16 # training batch size 6 | cfg['test_batch_size'] = 16 # testing batch size 7 | cfg['lr'] = 0.01 # base learning rate 8 | cfg['model_path'] = '/set_your_own_model_path/models' # the path where to save the trained model and evaluation results 9 | cfg['rs_size'] = [160,160,32] # resample size: [x, y, z] 10 | cfg['rs_spacing'] = [0.5,0.5,1.0] # resample spacing: [x, y, z]. non-positive value means adaptive spacing fit the physical size: rs_size * rs_spacing = origin_size * origin_spacing 11 | cfg['rs_intensity'] = [-200.0, 400.0] # rescale intensity from [min, max] to [0, 1]. 12 | cfg['cpu_thread'] = 4 # multi-thread for data loading. zero means single thread. 13 | cfg['commu_times'] = 50 # number of communication rounds 14 | cfg['epoch_per_commu'] = 32 # number of local training epochs within one communication round 15 | 16 | # map labels of different client datasets to a uniform label map 17 | cfg['label_map'] = { 18 | 'MSD':{1:1, 2:1}, 19 | 'NCI-ISBI':{1:1, 2:1}, 20 | 'PROMISE12':{1:1}, 21 | 'PROSTATEx':{1:1}, 22 | } 23 | 24 | # exclude any samples in the form of '[dataset_name, case_name]' 25 | cfg['exclude_case'] = [ 26 | ] 27 | 28 | # data path of each client dataset 29 | cfg['node_list'] = [ 30 | ['Node-1', ['MSD'], ['/set_your_own_data_path/MSD-Prostate'], [19,3,10]], # 32 in total 31 | ['Node-2', ['NCI-ISBI'], ['/set_your_own_data_path/NCI-ISBI-Prostate'], [48,8,24]], # 80 in total 32 | ['Node-3', ['PROMISE12'], ['/set_your_own_data_path/PROMISE12'], [30,5,15]], # 50 in total 33 | ['Node-4', ['PROSTATEx'], ['/set_your_own_data_path/PROSTATEx'], [122,20,62]], # 204 in total 34 | ] -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from torch.utils import data 5 | import itk 6 | import numpy as np 7 | import random 8 | import SimpleITK as sitk 9 | from scipy import stats 10 | 11 | def read_image(fname, imtype): 12 | reader = itk.ImageFileReader[imtype].New() 13 | reader.SetFileName(fname) 14 | reader.Update() 15 | image = reader.GetOutput() 16 | return image 17 | 18 | def image_2_array(image): 19 | arr = itk.GetArrayFromImage(image) 20 | return arr 21 | 22 | def array_2_image(arr, spacing, origin, imtype): 23 | image = itk.GetImageFromArray(arr) 24 | image.SetSpacing((spacing[0], spacing[1], spacing[2])) 25 | image.SetOrigin((origin[0], origin[1], origin[2])) 26 | cast = itk.CastImageFilter[type(image), imtype].New() 27 | cast.SetInput(image) 28 | cast.Update() 29 | image = cast.GetOutput() 30 | return image 31 | 32 | def scan_path(dataset_name, dataset_path): 33 | entries = [] 34 | if dataset_name == 'MSD': 35 | for f in os.listdir('{}/imagesTr'.format(dataset_path)): 36 | if f.startswith('prostate_') and f.endswith('.nii.gz'): 37 | case_name = f.split('.nii.gz')[0] 38 | if os.path.isfile('{}/labelsTr/{}'.format(dataset_path, f)): 39 | image_name = '{}/imagesTr/{}'.format(dataset_path, f) 40 | label_name = '{}/labelsTr/{}'.format(dataset_path, f) 41 | entries.append([dataset_name, case_name, image_name, label_name]) 42 | elif dataset_name == 'NCI-ISBI': 43 | for f in os.listdir('{}/image'.format(dataset_path)): 44 | if f.startswith('Prostate') and f.endswith('.nii.gz'): 45 | case_name = f.split('.nii.gz')[0] 46 | if os.path.isfile('{}/label/{}'.format(dataset_path, f)): 47 | image_name = '{}/image/{}'.format(dataset_path, f) 48 | label_name = '{}/label/{}'.format(dataset_path, f) 49 | entries.append([dataset_name, case_name, image_name, label_name]) 50 | elif dataset_name == 'PROMISE12': 51 | for f in os.listdir('{}/image'.format(dataset_path)): 52 | if f.startswith('Case') and f.endswith('.nii.gz'): 53 | case_name = f.split('.nii.gz')[0] 54 | if os.path.isfile('{}/label/{}'.format(dataset_path, f)): 55 | image_name = '{}/image/{}'.format(dataset_path, f) 56 | label_name = '{}/label/{}'.format(dataset_path, f) 57 | entries.append([dataset_name, case_name, image_name, label_name]) 58 | elif dataset_name == 'PROSTATEx': 59 | for f in os.listdir('{}/image'.format(dataset_path)): 60 | if f.startswith('ProstateX-') and f.endswith('.nii.gz'): 61 | case_name = f.split('.nii.gz')[0] 62 | if os.path.isfile('{}/label/{}'.format(dataset_path, f)): 63 | image_name = '{}/image/{}'.format(dataset_path, f) 64 | label_name = '{}/label/{}'.format(dataset_path, f) 65 | entries.append([dataset_name, case_name, image_name, label_name]) 66 | return entries 67 | 68 | def create_folds(dataset_name, dataset_path, fold_name, fraction, exclude_case): 69 | fold_file_name = '{0:s}/data_split-{1:s}.txt'.format(sys.path[0], fold_name) 70 | folds = {} 71 | if os.path.exists(fold_file_name): 72 | with open(fold_file_name, 'r') as fold_file: 73 | strlines = fold_file.readlines() 74 | for strline in strlines: 75 | strline = strline.rstrip('\n') 76 | params = strline.split() 77 | fold_id = int(params[0]) 78 | if fold_id not in folds: 79 | folds[fold_id] = [] 80 | folds[fold_id].append([params[1], params[2], params[3], params[4]]) 81 | else: 82 | entries = [] 83 | for [d_name, d_path] in zip(dataset_name, dataset_path): 84 | entries.extend(scan_path(d_name, d_path)) 85 | for e in entries: 86 | if e[0:2] in exclude_case: 87 | entries.remove(e) 88 | random.shuffle(entries) 89 | 90 | ptr = 0 91 | for fold_id in range(len(fraction)): 92 | folds[fold_id] = entries[ptr:ptr+fraction[fold_id]] 93 | ptr += fraction[fold_id] 94 | 95 | with open(fold_file_name, 'w') as fold_file: 96 | for fold_id in range(len(fraction)): 97 | for [d_name, case_name, image_path, label_path] in folds[fold_id]: 98 | fold_file.write('{0:d} {1:s} {2:s} {3:s} {4:s}\n'.format(fold_id, d_name, case_name, image_path, label_path)) 99 | 100 | folds_size = [len(x) for x in folds.values()] 101 | 102 | return folds, folds_size 103 | 104 | def generate_transform(aug, min_offset, max_offset): 105 | if aug: 106 | min_rotate = -0.1 # [rad] 107 | max_rotate = 0.1 # [rad] 108 | t = itk.Euler3DTransform[itk.D].New() 109 | euler_parameters = t.GetParameters() 110 | euler_parameters = itk.OptimizerParameters[itk.D](t.GetNumberOfParameters()) 111 | offset_x = min_offset[0] + random.random() * (max_offset[0] - min_offset[0]) # rotate 112 | offset_y = min_offset[1] + random.random() * (max_offset[1] - min_offset[1]) # rotate 113 | offset_z = min_offset[2] + random.random() * (max_offset[2] - min_offset[2]) # rotate 114 | rotate_x = min_rotate + random.random() * (max_rotate - min_rotate) # tranlate 115 | rotate_y = min_rotate + random.random() * (max_rotate - min_rotate) # tranlate 116 | rotate_z = min_rotate + random.random() * (max_rotate - min_rotate) # tranlate 117 | euler_parameters[0] = rotate_x # rotate 118 | euler_parameters[1] = rotate_y # rotate 119 | euler_parameters[2] = rotate_z # rotate 120 | euler_parameters[3] = offset_x # tranlate 121 | euler_parameters[4] = offset_y # tranlate 122 | euler_parameters[5] = offset_z # tranlate 123 | t.SetParameters(euler_parameters) 124 | else: 125 | offset_x = 0 126 | offset_y = 0 127 | offset_z = 0 128 | rotate_x = 0 129 | rotate_y = 0 130 | rotate_z = 0 131 | t = itk.IdentityTransform[itk.D, 3].New() 132 | return t, [offset_x, offset_y, offset_z, rotate_x, rotate_y, rotate_z] 133 | 134 | def resample(image, imtype, size, spacing, origin, transform, linear, dtype): 135 | o_origin = image.GetOrigin() 136 | o_spacing = image.GetSpacing() 137 | o_size = image.GetBufferedRegion().GetSize() 138 | output = {} 139 | output['org_size'] = np.array(o_size, dtype=int) 140 | output['org_spacing'] = np.array(o_spacing, dtype=float) 141 | output['org_origin'] = np.array(o_origin, dtype=float) 142 | 143 | if origin is None: # if no origin point specified, center align the resampled image with the original image 144 | new_size = np.zeros(3, dtype=int) 145 | new_spacing = np.zeros(3, dtype=float) 146 | new_origin = np.zeros(3, dtype=float) 147 | for i in range(3): 148 | new_size[i] = size[i] 149 | if spacing[i] > 0: 150 | new_spacing[i] = spacing[i] 151 | new_origin[i] = o_origin[i] + o_size[i]*o_spacing[i]*0.5 - size[i]*spacing[i]*0.5 152 | else: 153 | new_spacing[i] = o_size[i] * o_spacing[i] / size[i] 154 | new_origin[i] = o_origin[i] 155 | else: 156 | new_size = np.array(size, dtype=int) 157 | new_spacing = np.array(spacing, dtype=float) 158 | new_origin = np.array(origin, dtype=float) 159 | 160 | output['size'] = new_size 161 | output['spacing'] = new_spacing 162 | output['origin'] = new_origin 163 | 164 | resampler = itk.ResampleImageFilter[imtype, imtype].New() 165 | resampler.SetInput(image) 166 | resampler.SetSize((int(new_size[0]), int(new_size[1]), int(new_size[2]))) 167 | resampler.SetOutputSpacing((float(new_spacing[0]), float(new_spacing[1]), float(new_spacing[2]))) 168 | resampler.SetOutputOrigin((float(new_origin[0]), float(new_origin[1]), float(new_origin[2]))) 169 | resampler.SetTransform(transform) 170 | if linear: 171 | resampler.SetInterpolator(itk.LinearInterpolateImageFunction[imtype, itk.D].New()) 172 | else: 173 | resampler.SetInterpolator(itk.NearestNeighborInterpolateImageFunction[imtype, itk.D].New()) 174 | resampler.SetDefaultPixelValue(0) 175 | resampler.Update() 176 | rs_image = resampler.GetOutput() 177 | image_array = itk.GetArrayFromImage(rs_image) 178 | image_array = image_array[np.newaxis, :].astype(dtype) 179 | output['array'] = image_array 180 | 181 | return output 182 | 183 | def zscore_normalize(x): 184 | y = (x - x.mean()) / x.std() 185 | return y 186 | 187 | def make_onehot(input, cls): 188 | oh = np.repeat(np.zeros_like(input), cls+1, axis=0) 189 | for i in range(cls+1): 190 | tmp = np.zeros_like(input) 191 | tmp[input==i] = 1 192 | oh[i,:] = tmp 193 | return oh 194 | 195 | def make_flag(cls, labelmap): 196 | flag = np.zeros([cls, 1]) 197 | for key in labelmap: 198 | flag[labelmap[key]-1,0] = 1 199 | return flag 200 | 201 | def image2file(image, imtype, fname): 202 | writer = itk.ImageFileWriter[imtype].New() 203 | writer.SetInput(image) 204 | writer.SetFileName(fname) 205 | writer.Update() 206 | 207 | def array2file(array, size, origin, spacing, imtype, fname): 208 | image = itk.GetImageFromArray(array.reshape([size[2], size[1], size[0]])) 209 | image.SetSpacing((spacing[0], spacing[1], spacing[2])) 210 | image.SetOrigin((origin[0], origin[1], origin[2])) 211 | image2file(image, imtype=imtype, fname=fname) 212 | 213 | # dataset of 3D image volume 214 | # 3D volumes are resampled from and center-aligned with the original images 215 | class Dataset(data.Dataset): 216 | def __init__(self, ids, rs_size, rs_spacing, rs_intensity, label_map, cls_num, aug_data): 217 | self.ImageType = itk.Image[itk.F, 3] 218 | self.LabelType = itk.Image[itk.UC, 3] 219 | self.ids = [] 220 | self.rs_size = np.array(rs_size) 221 | self.rs_spacing = np.array(rs_spacing, dtype=np.float) 222 | self.rs_intensity = rs_intensity 223 | self.label_map = label_map 224 | self.cls_num = cls_num 225 | self.aug_data = aug_data 226 | self.im_cache = {} 227 | self.lb_cache = {} 228 | 229 | for [d_name, casename, image_fn, label_fn] in ids: 230 | reader = sitk.ImageFileReader() 231 | reader.SetFileName(image_fn) 232 | reader.ReadImageInformation() 233 | image_size = np.array(reader.GetSize()[:3]) 234 | image_origin = np.array(reader.GetOrigin()[:3], dtype=np.float) 235 | image_spacing = np.array(reader.GetSpacing()[:3], dtype=np.float) 236 | image_phy_size = image_size * image_spacing 237 | patch_phy_size = self.rs_size*self.rs_spacing 238 | 239 | if not aug_data: 240 | patch_num = (image_phy_size/patch_phy_size).astype(int)+1 241 | for p_z in range(patch_num[2]): 242 | for p_y in range(patch_num[1]): 243 | for p_x in range(patch_num[0]): 244 | patch_origin = np.zeros_like(image_origin) 245 | patch_origin[0] = image_origin[0]+p_x*patch_phy_size[0] 246 | patch_origin[1] = image_origin[1]+p_y*patch_phy_size[1] 247 | patch_origin[2] = image_origin[2]+p_z*patch_phy_size[2] 248 | patch_origin[0] = min(patch_origin[0], image_origin[0]+image_phy_size[0]-patch_phy_size[0]) 249 | patch_origin[1] = min(patch_origin[1], image_origin[1]+image_phy_size[1]-patch_phy_size[1]) 250 | patch_origin[2] = min(patch_origin[2], image_origin[2]+image_phy_size[2]-patch_phy_size[2]) 251 | eof = (p_x == patch_num[0]-1) & (p_y == patch_num[1]-1) & (p_z == patch_num[2]-1) 252 | self.ids.append([d_name, casename, image_fn, label_fn, patch_origin, eof]) 253 | else: 254 | patch_origin = np.zeros_like(image_origin) 255 | patch_origin = image_origin+0.5*image_phy_size-0.5*patch_phy_size 256 | repeat_time = 4 257 | for i in range(repeat_time): 258 | self.ids.append([d_name, casename, image_fn, label_fn, patch_origin, i==repeat_time-1]) 259 | 260 | 261 | def __len__(self): 262 | return len(self.ids) 263 | 264 | def __getitem__(self, index): 265 | [d_name, casename, image_fn, label_fn, patch_origin, eof] = self.ids[index] 266 | 267 | if image_fn not in self.im_cache: 268 | src_image = read_image(fname=image_fn, imtype=self.ImageType) 269 | image_cache = {} 270 | image_cache['size'] = np.array(src_image.GetBufferedRegion().GetSize()) 271 | image_cache['origin'] = np.array(src_image.GetOrigin(), dtype=np.float) 272 | image_cache['spacing'] = np.array(src_image.GetSpacing(), dtype=np.float) 273 | image_cache['array'] = zscore_normalize(image_2_array(src_image).copy()) 274 | self.im_cache[image_fn] = image_cache 275 | image_cache = self.im_cache[image_fn] 276 | 277 | min_offset = image_cache['origin'] - patch_origin 278 | max_offset = image_cache['origin']+image_cache['spacing']*image_cache['size']-patch_origin-self.rs_size*self.rs_spacing 279 | 280 | t, _ = generate_transform(self.aug_data, min_offset, max_offset) 281 | 282 | src_image = array_2_image(image_cache['array'], image_cache['spacing'], image_cache['origin'], self.ImageType) 283 | 284 | image = resample( 285 | image=src_image, imtype=self.ImageType, 286 | size=self.rs_size, spacing=self.rs_spacing, origin=patch_origin, 287 | transform=t, linear=True, dtype=np.float32) 288 | 289 | if label_fn not in self.lb_cache: 290 | src_label = read_image(fname=label_fn, imtype=self.LabelType) 291 | label_cache = {} 292 | label_cache['origin'] = np.array(src_label.GetOrigin(), dtype=np.float) 293 | label_cache['spacing'] = np.array(src_label.GetSpacing(), dtype=np.float) 294 | label_cache['array'] = image_2_array(src_label).copy() 295 | self.lb_cache[label_fn] = label_cache 296 | label_cache = self.lb_cache[label_fn] 297 | src_label = array_2_image(label_cache['array'], label_cache['spacing'], label_cache['origin'], self.LabelType) 298 | 299 | label = resample( 300 | image=src_label, imtype=self.LabelType, 301 | size=self.rs_size, spacing=self.rs_spacing, origin=patch_origin, 302 | transform=t, linear=False, dtype=np.int64) 303 | 304 | tmp_array = np.zeros_like(label['array']) 305 | lmap = self.label_map[d_name] 306 | for key in lmap: 307 | tmp_array[label['array'] == key] = lmap[key] 308 | label['array'] = tmp_array 309 | #label_bin = make_onehot(label['array'], cls=self.cls_num) 310 | label_exist = make_flag(cls=self.cls_num, labelmap=self.label_map[d_name]) 311 | 312 | image_tensor = torch.from_numpy(image['array']) 313 | label_tensor = torch.from_numpy(label['array']) 314 | 315 | output = {} 316 | output['data'] = image_tensor 317 | output['label'] = label_tensor 318 | output['label_exist'] = label_exist 319 | output['dataset'] = d_name 320 | output['case'] = casename 321 | output['label_fname'] = label_fn 322 | output['size'] = image['size'] 323 | output['spacing'] = image['spacing'] 324 | output['origin'] = image['origin'] 325 | output['org_size'] = image['org_size'] 326 | output['org_spacing'] = image['org_spacing'] 327 | output['org_origin'] = image['org_origin'] 328 | output['eof'] = eof 329 | 330 | return output -------------------------------------------------------------------------------- /fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIAL-RPI/FedCross/d64c637783a30a35599680d7913673cd2ea67898/fig1.png -------------------------------------------------------------------------------- /fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIAL-RPI/FedCross/d64c637783a30a35599680d7913673cd2ea67898/fig2.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class BinaryDiceLoss(nn.Module): 6 | """Dice loss of binary class 7 | Args: 8 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 9 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 10 | predict: A tensor of shape [N, *] 11 | target: A tensor of shape same with predict 12 | Returns: 13 | Loss tensor according to arg reduction 14 | Raise: 15 | Exception if unexpected reduction 16 | """ 17 | def __init__(self, smooth=1, p=2): 18 | super(BinaryDiceLoss, self).__init__() 19 | self.smooth = smooth 20 | self.p = p 21 | 22 | def forward(self, predict, target, flag): 23 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 24 | predict = predict.contiguous().view(predict.shape[0], -1) 25 | target = target.contiguous().view(target.shape[0], -1) 26 | 27 | intersection = self.smooth 28 | union = self.smooth 29 | if flag is None: 30 | pd = predict 31 | gt = target 32 | intersection += torch.sum(pd*gt)*2 33 | union += torch.sum(pd.pow(self.p) + gt.pow(self.p)) 34 | else: 35 | for i in range(target.shape[0]): 36 | if flag[i,0] > 0: 37 | pd = predict[i:i+1,:] 38 | gt = target[i:i+1,:] 39 | intersection += torch.sum(pd*gt)*2 40 | union += torch.sum(pd.pow(self.p) + gt.pow(self.p)) 41 | dice = intersection / union 42 | 43 | loss = 1 - dice 44 | return loss 45 | 46 | class DiceLoss(nn.Module): 47 | """Dice loss, need one hot encode input 48 | Args: 49 | weight: An array of shape [num_classes,] 50 | ignore_index: class index to ignore 51 | predict: A tensor of shape [N, C, *] 52 | target: A tensor of same shape with predict 53 | other args pass to BinaryDiceLoss 54 | Return: 55 | same as BinaryDiceLoss 56 | """ 57 | def __init__(self, weight=None, ignore_index=[], **kwargs): 58 | super(DiceLoss, self).__init__() 59 | self.kwargs = kwargs 60 | if weight is not None: 61 | self.weight = weight / weight.sum() 62 | else: 63 | self.weight = None 64 | self.ignore_index = ignore_index 65 | 66 | def forward(self, predict, target, flag=None): 67 | assert predict.shape == target.shape, 'predict & target shape do not match' 68 | dice = BinaryDiceLoss(**self.kwargs) 69 | total_loss = 0 70 | total_loss_num = 0 71 | 72 | for c in range(target.shape[1]): 73 | if c not in self.ignore_index: 74 | dice_loss = dice(predict[:, c], target[:, c], flag) 75 | if self.weight is not None: 76 | assert self.weight.shape[0] == target.shape[1], \ 77 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 78 | dice_loss *= self.weight[c] 79 | total_loss += dice_loss 80 | total_loss_num += 1 81 | 82 | if self.weight is not None: 83 | return total_loss 84 | elif total_loss_num > 0: 85 | return total_loss/total_loss_num 86 | else: 87 | return 0 88 | 89 | def make_onehot(input, cls): 90 | oh_list = [] 91 | for c in range(cls): 92 | tmp = torch.zeros_like(input) 93 | tmp[input==c] = 1 94 | oh_list.append(tmp) 95 | oh = torch.cat(oh_list, dim=1) 96 | return oh 97 | 98 | def dice_and_ce_loss(prob, logit, target): 99 | cls_num = np.sum(logit.shape[1]) 100 | target_oh = make_onehot(target, cls=cls_num) 101 | ce_loss = nn.CrossEntropyLoss() 102 | dice_loss = DiceLoss() 103 | l_ce = ce_loss(logit, target.squeeze(dim=1)) 104 | l_dice = dice_loss(prob, target_oh) 105 | 106 | return l_ce, l_dice -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | # Note: 5 | # Use itk here will cause deadlock after the first training epoch 6 | # when using multithread (dataloader num_workers > 0) but reason unknown 7 | import SimpleITK as sitk 8 | from pandas import DataFrame, read_csv 9 | from utils import read_image 10 | 11 | def keep_largest_component(image, largest_n=1): 12 | c_filter = sitk.ConnectedComponentImageFilter() 13 | obj_arr = sitk.GetArrayFromImage(c_filter.Execute(image)) 14 | obj_num = c_filter.GetObjectCount() 15 | tmp_arr = np.zeros_like(obj_arr) 16 | 17 | if obj_num > 0: 18 | obj_vol = np.zeros(obj_num, dtype=np.int64) 19 | for obj_id in range(obj_num): 20 | tmp_arr2 = np.zeros_like(obj_arr) 21 | tmp_arr2[obj_arr == obj_id+1] = 1 22 | obj_vol[obj_id] = np.sum(tmp_arr2) 23 | 24 | sorted_obj_id = np.argsort(obj_vol)[::-1] 25 | 26 | for i in range(min(largest_n, obj_num)): 27 | tmp_arr[obj_arr == sorted_obj_id[i]+1] = 1 28 | 29 | output = sitk.GetImageFromArray(tmp_arr) 30 | output.SetSpacing(image.GetSpacing()) 31 | output.SetOrigin(image.GetOrigin()) 32 | output.SetDirection(image.GetDirection()) 33 | 34 | return output 35 | 36 | def cal_dsc(pd, gt): 37 | y = (np.sum(pd * gt) * 2 + 1) / (np.sum(pd * pd + gt * gt) + 1) 38 | return y 39 | 40 | def cal_asd(a, b): 41 | filter1 = sitk.SignedMaurerDistanceMapImageFilter() 42 | filter1.SetUseImageSpacing(True) 43 | filter1.SetSquaredDistance(False) 44 | a_dist = filter1.Execute(a) 45 | 46 | a_dist = sitk.GetArrayFromImage(a_dist) 47 | a_dist = np.abs(a_dist) 48 | a_edge = np.zeros(a_dist.shape, a_dist.dtype) 49 | a_edge[a_dist == 0] = 1 50 | a_num = np.sum(a_edge) 51 | 52 | filter2 = sitk.SignedMaurerDistanceMapImageFilter() 53 | filter2.SetUseImageSpacing(True) 54 | filter2.SetSquaredDistance(False) 55 | b_dist = filter2.Execute(b) 56 | 57 | b_dist = sitk.GetArrayFromImage(b_dist) 58 | b_dist = np.abs(b_dist) 59 | b_edge = np.zeros(b_dist.shape, b_dist.dtype) 60 | b_edge[b_dist == 0] = 1 61 | b_num = np.sum(b_edge) 62 | 63 | a_dist[b_edge == 0] = 0.0 64 | b_dist[a_edge == 0] = 0.0 65 | 66 | #a2b_mean_dist = np.sum(b_dist) / a_num 67 | #b2a_mean_dist = np.sum(a_dist) / b_num 68 | asd = (np.sum(a_dist) + np.sum(b_dist)) / (a_num + b_num) 69 | 70 | return asd 71 | 72 | def cal_hd(a, b): 73 | filter1 = sitk.HausdorffDistanceImageFilter() 74 | filter1.Execute(a, b) 75 | hd = filter1.GetHausdorffDistance() 76 | 77 | return hd 78 | 79 | def eval(pd_path, gt_entries, label_map, cls_num, metric_fn, calc_asd=True, keep_largest=False): 80 | result_lines = '' 81 | df_fn = '{}/{}.csv'.format(pd_path, metric_fn) 82 | if not os.path.exists(df_fn): 83 | results = [] 84 | print_line = '\n --- Start calculating metrics --- ' 85 | print(print_line) 86 | result_lines += '{}\n'.format(print_line) 87 | for [d_name, casename, gt_fname] in gt_entries: 88 | gt_label = read_image(fname=gt_fname) 89 | gt_array = sitk.GetArrayFromImage(gt_label) 90 | gt_array = gt_array.astype(dtype=np.uint8) 91 | 92 | # map labels 93 | tmp_array = np.zeros_like(gt_array) 94 | lmap = label_map[d_name] 95 | tgt_labels = [] 96 | for key in lmap: 97 | tmp_array[gt_array == key] = lmap[key] 98 | if lmap[key] not in tgt_labels: 99 | tgt_labels.append(lmap[key]) 100 | gt_array = tmp_array 101 | 102 | pd_fname = '{}/{}@{}.nii.gz'.format(pd_path, d_name, casename) 103 | pd_array = sitk.GetArrayFromImage(read_image(fname=pd_fname)) 104 | 105 | for c in tgt_labels: 106 | pd = np.zeros_like(pd_array, dtype=np.uint8) 107 | pd[pd_array == c] = 1 108 | pd_im = sitk.GetImageFromArray(pd) 109 | pd_im.SetSpacing(gt_label.GetSpacing()) 110 | pd_im.SetOrigin(gt_label.GetOrigin()) 111 | pd_im.SetDirection(gt_label.GetDirection()) 112 | if keep_largest: 113 | pd_im = keep_largest_component(pd_im, largest_n=1) 114 | pd = sitk.GetArrayFromImage(pd_im) 115 | pd = pd.astype(dtype=np.uint8) 116 | pd = np.reshape(pd, -1) 117 | 118 | gt = np.zeros_like(gt_array, dtype=np.uint8) 119 | gt[gt_array == c] = 1 120 | gt_im = sitk.GetImageFromArray(gt) 121 | gt_im.SetSpacing(gt_label.GetSpacing()) 122 | gt_im.SetOrigin(gt_label.GetOrigin()) 123 | gt_im.SetDirection(gt_label.GetDirection()) 124 | gt = np.reshape(gt, -1) 125 | 126 | dsc = cal_dsc(pd, gt) 127 | if calc_asd and np.sum(pd) > 0: 128 | asd = cal_asd(pd_im, gt_im) 129 | hd = cal_hd(pd_im, gt_im) 130 | else: 131 | asd = 0 132 | hd = 0 133 | results.append([d_name, casename, c, dsc, asd, hd]) 134 | 135 | print_line = ' --- {0:22s}@{1:22s}@{2:d}:\t\tDSC = {3:>5.2f}%\tASD = {4:>5.2f}mm\tHD = {5:>5.2f}mm'.format(d_name, casename, c, dsc*100.0, asd, hd) 136 | print(print_line) 137 | result_lines += '{}\n'.format(print_line) 138 | 139 | df = DataFrame(results, columns=['Dataset', 'Case', 'Class', 'DSC', 'ASD', 'HD']) 140 | df.to_csv(df_fn) 141 | df = read_csv(df_fn) 142 | 143 | 144 | dsc = [] 145 | asd = [] 146 | hd = [] 147 | dsc_mean = 0 148 | asd_mean = 0 149 | hd_mean = 0 150 | d_list = [d for d in set(df['Dataset'].tolist())] 151 | for d in d_list: 152 | dsc_m = df[df['Dataset'] == d]['DSC'].mean() 153 | dsc_v = df[df['Dataset'] == d]['DSC'].std() 154 | dsc.append([dsc_m, dsc_v]) 155 | dsc_mean += dsc_m 156 | asd_m = df[df['Dataset'] == d]['ASD'].mean() 157 | asd_v = df[df['Dataset'] == d]['ASD'].std() 158 | asd.append([asd_m, asd_v]) 159 | asd_mean += asd_m 160 | hd_m = df[df['Dataset'] == d]['HD'].mean() 161 | hd_v = df[df['Dataset'] == d]['HD'].std() 162 | hd.append([hd_m, hd_v]) 163 | hd_mean += hd_m 164 | print_line = ' --- dataset {0:s}:\tDSC = {1:.2f}({2:.2f})%\tASD = {3:.2f}({4:.2f})mm\tHD = {5:.2f}({6:.2f})mm\tN={7:d}'.format(d, dsc_m*100.0, dsc_v*100.0, asd_m, asd_v, hd_m, hd_v, len(df[df['Dataset'] == d]['DSC'])) 165 | print(print_line) 166 | result_lines += '{}\n'.format(print_line) 167 | dsc_mean = dsc_mean / len(d_list) 168 | asd_mean = asd_mean / len(d_list) 169 | hd_mean = hd_mean / len(d_list) 170 | dsc = np.array(dsc) 171 | asd = np.array(asd) 172 | hd = np.array(hd) 173 | 174 | print_line = ' --- dataset-avg:\tDSC = {0:.2f}%\tASD = {1:.2f}mm\tHD = {2:.2f}mm'.format(dsc_mean*100.0, asd_mean, hd_mean) 175 | print(print_line) 176 | result_lines += '{}\n'.format(print_line) 177 | print_line = ' --- Finish calculating metrics --- \n' 178 | print(print_line) 179 | result_lines += '{}\n'.format(print_line) 180 | 181 | result_fn = '{}/{}.txt'.format(pd_path, metric_fn) 182 | with open(result_fn, 'w') as result_file: 183 | result_file.write(result_lines) 184 | 185 | return dsc, asd, hd, dsc_mean, asd_mean, hd_mean 186 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class double_conv(nn.Module): 6 | def __init__(self, in_ch, out_ch): 7 | super(double_conv, self).__init__() 8 | self.conv = nn.Sequential( 9 | nn.Conv3d(in_ch, out_ch, 3, padding=1), 10 | nn.BatchNorm3d(out_ch), 11 | nn.ReLU(inplace=True), 12 | nn.Conv3d(out_ch, out_ch, 3, padding=1), 13 | nn.BatchNorm3d(out_ch), 14 | nn.ReLU(inplace=True) 15 | ) 16 | 17 | def forward(self, x): 18 | y = self.conv(x) 19 | return y 20 | 21 | class enc_block(nn.Module): 22 | def __init__(self, in_ch, out_ch): 23 | super(enc_block, self).__init__() 24 | self.conv = double_conv(in_ch, out_ch) 25 | self.down = nn.MaxPool3d(2) 26 | 27 | def forward(self, x): 28 | y_conv = self.conv(x) 29 | y = self.down(y_conv) 30 | return y, y_conv 31 | 32 | class dec_block(nn.Module): 33 | def __init__(self, in_ch, out_ch): 34 | super(dec_block, self).__init__() 35 | self.conv = double_conv(in_ch, out_ch) 36 | self.up = nn.ConvTranspose3d(out_ch, out_ch, 2, stride=2) 37 | 38 | def forward(self, x): 39 | y_conv = self.conv(x) 40 | y = self.up(y_conv) 41 | return y, y_conv 42 | 43 | def concatenate(x1, x2): 44 | diffZ = x2.size()[2] - x1.size()[2] 45 | diffY = x2.size()[3] - x1.size()[3] 46 | diffX = x2.size()[4] - x1.size()[4] 47 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 48 | diffY // 2, diffY - diffY//2, 49 | diffZ // 2, diffZ - diffZ//2)) 50 | y = torch.cat([x2, x1], dim=1) 51 | return y 52 | 53 | class UNet(nn.Module): 54 | def __init__(self, in_ch, base_ch, cls_num): 55 | super(UNet, self).__init__() 56 | self.in_ch = in_ch 57 | self.base_ch = base_ch 58 | self.cls_num = cls_num 59 | 60 | self.enc1 = enc_block(in_ch, base_ch) 61 | self.enc2 = enc_block(base_ch, base_ch*2) 62 | self.enc3 = enc_block(base_ch*2, base_ch*4) 63 | self.enc4 = enc_block(base_ch*4, base_ch*8) 64 | 65 | self.dec1 = dec_block(base_ch*8, base_ch*8) 66 | self.dec2 = dec_block(base_ch*8+base_ch*8, base_ch*4) 67 | self.dec3 = dec_block(base_ch*4+base_ch*4, base_ch*2) 68 | self.dec4 = dec_block(base_ch*2+base_ch*2, base_ch) 69 | self.lastconv = double_conv(base_ch+base_ch, base_ch) 70 | 71 | self.outconv = nn.Conv3d(base_ch, cls_num+1, 1) 72 | self.softmax = nn.Softmax(dim=1) 73 | 74 | def forward(self, x): 75 | x, enc1_conv = self.enc1(x) 76 | x, enc2_conv = self.enc2(x) 77 | x, enc3_conv = self.enc3(x) 78 | x, enc4_conv = self.enc4(x) 79 | x, _ = self.dec1(x) 80 | x, _ = self.dec2(concatenate(x, enc4_conv)) 81 | x, _ = self.dec3(concatenate(x, enc3_conv)) 82 | x, _ = self.dec4(concatenate(x, enc2_conv)) 83 | x = self.lastconv(concatenate(x, enc1_conv)) 84 | x = self.outconv(x) 85 | y = self.softmax(x) 86 | 87 | if self.training: 88 | return y, x 89 | else: 90 | return y 91 | 92 | def description(self): 93 | return 'U-Net (input channel = {0:d}) for {1:d}-class segmentation (base channel = {2:d})'.format(self.in_ch, self.cls_num+1, self.base_ch) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | from torch import optim 8 | from torch.utils import data 9 | from dataset import create_folds, Dataset 10 | from model import UNet 11 | from loss import dice_and_ce_loss 12 | from utils import resample_array, output2file 13 | from metric import eval 14 | from config import cfg 15 | 16 | def initial_net(net): 17 | for m in net.modules(): 18 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d): 19 | nn.init.xavier_normal_(m.weight) 20 | if m.bias is not None: 21 | nn.init.constant_(m.bias, 0) 22 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d): 23 | nn.init.constant_(m.weight, 1) 24 | if m.bias is not None: 25 | nn.init.constant_(m.bias, 0) 26 | elif isinstance(m, nn.Linear): 27 | nn.init.xavier_normal_(m.weight) 28 | if m.bias is not None: 29 | nn.init.constant_(m.bias, 0) 30 | 31 | def initialization(): 32 | 33 | train_record = [] 34 | for i in range(cfg['commu_times']): 35 | order = [(j + i) % len(cfg['node_list']) for j in range(len(cfg['node_list']))] 36 | random.shuffle(order) 37 | train_record.append(order) 38 | 39 | nodes = [] 40 | val_fold = None 41 | test_fold = None 42 | weight_sum = 0 43 | for node_id, [node_name, d_name, d_path, fraction] in enumerate(cfg['node_list']): 44 | 45 | folds, _ = create_folds(d_name, d_path, node_name, fraction, exclude_case=cfg['exclude_case']) 46 | 47 | # create training fold 48 | train_fold = folds[0] 49 | d_train = Dataset(train_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], aug_data=True) 50 | dl_train = data.DataLoader(dataset=d_train, batch_size=cfg['batch_size'], shuffle=True, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread']) 51 | 52 | # create validaion fold 53 | if val_fold is None: 54 | val_fold = folds[1] 55 | else: 56 | val_fold.extend(folds[1]) 57 | 58 | # create testing fold 59 | if test_fold is None: 60 | test_fold = folds[2] 61 | else: 62 | test_fold.extend(folds[2]) 63 | 64 | print('{0:s}: train = {1:d}'.format(node_name, len(d_train))) 65 | weight_sum += len(d_train) 66 | 67 | local_model = nn.DataParallel(module=UNet(in_ch=1, base_ch=32, cls_num=cfg['cls_num'])) 68 | local_model.cuda() 69 | initial_net(local_model) 70 | 71 | optimizer = optim.SGD(local_model.parameters(), lr=cfg['lr'], momentum=0.99, nesterov=True) 72 | 73 | lambda_func = lambda epoch: (1 - epoch / (cfg['commu_times'] * cfg['epoch_per_commu']))**0.9 74 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func) 75 | 76 | nodes.append([local_model, optimizer, scheduler, node_name, len(d_train), dl_train]) 77 | 78 | d_val = Dataset(val_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], aug_data=False) 79 | dl_val = data.DataLoader(dataset=d_val, batch_size=cfg['test_batch_size'], shuffle=False, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread']) 80 | 81 | d_test = Dataset(test_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], aug_data=False) 82 | dl_test = data.DataLoader(dataset=d_test, batch_size=cfg['test_batch_size'], shuffle=False, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread']) 83 | 84 | print('{0:s}: val/test = {1:d}/{2:d}'.format(node_name, len(d_val), len(d_test))) 85 | 86 | for i in range(len(nodes)): 87 | nodes[i][4] = nodes[i][4] / weight_sum 88 | print('Weight of {0:s}: {1:f}'.format(nodes[i][3], nodes[i][4])) 89 | 90 | return nodes, train_record, dl_val, dl_test 91 | 92 | def exchange_local_models(nodes, train_record, commu_t, node_id): 93 | new_node_id = train_record[commu_t][node_id] 94 | return new_node_id, nodes[new_node_id][3], nodes[new_node_id][4], nodes[new_node_id][5] 95 | 96 | def train_local_model(local_model, optimizer, scheduler, data_loader, epoch_num): 97 | train_loss = 0 98 | train_loss_num = 0 99 | for epoch_id in range(epoch_num): 100 | 101 | t0 = time.perf_counter() 102 | 103 | epoch_loss = 0 104 | epoch_loss_num = 0 105 | batch_id = 0 106 | for batch in data_loader: 107 | image = batch['data'].cuda() 108 | label = batch['label'].cuda() 109 | 110 | N = len(image) 111 | 112 | pred, pred_logit = local_model(image) 113 | 114 | print_line = 'Epoch {0:d}/{1:d} (train) --- Progress {2:5.2f}% (+{3:02d})'.format( 115 | epoch_id+1, epoch_num, 100.0 * batch_id * cfg['batch_size'] / len(data_loader.dataset), N) 116 | 117 | l_ce, l_dice = dice_and_ce_loss(pred, pred_logit, label) 118 | loss_sup = l_dice + l_ce 119 | epoch_loss += l_dice.item() + l_ce.item() 120 | epoch_loss_num += 1 121 | 122 | print_line += ' --- Loss: {0:.6f}({1:.6f}/{2:.6f})'.format(loss_sup.item(), l_dice.item(), l_ce.item()) 123 | print(print_line) 124 | 125 | optimizer.zero_grad() 126 | loss_sup.backward() 127 | optimizer.step() 128 | 129 | del image, label, pred, pred_logit, loss_sup 130 | batch_id += 1 131 | 132 | train_loss += epoch_loss 133 | train_loss_num += epoch_loss_num 134 | epoch_loss = epoch_loss / epoch_loss_num 135 | lr = scheduler.get_last_lr()[0] 136 | 137 | print_line = 'Epoch {0:d}/{1:d} (train) --- Loss: {2:.6f} --- Lr: {3:.6f}'.format(epoch_id+1, epoch_num, epoch_loss, lr) 138 | print(print_line) 139 | 140 | scheduler.step() 141 | 142 | t1 = time.perf_counter() 143 | epoch_t = t1 - t0 144 | print("Epoch time cost: {h:>02d}:{m:>02d}:{s:>02d}\n".format( 145 | h=int(epoch_t) // 3600, m=(int(epoch_t) % 3600) // 60, s=int(epoch_t) % 60)) 146 | 147 | train_loss = train_loss / train_loss_num 148 | 149 | return train_loss 150 | 151 | # mode: 'val' or 'test' 152 | # commu_iters: communication iteration index, only available when mode == 'val' 153 | def eval_local_model(nodes, data_loader, result_path, mode, commu_iters): 154 | t0 = time.perf_counter() 155 | 156 | if mode == 'val': 157 | metric_fname = 'metric_validation-{0:04d}'.format(commu_iters+1) 158 | print('Validation ({0:d}/{1:d}) ...'.format(commu_iters+1, cfg['commu_times'])) 159 | elif mode == 'test': 160 | metric_fname = 'metric_testing' 161 | print('Testing ...') 162 | 163 | gt_entries = [] 164 | output_buffer = None 165 | std_buffer = None 166 | for batch_id, batch in enumerate(data_loader): 167 | image = batch['data'].cuda() 168 | N = len(image) 169 | 170 | Ey, Ey2 = None, None 171 | for [local_model, _, _, _, _, _] in nodes: 172 | 173 | local_model.eval() 174 | prob = local_model(image) 175 | y = prob[:,1,:].detach().cpu().numpy().copy() 176 | 177 | if Ey is None: 178 | Ey = np.zeros_like(y) 179 | Ey2 = np.zeros_like(y) 180 | Ey = Ey + y 181 | Ey2 = Ey2 + y**2 182 | 183 | del prob 184 | 185 | Ey = Ey/len(nodes) 186 | Ey2 = Ey2/len(nodes) 187 | mask = (Ey.copy()>0.5).astype(dtype=np.uint8) 188 | 189 | stddev = Ey2-Ey**2 190 | stddev[stddev<0] = 0 191 | stddev = np.sqrt(stddev) + 1.0 192 | 193 | print_line = '{0:s} --- Progress {1:5.2f}% (+{2:d})'.format( 194 | mode, 100.0 * batch_id * cfg['test_batch_size'] / len(data_loader.dataset), N) 195 | print(print_line) 196 | 197 | for i in range(N): 198 | sample_mask = resample_array( 199 | mask[i,:], batch['size'][i].numpy(), batch['spacing'][i].numpy(), batch['origin'][i].numpy(), 200 | batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy()) 201 | if output_buffer is None: 202 | output_buffer = np.zeros_like(sample_mask, dtype=np.uint8) 203 | output_buffer[sample_mask > 0] = 0 204 | output_buffer = output_buffer + sample_mask 205 | 206 | if batch['eof'][i] == True: 207 | output2file(output_buffer, batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy(), 208 | '{0:s}/{1:s}@{2:s}.nii.gz'.format(result_path, batch['dataset'][i], batch['case'][i])) 209 | output_buffer = None 210 | gt_entries.append([batch['dataset'][i], batch['case'][i], batch['label_fname'][i]]) 211 | 212 | if mode == 'test': 213 | sample_stddev = resample_array( 214 | stddev[i,:], batch['size'][i].numpy(), batch['spacing'][i].numpy(), batch['origin'][i].numpy(), 215 | batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy(), linear=True) 216 | if std_buffer is None: 217 | std_buffer = np.zeros_like(sample_stddev) 218 | std_buffer[sample_stddev >= 1.0] = 0 219 | sample_stddev[sample_stddev < 1.0] = 0 220 | sample_stddev = sample_stddev - 1.0 221 | sample_stddev[sample_stddev < 0.0] = 0 222 | std_buffer = std_buffer + sample_stddev 223 | 224 | if batch['eof'][i] == True: 225 | output2file(std_buffer, batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy(), 226 | '{0:s}/{1:s}@{2:s}-std.nii.gz'.format(result_path, batch['dataset'][i], batch['case'][i])) 227 | std_buffer = None 228 | 229 | del image 230 | 231 | seg_dsc, seg_asd, seg_hd, seg_dsc_m, seg_asd_m, seg_hd_m = eval( 232 | pd_path=result_path, gt_entries=gt_entries, label_map=cfg['label_map'], cls_num=cfg['cls_num'], 233 | metric_fn=metric_fname, calc_asd=(mode != 'val'), keep_largest=False) 234 | 235 | if mode == 'val': 236 | print_line = 'Validation result (iter = {0:d}/{1:d}) --- DSC {2:.2f} ({3:s})%'.format( 237 | commu_iters+1, cfg['commu_times'], 238 | seg_dsc_m*100.0, '/'.join(['%.2f']*len(seg_dsc[:,0])) % tuple(seg_dsc[:,0]*100.0)) 239 | else: 240 | print_line = 'Testing results --- DSC {0:.2f} ({1:s})% --- ASD {2:.2f} ({3:s})mm --- HD {4:.2f} ({5:s})mm'.format( 241 | seg_dsc_m*100.0, '/'.join(['%.2f']*len(seg_dsc[:,0])) % tuple(seg_dsc[:,0]*100.0), 242 | seg_asd_m, '/'.join(['%.2f']*len(seg_asd[:,0])) % tuple(seg_asd[:,0]), 243 | seg_hd_m, '/'.join(['%.2f']*len(seg_hd[:,0])) % tuple(seg_hd[:,0])) 244 | print(print_line) 245 | t1 = time.perf_counter() 246 | eval_t = t1 - t0 247 | print("Evaluation time cost: {h:>02d}:{m:>02d}:{s:>02d}\n".format( 248 | h=int(eval_t) // 3600, m=(int(eval_t) % 3600) // 60, s=int(eval_t) % 60)) 249 | 250 | return seg_dsc_m, seg_dsc 251 | 252 | def load_models(nodes, model_fname): 253 | for node_id in range(len(nodes)): 254 | nodes[node_id][0].load_state_dict(torch.load(model_fname)['local_model_{0:d}_state_dict'.format(node_id)]) 255 | nodes[node_id][1].load_state_dict(torch.load(model_fname)['local_model_{0:d}_optimizer'.format(node_id)]) 256 | nodes[node_id][2].load_state_dict(torch.load(model_fname)['local_model_{0:d}_scheduler'.format(node_id)]) 257 | 258 | def train(): 259 | 260 | train_start_time = time.localtime() 261 | print("Start time: {start_time}\n".format(start_time=time.strftime("%Y-%m-%d %H:%M:%S", train_start_time))) 262 | time_stamp = time.strftime("%Y%m%d%H%M%S", train_start_time) 263 | 264 | # create directory for results storage 265 | store_dir = '{}/model_{}'.format(cfg['model_path'], time_stamp) 266 | loss_fn = '{}/loss.txt'.format(store_dir) 267 | val_result_path = '{}/results_val'.format(store_dir) 268 | os.makedirs(val_result_path, exist_ok=True) 269 | test_result_path = '{}/results_test'.format(store_dir) 270 | os.makedirs(test_result_path, exist_ok=True) 271 | 272 | print('Loading local data from each nodes ... \n') 273 | 274 | nodes, train_record, dl_val, dl_test = initialization() 275 | 276 | print("Training order:", train_record) 277 | 278 | best_val_acc = 0 279 | start_iter = 0 280 | acc_time = 0 281 | best_model_fn = '{0:s}/cp_commu_{1:04d}.pth.tar'.format(store_dir, 1) 282 | 283 | print() 284 | log_line = "Model: {}\nModel parameters: {}\nStart time: {}\nConfiguration:\n".format( 285 | nodes[0][0].module.description(), 286 | sum(x.numel() for x in nodes[0][0].parameters()), 287 | time.strftime("%Y-%m-%d %H:%M:%S", train_start_time)) 288 | for cfg_key in cfg: 289 | log_line += ' --- {}: {}\n'.format(cfg_key, cfg[cfg_key]) 290 | print(log_line) 291 | 292 | for commu_t in range(start_iter, cfg['commu_times'], 1): 293 | 294 | t0 = time.perf_counter() 295 | 296 | train_loss = [] 297 | for i, [local_model, optimizer, scheduler, _, _, _] in enumerate(nodes): 298 | 299 | node_id, node_name, node_weight, dl_train = exchange_local_models(nodes, train_record, commu_t, i) 300 | 301 | print('Training ({0:d}/{1:d}) on Node: {2:s}\n'.format(commu_t+1, cfg['commu_times'], node_name)) 302 | 303 | local_model.train() 304 | 305 | train_loss.append(train_local_model(local_model, optimizer, scheduler, dl_train, cfg['epoch_per_commu'])) 306 | 307 | seg_dsc_m, seg_dsc = eval_local_model(nodes, dl_val, val_result_path, mode='val', commu_iters=commu_t) 308 | 309 | t1 = time.perf_counter() 310 | epoch_t = t1 - t0 311 | acc_time += epoch_t 312 | print("Iteration time cost: {h:>02d}:{m:>02d}:{s:>02d}\n".format( 313 | h=int(epoch_t) // 3600, m=(int(epoch_t) % 3600) // 60, s=int(epoch_t) % 60)) 314 | 315 | loss_line = '{commu_iter:>04d}\t{train_loss:s}\t{seg_val_dsc:>8.6f}\t{seg_val_dsc_cls:s}'.format( 316 | commu_iter=commu_t+1, train_loss='\t'.join(['%8.6f']*len(train_loss)) % tuple(train_loss), 317 | seg_val_dsc=seg_dsc_m, seg_val_dsc_cls='\t'.join(['%8.6f']*len(seg_dsc[:,0])) % tuple(seg_dsc[:,0]) 318 | ) 319 | for [_, _, scheduler, _, _, _] in nodes: 320 | loss_line += '\t{node_lr:>8.6f}'.format(node_lr=scheduler.get_last_lr()[0]) 321 | loss_line += '\n' 322 | 323 | with open(loss_fn, 'a') as loss_file: 324 | loss_file.write(loss_line) 325 | 326 | # save best model 327 | if commu_t == 0 or seg_dsc_m > best_val_acc: 328 | # remove former best model 329 | if os.path.exists(best_model_fn): 330 | os.remove(best_model_fn) 331 | # save current best model 332 | best_val_acc = seg_dsc_m 333 | best_model_fn = '{0:s}/cp_commu_{1:04d}.pth.tar'.format(store_dir, commu_t+1) 334 | best_model_cp = { 335 | 'commu_iter':commu_t, 336 | 'acc_time':acc_time, 337 | 'time_stamp':time_stamp, 338 | 'best_val_acc':best_val_acc, 339 | 'best_model_filename':best_model_fn} 340 | for node_id, [local_model, optimizer, scheduler, _, _, _] in enumerate(nodes): 341 | best_model_cp['local_model_{0:d}_state_dict'.format(node_id)] = local_model.state_dict() 342 | best_model_cp['local_model_{0:d}_optimizer'.format(node_id)] = optimizer.state_dict() 343 | best_model_cp['local_model_{0:d}_scheduler'.format(node_id)] = scheduler.state_dict() 344 | torch.save(best_model_cp, best_model_fn) 345 | print('Best model (communication iteration = {}) saved.\n'.format(commu_t+1)) 346 | 347 | print("Total training time: {h:>02d}:{m:>02d}:{s:>02d}\n\n".format( 348 | h=int(acc_time) // 3600, m=(int(acc_time) % 3600) // 60, s=int(acc_time) % 60)) 349 | 350 | # test 351 | load_models(nodes, best_model_fn) 352 | eval_local_model(nodes, dl_test, test_result_path, mode='test', commu_iters=0) 353 | 354 | print("Finish time: {finish_time}\n\n".format( 355 | finish_time=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) 356 | 357 | if __name__ == '__main__': 358 | 359 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg['gpu'] 360 | 361 | train() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # Note: 3 | # Use itk here will cause deadlock after the first training epoch 4 | # when using multithread (dataloader num_workers > 0) but reason unknown 5 | import SimpleITK as sitk 6 | 7 | def read_image(fname): 8 | reader = sitk.ImageFileReader() 9 | reader.SetFileName(fname) 10 | image = reader.Execute() 11 | return image 12 | 13 | def resample_array(array, size, spacing, origin, size_rs, spacing_rs, origin_rs, transform=None, linear=False): 14 | array = np.reshape(array, [size[2], size[1], size[0]]) 15 | image = sitk.GetImageFromArray(array) 16 | image.SetSpacing((float(spacing[0]), float(spacing[1]), float(spacing[2]))) 17 | image.SetOrigin((float(origin[0]), float(origin[1]), float(origin[2]))) 18 | 19 | resampler = sitk.ResampleImageFilter() 20 | resampler.SetSize((int(size_rs[0]), int(size_rs[1]), int(size_rs[2]))) 21 | resampler.SetOutputSpacing((float(spacing_rs[0]), float(spacing_rs[1]), float(spacing_rs[2]))) 22 | resampler.SetOutputOrigin((float(origin_rs[0]), float(origin_rs[1]), float(origin_rs[2]))) 23 | if transform is not None: 24 | resampler.SetTransform(transform) 25 | else: 26 | resampler.SetTransform(sitk.Transform(3, sitk.sitkIdentity)) 27 | if linear: 28 | resampler.SetInterpolator(sitk.sitkLinear) 29 | else: 30 | resampler.SetInterpolator(sitk.sitkNearestNeighbor) 31 | resampler.SetDefaultPixelValue(0) 32 | rs_image = resampler.Execute(image) 33 | rs_array = sitk.GetArrayFromImage(rs_image) 34 | 35 | return rs_array 36 | 37 | def output2file(array, size, spacing, origin, fname): 38 | array = np.reshape(array, [size[2], size[1], size[0]])#.astype(dtype=np.uint8) 39 | image = sitk.GetImageFromArray(array) 40 | image.SetSpacing((float(spacing[0]), float(spacing[1]), float(spacing[2]))) 41 | image.SetOrigin((float(origin[0]), float(origin[1]), float(origin[2]))) 42 | 43 | writer = sitk.ImageFileWriter() 44 | writer.SetFileName(fname) 45 | writer.Execute(image) --------------------------------------------------------------------------------