├── .gitignore
├── README.md
├── config.py
├── dataset.py
├── fig1.png
├── fig2.png
├── loss.py
├── metric.py
├── model.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | models/
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 | .vscode/settings.json
133 |
134 | tmp.py
135 | .vscode/launch.json
136 |
137 | models/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FedCross
2 | Federated Cross Learning for Medical Image Segmentation
3 |
4 | This is a python (PyTorch) implementation of **Federated Cross Learning (FedCross)** method proposed in our paper [**"Federated Cross Learning for Medical Image Segmentation"**](https://openreview.net/forum?id=DrZbwobH_zo) published in *Medical Imaging with Deep Learning 2023* conference (Nashville, Tennessee, United States, Jul. 10-12, 2023). A preprint version of this paper is also available on [arXiv](https://arxiv.org/abs/2204.02450).
5 |
6 | ## Citation
7 | *X. Xu, H. H. Deng, T. Chen, T. Kuang, J. C. Barber, D. Kim, J. Gateno, J. J. Xia, and P. Yan, "Federated Cross Learning for Medical Image Segmentation," in Medical Imaging with Deep Learning 2023. Nashville, Tennessee, United States, Jul. 10-12, 2023.*
8 |
9 | @inproceedings{Xu2023FedCross,
10 | title={Federated Cross Learning for Medical Image Segmentation},
11 | author={Xuanang Xu and Hannah H. Deng and Tianyi Chen and Tianshu Kuang and Joshua C. Barber and Daeseung Kim and Jaime Gateno and James J. Xia and Pingkun Yan},
12 | booktitle={Medical Imaging with Deep Learning},
13 | year={2023},
14 | url={https://openreview.net/forum?id=DrZbwobH_zo}
15 | }
16 |
17 | ## Update
18 | - **Nov 21, 2023**: Fix a bug (`read_image` function was missing in `utils.py` file).
19 |
20 | ## Abstract
21 | Federated learning (FL) can collaboratively train deep learning models using isolated patient data owned by different hospitals for various clinical applications, including medical image segmentation. However, a major problem of FL is its performance degradation when dealing with data that are not independently and identically distributed (non-iid), which is often the case in medical images. In this paper, we first conduct a theoretical analysis on the FL algorithm to reveal the problem of model aggregation during training on non-iid data. With the insights gained through the analysis, we propose a simple yet effective method, federated cross learning (FedCross), to tackle this challenging problem. Unlike the conventional FL methods that combine multiple individually trained local models on a server node, our FedCross sequentially trains the global model across different clients in a round-robin manner, and thus the entire training procedure does not involve any model aggregation steps. To further improve its performance to be comparable with the centralized learning method, we combine the FedCross with an ensemble learning mechanism to compose a federated cross ensemble learning (FedCrossEns) method. Finally, we conduct extensive experiments using a set of public datasets. The experimental results show that the proposed FedCross training strategy outperforms the mainstream FL methods on non-iid data. In addition to improving the segmentation performance, our FedCrossEns can further provide a quantitative estimation of the model uncertainty, demonstrating the effectiveness and clinical significance of our designs. Source code is publicly available at [https://github.com/DIAL-RPI/FedCross](https://github.com/DIAL-RPI/FedCross).
22 |
23 | ## An illustration of non-iid problem in FL model aggregation
24 | The model aggregation process in FL may lead to sub-optimal solution when dealing with non-iid data. In the figure below, (a) and (b) show that the locally trained models $θ^{J+1}_k$ and $θ^{J+1}_m$ each individually achieve their minimal in the loss landscape of the client dataset $D_k$ and $D_m$, respectively. (c) indicates that the aggregated FL model $θ^′$ is located at a non-minimal position in the global loss landscape of $D_k \cup D_m$.
25 |
26 |
27 | ## Method
28 | ### Training schemes of (a) FedAvg, (b) FedCross, and (c) FedCrossEns
29 |
30 |
31 | ## Contact
32 | You are welcome to contact us:
33 | - [xux12@rpi.edu](mailto:xux12@rpi.edu) ([Dr. Xuanang Xu](https://superxuang.github.io/))
34 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | cfg = {}
2 |
3 | cfg['cls_num'] = 1
4 | cfg['gpu'] = '0,1,2,3' # to use multiple gpu: cfg['gpu'] = '0,1,2,3'
5 | cfg['batch_size'] = 16 # training batch size
6 | cfg['test_batch_size'] = 16 # testing batch size
7 | cfg['lr'] = 0.01 # base learning rate
8 | cfg['model_path'] = '/set_your_own_model_path/models' # the path where to save the trained model and evaluation results
9 | cfg['rs_size'] = [160,160,32] # resample size: [x, y, z]
10 | cfg['rs_spacing'] = [0.5,0.5,1.0] # resample spacing: [x, y, z]. non-positive value means adaptive spacing fit the physical size: rs_size * rs_spacing = origin_size * origin_spacing
11 | cfg['rs_intensity'] = [-200.0, 400.0] # rescale intensity from [min, max] to [0, 1].
12 | cfg['cpu_thread'] = 4 # multi-thread for data loading. zero means single thread.
13 | cfg['commu_times'] = 50 # number of communication rounds
14 | cfg['epoch_per_commu'] = 32 # number of local training epochs within one communication round
15 |
16 | # map labels of different client datasets to a uniform label map
17 | cfg['label_map'] = {
18 | 'MSD':{1:1, 2:1},
19 | 'NCI-ISBI':{1:1, 2:1},
20 | 'PROMISE12':{1:1},
21 | 'PROSTATEx':{1:1},
22 | }
23 |
24 | # exclude any samples in the form of '[dataset_name, case_name]'
25 | cfg['exclude_case'] = [
26 | ]
27 |
28 | # data path of each client dataset
29 | cfg['node_list'] = [
30 | ['Node-1', ['MSD'], ['/set_your_own_data_path/MSD-Prostate'], [19,3,10]], # 32 in total
31 | ['Node-2', ['NCI-ISBI'], ['/set_your_own_data_path/NCI-ISBI-Prostate'], [48,8,24]], # 80 in total
32 | ['Node-3', ['PROMISE12'], ['/set_your_own_data_path/PROMISE12'], [30,5,15]], # 50 in total
33 | ['Node-4', ['PROSTATEx'], ['/set_your_own_data_path/PROSTATEx'], [122,20,62]], # 204 in total
34 | ]
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | from torch.utils import data
5 | import itk
6 | import numpy as np
7 | import random
8 | import SimpleITK as sitk
9 | from scipy import stats
10 |
11 | def read_image(fname, imtype):
12 | reader = itk.ImageFileReader[imtype].New()
13 | reader.SetFileName(fname)
14 | reader.Update()
15 | image = reader.GetOutput()
16 | return image
17 |
18 | def image_2_array(image):
19 | arr = itk.GetArrayFromImage(image)
20 | return arr
21 |
22 | def array_2_image(arr, spacing, origin, imtype):
23 | image = itk.GetImageFromArray(arr)
24 | image.SetSpacing((spacing[0], spacing[1], spacing[2]))
25 | image.SetOrigin((origin[0], origin[1], origin[2]))
26 | cast = itk.CastImageFilter[type(image), imtype].New()
27 | cast.SetInput(image)
28 | cast.Update()
29 | image = cast.GetOutput()
30 | return image
31 |
32 | def scan_path(dataset_name, dataset_path):
33 | entries = []
34 | if dataset_name == 'MSD':
35 | for f in os.listdir('{}/imagesTr'.format(dataset_path)):
36 | if f.startswith('prostate_') and f.endswith('.nii.gz'):
37 | case_name = f.split('.nii.gz')[0]
38 | if os.path.isfile('{}/labelsTr/{}'.format(dataset_path, f)):
39 | image_name = '{}/imagesTr/{}'.format(dataset_path, f)
40 | label_name = '{}/labelsTr/{}'.format(dataset_path, f)
41 | entries.append([dataset_name, case_name, image_name, label_name])
42 | elif dataset_name == 'NCI-ISBI':
43 | for f in os.listdir('{}/image'.format(dataset_path)):
44 | if f.startswith('Prostate') and f.endswith('.nii.gz'):
45 | case_name = f.split('.nii.gz')[0]
46 | if os.path.isfile('{}/label/{}'.format(dataset_path, f)):
47 | image_name = '{}/image/{}'.format(dataset_path, f)
48 | label_name = '{}/label/{}'.format(dataset_path, f)
49 | entries.append([dataset_name, case_name, image_name, label_name])
50 | elif dataset_name == 'PROMISE12':
51 | for f in os.listdir('{}/image'.format(dataset_path)):
52 | if f.startswith('Case') and f.endswith('.nii.gz'):
53 | case_name = f.split('.nii.gz')[0]
54 | if os.path.isfile('{}/label/{}'.format(dataset_path, f)):
55 | image_name = '{}/image/{}'.format(dataset_path, f)
56 | label_name = '{}/label/{}'.format(dataset_path, f)
57 | entries.append([dataset_name, case_name, image_name, label_name])
58 | elif dataset_name == 'PROSTATEx':
59 | for f in os.listdir('{}/image'.format(dataset_path)):
60 | if f.startswith('ProstateX-') and f.endswith('.nii.gz'):
61 | case_name = f.split('.nii.gz')[0]
62 | if os.path.isfile('{}/label/{}'.format(dataset_path, f)):
63 | image_name = '{}/image/{}'.format(dataset_path, f)
64 | label_name = '{}/label/{}'.format(dataset_path, f)
65 | entries.append([dataset_name, case_name, image_name, label_name])
66 | return entries
67 |
68 | def create_folds(dataset_name, dataset_path, fold_name, fraction, exclude_case):
69 | fold_file_name = '{0:s}/data_split-{1:s}.txt'.format(sys.path[0], fold_name)
70 | folds = {}
71 | if os.path.exists(fold_file_name):
72 | with open(fold_file_name, 'r') as fold_file:
73 | strlines = fold_file.readlines()
74 | for strline in strlines:
75 | strline = strline.rstrip('\n')
76 | params = strline.split()
77 | fold_id = int(params[0])
78 | if fold_id not in folds:
79 | folds[fold_id] = []
80 | folds[fold_id].append([params[1], params[2], params[3], params[4]])
81 | else:
82 | entries = []
83 | for [d_name, d_path] in zip(dataset_name, dataset_path):
84 | entries.extend(scan_path(d_name, d_path))
85 | for e in entries:
86 | if e[0:2] in exclude_case:
87 | entries.remove(e)
88 | random.shuffle(entries)
89 |
90 | ptr = 0
91 | for fold_id in range(len(fraction)):
92 | folds[fold_id] = entries[ptr:ptr+fraction[fold_id]]
93 | ptr += fraction[fold_id]
94 |
95 | with open(fold_file_name, 'w') as fold_file:
96 | for fold_id in range(len(fraction)):
97 | for [d_name, case_name, image_path, label_path] in folds[fold_id]:
98 | fold_file.write('{0:d} {1:s} {2:s} {3:s} {4:s}\n'.format(fold_id, d_name, case_name, image_path, label_path))
99 |
100 | folds_size = [len(x) for x in folds.values()]
101 |
102 | return folds, folds_size
103 |
104 | def generate_transform(aug, min_offset, max_offset):
105 | if aug:
106 | min_rotate = -0.1 # [rad]
107 | max_rotate = 0.1 # [rad]
108 | t = itk.Euler3DTransform[itk.D].New()
109 | euler_parameters = t.GetParameters()
110 | euler_parameters = itk.OptimizerParameters[itk.D](t.GetNumberOfParameters())
111 | offset_x = min_offset[0] + random.random() * (max_offset[0] - min_offset[0]) # rotate
112 | offset_y = min_offset[1] + random.random() * (max_offset[1] - min_offset[1]) # rotate
113 | offset_z = min_offset[2] + random.random() * (max_offset[2] - min_offset[2]) # rotate
114 | rotate_x = min_rotate + random.random() * (max_rotate - min_rotate) # tranlate
115 | rotate_y = min_rotate + random.random() * (max_rotate - min_rotate) # tranlate
116 | rotate_z = min_rotate + random.random() * (max_rotate - min_rotate) # tranlate
117 | euler_parameters[0] = rotate_x # rotate
118 | euler_parameters[1] = rotate_y # rotate
119 | euler_parameters[2] = rotate_z # rotate
120 | euler_parameters[3] = offset_x # tranlate
121 | euler_parameters[4] = offset_y # tranlate
122 | euler_parameters[5] = offset_z # tranlate
123 | t.SetParameters(euler_parameters)
124 | else:
125 | offset_x = 0
126 | offset_y = 0
127 | offset_z = 0
128 | rotate_x = 0
129 | rotate_y = 0
130 | rotate_z = 0
131 | t = itk.IdentityTransform[itk.D, 3].New()
132 | return t, [offset_x, offset_y, offset_z, rotate_x, rotate_y, rotate_z]
133 |
134 | def resample(image, imtype, size, spacing, origin, transform, linear, dtype):
135 | o_origin = image.GetOrigin()
136 | o_spacing = image.GetSpacing()
137 | o_size = image.GetBufferedRegion().GetSize()
138 | output = {}
139 | output['org_size'] = np.array(o_size, dtype=int)
140 | output['org_spacing'] = np.array(o_spacing, dtype=float)
141 | output['org_origin'] = np.array(o_origin, dtype=float)
142 |
143 | if origin is None: # if no origin point specified, center align the resampled image with the original image
144 | new_size = np.zeros(3, dtype=int)
145 | new_spacing = np.zeros(3, dtype=float)
146 | new_origin = np.zeros(3, dtype=float)
147 | for i in range(3):
148 | new_size[i] = size[i]
149 | if spacing[i] > 0:
150 | new_spacing[i] = spacing[i]
151 | new_origin[i] = o_origin[i] + o_size[i]*o_spacing[i]*0.5 - size[i]*spacing[i]*0.5
152 | else:
153 | new_spacing[i] = o_size[i] * o_spacing[i] / size[i]
154 | new_origin[i] = o_origin[i]
155 | else:
156 | new_size = np.array(size, dtype=int)
157 | new_spacing = np.array(spacing, dtype=float)
158 | new_origin = np.array(origin, dtype=float)
159 |
160 | output['size'] = new_size
161 | output['spacing'] = new_spacing
162 | output['origin'] = new_origin
163 |
164 | resampler = itk.ResampleImageFilter[imtype, imtype].New()
165 | resampler.SetInput(image)
166 | resampler.SetSize((int(new_size[0]), int(new_size[1]), int(new_size[2])))
167 | resampler.SetOutputSpacing((float(new_spacing[0]), float(new_spacing[1]), float(new_spacing[2])))
168 | resampler.SetOutputOrigin((float(new_origin[0]), float(new_origin[1]), float(new_origin[2])))
169 | resampler.SetTransform(transform)
170 | if linear:
171 | resampler.SetInterpolator(itk.LinearInterpolateImageFunction[imtype, itk.D].New())
172 | else:
173 | resampler.SetInterpolator(itk.NearestNeighborInterpolateImageFunction[imtype, itk.D].New())
174 | resampler.SetDefaultPixelValue(0)
175 | resampler.Update()
176 | rs_image = resampler.GetOutput()
177 | image_array = itk.GetArrayFromImage(rs_image)
178 | image_array = image_array[np.newaxis, :].astype(dtype)
179 | output['array'] = image_array
180 |
181 | return output
182 |
183 | def zscore_normalize(x):
184 | y = (x - x.mean()) / x.std()
185 | return y
186 |
187 | def make_onehot(input, cls):
188 | oh = np.repeat(np.zeros_like(input), cls+1, axis=0)
189 | for i in range(cls+1):
190 | tmp = np.zeros_like(input)
191 | tmp[input==i] = 1
192 | oh[i,:] = tmp
193 | return oh
194 |
195 | def make_flag(cls, labelmap):
196 | flag = np.zeros([cls, 1])
197 | for key in labelmap:
198 | flag[labelmap[key]-1,0] = 1
199 | return flag
200 |
201 | def image2file(image, imtype, fname):
202 | writer = itk.ImageFileWriter[imtype].New()
203 | writer.SetInput(image)
204 | writer.SetFileName(fname)
205 | writer.Update()
206 |
207 | def array2file(array, size, origin, spacing, imtype, fname):
208 | image = itk.GetImageFromArray(array.reshape([size[2], size[1], size[0]]))
209 | image.SetSpacing((spacing[0], spacing[1], spacing[2]))
210 | image.SetOrigin((origin[0], origin[1], origin[2]))
211 | image2file(image, imtype=imtype, fname=fname)
212 |
213 | # dataset of 3D image volume
214 | # 3D volumes are resampled from and center-aligned with the original images
215 | class Dataset(data.Dataset):
216 | def __init__(self, ids, rs_size, rs_spacing, rs_intensity, label_map, cls_num, aug_data):
217 | self.ImageType = itk.Image[itk.F, 3]
218 | self.LabelType = itk.Image[itk.UC, 3]
219 | self.ids = []
220 | self.rs_size = np.array(rs_size)
221 | self.rs_spacing = np.array(rs_spacing, dtype=np.float)
222 | self.rs_intensity = rs_intensity
223 | self.label_map = label_map
224 | self.cls_num = cls_num
225 | self.aug_data = aug_data
226 | self.im_cache = {}
227 | self.lb_cache = {}
228 |
229 | for [d_name, casename, image_fn, label_fn] in ids:
230 | reader = sitk.ImageFileReader()
231 | reader.SetFileName(image_fn)
232 | reader.ReadImageInformation()
233 | image_size = np.array(reader.GetSize()[:3])
234 | image_origin = np.array(reader.GetOrigin()[:3], dtype=np.float)
235 | image_spacing = np.array(reader.GetSpacing()[:3], dtype=np.float)
236 | image_phy_size = image_size * image_spacing
237 | patch_phy_size = self.rs_size*self.rs_spacing
238 |
239 | if not aug_data:
240 | patch_num = (image_phy_size/patch_phy_size).astype(int)+1
241 | for p_z in range(patch_num[2]):
242 | for p_y in range(patch_num[1]):
243 | for p_x in range(patch_num[0]):
244 | patch_origin = np.zeros_like(image_origin)
245 | patch_origin[0] = image_origin[0]+p_x*patch_phy_size[0]
246 | patch_origin[1] = image_origin[1]+p_y*patch_phy_size[1]
247 | patch_origin[2] = image_origin[2]+p_z*patch_phy_size[2]
248 | patch_origin[0] = min(patch_origin[0], image_origin[0]+image_phy_size[0]-patch_phy_size[0])
249 | patch_origin[1] = min(patch_origin[1], image_origin[1]+image_phy_size[1]-patch_phy_size[1])
250 | patch_origin[2] = min(patch_origin[2], image_origin[2]+image_phy_size[2]-patch_phy_size[2])
251 | eof = (p_x == patch_num[0]-1) & (p_y == patch_num[1]-1) & (p_z == patch_num[2]-1)
252 | self.ids.append([d_name, casename, image_fn, label_fn, patch_origin, eof])
253 | else:
254 | patch_origin = np.zeros_like(image_origin)
255 | patch_origin = image_origin+0.5*image_phy_size-0.5*patch_phy_size
256 | repeat_time = 4
257 | for i in range(repeat_time):
258 | self.ids.append([d_name, casename, image_fn, label_fn, patch_origin, i==repeat_time-1])
259 |
260 |
261 | def __len__(self):
262 | return len(self.ids)
263 |
264 | def __getitem__(self, index):
265 | [d_name, casename, image_fn, label_fn, patch_origin, eof] = self.ids[index]
266 |
267 | if image_fn not in self.im_cache:
268 | src_image = read_image(fname=image_fn, imtype=self.ImageType)
269 | image_cache = {}
270 | image_cache['size'] = np.array(src_image.GetBufferedRegion().GetSize())
271 | image_cache['origin'] = np.array(src_image.GetOrigin(), dtype=np.float)
272 | image_cache['spacing'] = np.array(src_image.GetSpacing(), dtype=np.float)
273 | image_cache['array'] = zscore_normalize(image_2_array(src_image).copy())
274 | self.im_cache[image_fn] = image_cache
275 | image_cache = self.im_cache[image_fn]
276 |
277 | min_offset = image_cache['origin'] - patch_origin
278 | max_offset = image_cache['origin']+image_cache['spacing']*image_cache['size']-patch_origin-self.rs_size*self.rs_spacing
279 |
280 | t, _ = generate_transform(self.aug_data, min_offset, max_offset)
281 |
282 | src_image = array_2_image(image_cache['array'], image_cache['spacing'], image_cache['origin'], self.ImageType)
283 |
284 | image = resample(
285 | image=src_image, imtype=self.ImageType,
286 | size=self.rs_size, spacing=self.rs_spacing, origin=patch_origin,
287 | transform=t, linear=True, dtype=np.float32)
288 |
289 | if label_fn not in self.lb_cache:
290 | src_label = read_image(fname=label_fn, imtype=self.LabelType)
291 | label_cache = {}
292 | label_cache['origin'] = np.array(src_label.GetOrigin(), dtype=np.float)
293 | label_cache['spacing'] = np.array(src_label.GetSpacing(), dtype=np.float)
294 | label_cache['array'] = image_2_array(src_label).copy()
295 | self.lb_cache[label_fn] = label_cache
296 | label_cache = self.lb_cache[label_fn]
297 | src_label = array_2_image(label_cache['array'], label_cache['spacing'], label_cache['origin'], self.LabelType)
298 |
299 | label = resample(
300 | image=src_label, imtype=self.LabelType,
301 | size=self.rs_size, spacing=self.rs_spacing, origin=patch_origin,
302 | transform=t, linear=False, dtype=np.int64)
303 |
304 | tmp_array = np.zeros_like(label['array'])
305 | lmap = self.label_map[d_name]
306 | for key in lmap:
307 | tmp_array[label['array'] == key] = lmap[key]
308 | label['array'] = tmp_array
309 | #label_bin = make_onehot(label['array'], cls=self.cls_num)
310 | label_exist = make_flag(cls=self.cls_num, labelmap=self.label_map[d_name])
311 |
312 | image_tensor = torch.from_numpy(image['array'])
313 | label_tensor = torch.from_numpy(label['array'])
314 |
315 | output = {}
316 | output['data'] = image_tensor
317 | output['label'] = label_tensor
318 | output['label_exist'] = label_exist
319 | output['dataset'] = d_name
320 | output['case'] = casename
321 | output['label_fname'] = label_fn
322 | output['size'] = image['size']
323 | output['spacing'] = image['spacing']
324 | output['origin'] = image['origin']
325 | output['org_size'] = image['org_size']
326 | output['org_spacing'] = image['org_spacing']
327 | output['org_origin'] = image['org_origin']
328 | output['eof'] = eof
329 |
330 | return output
--------------------------------------------------------------------------------
/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DIAL-RPI/FedCross/d64c637783a30a35599680d7913673cd2ea67898/fig1.png
--------------------------------------------------------------------------------
/fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DIAL-RPI/FedCross/d64c637783a30a35599680d7913673cd2ea67898/fig2.png
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | class BinaryDiceLoss(nn.Module):
6 | """Dice loss of binary class
7 | Args:
8 | smooth: A float number to smooth loss, and avoid NaN error, default: 1
9 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
10 | predict: A tensor of shape [N, *]
11 | target: A tensor of shape same with predict
12 | Returns:
13 | Loss tensor according to arg reduction
14 | Raise:
15 | Exception if unexpected reduction
16 | """
17 | def __init__(self, smooth=1, p=2):
18 | super(BinaryDiceLoss, self).__init__()
19 | self.smooth = smooth
20 | self.p = p
21 |
22 | def forward(self, predict, target, flag):
23 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
24 | predict = predict.contiguous().view(predict.shape[0], -1)
25 | target = target.contiguous().view(target.shape[0], -1)
26 |
27 | intersection = self.smooth
28 | union = self.smooth
29 | if flag is None:
30 | pd = predict
31 | gt = target
32 | intersection += torch.sum(pd*gt)*2
33 | union += torch.sum(pd.pow(self.p) + gt.pow(self.p))
34 | else:
35 | for i in range(target.shape[0]):
36 | if flag[i,0] > 0:
37 | pd = predict[i:i+1,:]
38 | gt = target[i:i+1,:]
39 | intersection += torch.sum(pd*gt)*2
40 | union += torch.sum(pd.pow(self.p) + gt.pow(self.p))
41 | dice = intersection / union
42 |
43 | loss = 1 - dice
44 | return loss
45 |
46 | class DiceLoss(nn.Module):
47 | """Dice loss, need one hot encode input
48 | Args:
49 | weight: An array of shape [num_classes,]
50 | ignore_index: class index to ignore
51 | predict: A tensor of shape [N, C, *]
52 | target: A tensor of same shape with predict
53 | other args pass to BinaryDiceLoss
54 | Return:
55 | same as BinaryDiceLoss
56 | """
57 | def __init__(self, weight=None, ignore_index=[], **kwargs):
58 | super(DiceLoss, self).__init__()
59 | self.kwargs = kwargs
60 | if weight is not None:
61 | self.weight = weight / weight.sum()
62 | else:
63 | self.weight = None
64 | self.ignore_index = ignore_index
65 |
66 | def forward(self, predict, target, flag=None):
67 | assert predict.shape == target.shape, 'predict & target shape do not match'
68 | dice = BinaryDiceLoss(**self.kwargs)
69 | total_loss = 0
70 | total_loss_num = 0
71 |
72 | for c in range(target.shape[1]):
73 | if c not in self.ignore_index:
74 | dice_loss = dice(predict[:, c], target[:, c], flag)
75 | if self.weight is not None:
76 | assert self.weight.shape[0] == target.shape[1], \
77 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
78 | dice_loss *= self.weight[c]
79 | total_loss += dice_loss
80 | total_loss_num += 1
81 |
82 | if self.weight is not None:
83 | return total_loss
84 | elif total_loss_num > 0:
85 | return total_loss/total_loss_num
86 | else:
87 | return 0
88 |
89 | def make_onehot(input, cls):
90 | oh_list = []
91 | for c in range(cls):
92 | tmp = torch.zeros_like(input)
93 | tmp[input==c] = 1
94 | oh_list.append(tmp)
95 | oh = torch.cat(oh_list, dim=1)
96 | return oh
97 |
98 | def dice_and_ce_loss(prob, logit, target):
99 | cls_num = np.sum(logit.shape[1])
100 | target_oh = make_onehot(target, cls=cls_num)
101 | ce_loss = nn.CrossEntropyLoss()
102 | dice_loss = DiceLoss()
103 | l_ce = ce_loss(logit, target.squeeze(dim=1))
104 | l_dice = dice_loss(prob, target_oh)
105 |
106 | return l_ce, l_dice
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | # Note:
5 | # Use itk here will cause deadlock after the first training epoch
6 | # when using multithread (dataloader num_workers > 0) but reason unknown
7 | import SimpleITK as sitk
8 | from pandas import DataFrame, read_csv
9 | from utils import read_image
10 |
11 | def keep_largest_component(image, largest_n=1):
12 | c_filter = sitk.ConnectedComponentImageFilter()
13 | obj_arr = sitk.GetArrayFromImage(c_filter.Execute(image))
14 | obj_num = c_filter.GetObjectCount()
15 | tmp_arr = np.zeros_like(obj_arr)
16 |
17 | if obj_num > 0:
18 | obj_vol = np.zeros(obj_num, dtype=np.int64)
19 | for obj_id in range(obj_num):
20 | tmp_arr2 = np.zeros_like(obj_arr)
21 | tmp_arr2[obj_arr == obj_id+1] = 1
22 | obj_vol[obj_id] = np.sum(tmp_arr2)
23 |
24 | sorted_obj_id = np.argsort(obj_vol)[::-1]
25 |
26 | for i in range(min(largest_n, obj_num)):
27 | tmp_arr[obj_arr == sorted_obj_id[i]+1] = 1
28 |
29 | output = sitk.GetImageFromArray(tmp_arr)
30 | output.SetSpacing(image.GetSpacing())
31 | output.SetOrigin(image.GetOrigin())
32 | output.SetDirection(image.GetDirection())
33 |
34 | return output
35 |
36 | def cal_dsc(pd, gt):
37 | y = (np.sum(pd * gt) * 2 + 1) / (np.sum(pd * pd + gt * gt) + 1)
38 | return y
39 |
40 | def cal_asd(a, b):
41 | filter1 = sitk.SignedMaurerDistanceMapImageFilter()
42 | filter1.SetUseImageSpacing(True)
43 | filter1.SetSquaredDistance(False)
44 | a_dist = filter1.Execute(a)
45 |
46 | a_dist = sitk.GetArrayFromImage(a_dist)
47 | a_dist = np.abs(a_dist)
48 | a_edge = np.zeros(a_dist.shape, a_dist.dtype)
49 | a_edge[a_dist == 0] = 1
50 | a_num = np.sum(a_edge)
51 |
52 | filter2 = sitk.SignedMaurerDistanceMapImageFilter()
53 | filter2.SetUseImageSpacing(True)
54 | filter2.SetSquaredDistance(False)
55 | b_dist = filter2.Execute(b)
56 |
57 | b_dist = sitk.GetArrayFromImage(b_dist)
58 | b_dist = np.abs(b_dist)
59 | b_edge = np.zeros(b_dist.shape, b_dist.dtype)
60 | b_edge[b_dist == 0] = 1
61 | b_num = np.sum(b_edge)
62 |
63 | a_dist[b_edge == 0] = 0.0
64 | b_dist[a_edge == 0] = 0.0
65 |
66 | #a2b_mean_dist = np.sum(b_dist) / a_num
67 | #b2a_mean_dist = np.sum(a_dist) / b_num
68 | asd = (np.sum(a_dist) + np.sum(b_dist)) / (a_num + b_num)
69 |
70 | return asd
71 |
72 | def cal_hd(a, b):
73 | filter1 = sitk.HausdorffDistanceImageFilter()
74 | filter1.Execute(a, b)
75 | hd = filter1.GetHausdorffDistance()
76 |
77 | return hd
78 |
79 | def eval(pd_path, gt_entries, label_map, cls_num, metric_fn, calc_asd=True, keep_largest=False):
80 | result_lines = ''
81 | df_fn = '{}/{}.csv'.format(pd_path, metric_fn)
82 | if not os.path.exists(df_fn):
83 | results = []
84 | print_line = '\n --- Start calculating metrics --- '
85 | print(print_line)
86 | result_lines += '{}\n'.format(print_line)
87 | for [d_name, casename, gt_fname] in gt_entries:
88 | gt_label = read_image(fname=gt_fname)
89 | gt_array = sitk.GetArrayFromImage(gt_label)
90 | gt_array = gt_array.astype(dtype=np.uint8)
91 |
92 | # map labels
93 | tmp_array = np.zeros_like(gt_array)
94 | lmap = label_map[d_name]
95 | tgt_labels = []
96 | for key in lmap:
97 | tmp_array[gt_array == key] = lmap[key]
98 | if lmap[key] not in tgt_labels:
99 | tgt_labels.append(lmap[key])
100 | gt_array = tmp_array
101 |
102 | pd_fname = '{}/{}@{}.nii.gz'.format(pd_path, d_name, casename)
103 | pd_array = sitk.GetArrayFromImage(read_image(fname=pd_fname))
104 |
105 | for c in tgt_labels:
106 | pd = np.zeros_like(pd_array, dtype=np.uint8)
107 | pd[pd_array == c] = 1
108 | pd_im = sitk.GetImageFromArray(pd)
109 | pd_im.SetSpacing(gt_label.GetSpacing())
110 | pd_im.SetOrigin(gt_label.GetOrigin())
111 | pd_im.SetDirection(gt_label.GetDirection())
112 | if keep_largest:
113 | pd_im = keep_largest_component(pd_im, largest_n=1)
114 | pd = sitk.GetArrayFromImage(pd_im)
115 | pd = pd.astype(dtype=np.uint8)
116 | pd = np.reshape(pd, -1)
117 |
118 | gt = np.zeros_like(gt_array, dtype=np.uint8)
119 | gt[gt_array == c] = 1
120 | gt_im = sitk.GetImageFromArray(gt)
121 | gt_im.SetSpacing(gt_label.GetSpacing())
122 | gt_im.SetOrigin(gt_label.GetOrigin())
123 | gt_im.SetDirection(gt_label.GetDirection())
124 | gt = np.reshape(gt, -1)
125 |
126 | dsc = cal_dsc(pd, gt)
127 | if calc_asd and np.sum(pd) > 0:
128 | asd = cal_asd(pd_im, gt_im)
129 | hd = cal_hd(pd_im, gt_im)
130 | else:
131 | asd = 0
132 | hd = 0
133 | results.append([d_name, casename, c, dsc, asd, hd])
134 |
135 | print_line = ' --- {0:22s}@{1:22s}@{2:d}:\t\tDSC = {3:>5.2f}%\tASD = {4:>5.2f}mm\tHD = {5:>5.2f}mm'.format(d_name, casename, c, dsc*100.0, asd, hd)
136 | print(print_line)
137 | result_lines += '{}\n'.format(print_line)
138 |
139 | df = DataFrame(results, columns=['Dataset', 'Case', 'Class', 'DSC', 'ASD', 'HD'])
140 | df.to_csv(df_fn)
141 | df = read_csv(df_fn)
142 |
143 |
144 | dsc = []
145 | asd = []
146 | hd = []
147 | dsc_mean = 0
148 | asd_mean = 0
149 | hd_mean = 0
150 | d_list = [d for d in set(df['Dataset'].tolist())]
151 | for d in d_list:
152 | dsc_m = df[df['Dataset'] == d]['DSC'].mean()
153 | dsc_v = df[df['Dataset'] == d]['DSC'].std()
154 | dsc.append([dsc_m, dsc_v])
155 | dsc_mean += dsc_m
156 | asd_m = df[df['Dataset'] == d]['ASD'].mean()
157 | asd_v = df[df['Dataset'] == d]['ASD'].std()
158 | asd.append([asd_m, asd_v])
159 | asd_mean += asd_m
160 | hd_m = df[df['Dataset'] == d]['HD'].mean()
161 | hd_v = df[df['Dataset'] == d]['HD'].std()
162 | hd.append([hd_m, hd_v])
163 | hd_mean += hd_m
164 | print_line = ' --- dataset {0:s}:\tDSC = {1:.2f}({2:.2f})%\tASD = {3:.2f}({4:.2f})mm\tHD = {5:.2f}({6:.2f})mm\tN={7:d}'.format(d, dsc_m*100.0, dsc_v*100.0, asd_m, asd_v, hd_m, hd_v, len(df[df['Dataset'] == d]['DSC']))
165 | print(print_line)
166 | result_lines += '{}\n'.format(print_line)
167 | dsc_mean = dsc_mean / len(d_list)
168 | asd_mean = asd_mean / len(d_list)
169 | hd_mean = hd_mean / len(d_list)
170 | dsc = np.array(dsc)
171 | asd = np.array(asd)
172 | hd = np.array(hd)
173 |
174 | print_line = ' --- dataset-avg:\tDSC = {0:.2f}%\tASD = {1:.2f}mm\tHD = {2:.2f}mm'.format(dsc_mean*100.0, asd_mean, hd_mean)
175 | print(print_line)
176 | result_lines += '{}\n'.format(print_line)
177 | print_line = ' --- Finish calculating metrics --- \n'
178 | print(print_line)
179 | result_lines += '{}\n'.format(print_line)
180 |
181 | result_fn = '{}/{}.txt'.format(pd_path, metric_fn)
182 | with open(result_fn, 'w') as result_file:
183 | result_file.write(result_lines)
184 |
185 | return dsc, asd, hd, dsc_mean, asd_mean, hd_mean
186 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class double_conv(nn.Module):
6 | def __init__(self, in_ch, out_ch):
7 | super(double_conv, self).__init__()
8 | self.conv = nn.Sequential(
9 | nn.Conv3d(in_ch, out_ch, 3, padding=1),
10 | nn.BatchNorm3d(out_ch),
11 | nn.ReLU(inplace=True),
12 | nn.Conv3d(out_ch, out_ch, 3, padding=1),
13 | nn.BatchNorm3d(out_ch),
14 | nn.ReLU(inplace=True)
15 | )
16 |
17 | def forward(self, x):
18 | y = self.conv(x)
19 | return y
20 |
21 | class enc_block(nn.Module):
22 | def __init__(self, in_ch, out_ch):
23 | super(enc_block, self).__init__()
24 | self.conv = double_conv(in_ch, out_ch)
25 | self.down = nn.MaxPool3d(2)
26 |
27 | def forward(self, x):
28 | y_conv = self.conv(x)
29 | y = self.down(y_conv)
30 | return y, y_conv
31 |
32 | class dec_block(nn.Module):
33 | def __init__(self, in_ch, out_ch):
34 | super(dec_block, self).__init__()
35 | self.conv = double_conv(in_ch, out_ch)
36 | self.up = nn.ConvTranspose3d(out_ch, out_ch, 2, stride=2)
37 |
38 | def forward(self, x):
39 | y_conv = self.conv(x)
40 | y = self.up(y_conv)
41 | return y, y_conv
42 |
43 | def concatenate(x1, x2):
44 | diffZ = x2.size()[2] - x1.size()[2]
45 | diffY = x2.size()[3] - x1.size()[3]
46 | diffX = x2.size()[4] - x1.size()[4]
47 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
48 | diffY // 2, diffY - diffY//2,
49 | diffZ // 2, diffZ - diffZ//2))
50 | y = torch.cat([x2, x1], dim=1)
51 | return y
52 |
53 | class UNet(nn.Module):
54 | def __init__(self, in_ch, base_ch, cls_num):
55 | super(UNet, self).__init__()
56 | self.in_ch = in_ch
57 | self.base_ch = base_ch
58 | self.cls_num = cls_num
59 |
60 | self.enc1 = enc_block(in_ch, base_ch)
61 | self.enc2 = enc_block(base_ch, base_ch*2)
62 | self.enc3 = enc_block(base_ch*2, base_ch*4)
63 | self.enc4 = enc_block(base_ch*4, base_ch*8)
64 |
65 | self.dec1 = dec_block(base_ch*8, base_ch*8)
66 | self.dec2 = dec_block(base_ch*8+base_ch*8, base_ch*4)
67 | self.dec3 = dec_block(base_ch*4+base_ch*4, base_ch*2)
68 | self.dec4 = dec_block(base_ch*2+base_ch*2, base_ch)
69 | self.lastconv = double_conv(base_ch+base_ch, base_ch)
70 |
71 | self.outconv = nn.Conv3d(base_ch, cls_num+1, 1)
72 | self.softmax = nn.Softmax(dim=1)
73 |
74 | def forward(self, x):
75 | x, enc1_conv = self.enc1(x)
76 | x, enc2_conv = self.enc2(x)
77 | x, enc3_conv = self.enc3(x)
78 | x, enc4_conv = self.enc4(x)
79 | x, _ = self.dec1(x)
80 | x, _ = self.dec2(concatenate(x, enc4_conv))
81 | x, _ = self.dec3(concatenate(x, enc3_conv))
82 | x, _ = self.dec4(concatenate(x, enc2_conv))
83 | x = self.lastconv(concatenate(x, enc1_conv))
84 | x = self.outconv(x)
85 | y = self.softmax(x)
86 |
87 | if self.training:
88 | return y, x
89 | else:
90 | return y
91 |
92 | def description(self):
93 | return 'U-Net (input channel = {0:d}) for {1:d}-class segmentation (base channel = {2:d})'.format(self.in_ch, self.cls_num+1, self.base_ch)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import random
4 | import time
5 | import torch
6 | import torch.nn as nn
7 | from torch import optim
8 | from torch.utils import data
9 | from dataset import create_folds, Dataset
10 | from model import UNet
11 | from loss import dice_and_ce_loss
12 | from utils import resample_array, output2file
13 | from metric import eval
14 | from config import cfg
15 |
16 | def initial_net(net):
17 | for m in net.modules():
18 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
19 | nn.init.xavier_normal_(m.weight)
20 | if m.bias is not None:
21 | nn.init.constant_(m.bias, 0)
22 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
23 | nn.init.constant_(m.weight, 1)
24 | if m.bias is not None:
25 | nn.init.constant_(m.bias, 0)
26 | elif isinstance(m, nn.Linear):
27 | nn.init.xavier_normal_(m.weight)
28 | if m.bias is not None:
29 | nn.init.constant_(m.bias, 0)
30 |
31 | def initialization():
32 |
33 | train_record = []
34 | for i in range(cfg['commu_times']):
35 | order = [(j + i) % len(cfg['node_list']) for j in range(len(cfg['node_list']))]
36 | random.shuffle(order)
37 | train_record.append(order)
38 |
39 | nodes = []
40 | val_fold = None
41 | test_fold = None
42 | weight_sum = 0
43 | for node_id, [node_name, d_name, d_path, fraction] in enumerate(cfg['node_list']):
44 |
45 | folds, _ = create_folds(d_name, d_path, node_name, fraction, exclude_case=cfg['exclude_case'])
46 |
47 | # create training fold
48 | train_fold = folds[0]
49 | d_train = Dataset(train_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], aug_data=True)
50 | dl_train = data.DataLoader(dataset=d_train, batch_size=cfg['batch_size'], shuffle=True, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread'])
51 |
52 | # create validaion fold
53 | if val_fold is None:
54 | val_fold = folds[1]
55 | else:
56 | val_fold.extend(folds[1])
57 |
58 | # create testing fold
59 | if test_fold is None:
60 | test_fold = folds[2]
61 | else:
62 | test_fold.extend(folds[2])
63 |
64 | print('{0:s}: train = {1:d}'.format(node_name, len(d_train)))
65 | weight_sum += len(d_train)
66 |
67 | local_model = nn.DataParallel(module=UNet(in_ch=1, base_ch=32, cls_num=cfg['cls_num']))
68 | local_model.cuda()
69 | initial_net(local_model)
70 |
71 | optimizer = optim.SGD(local_model.parameters(), lr=cfg['lr'], momentum=0.99, nesterov=True)
72 |
73 | lambda_func = lambda epoch: (1 - epoch / (cfg['commu_times'] * cfg['epoch_per_commu']))**0.9
74 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func)
75 |
76 | nodes.append([local_model, optimizer, scheduler, node_name, len(d_train), dl_train])
77 |
78 | d_val = Dataset(val_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], aug_data=False)
79 | dl_val = data.DataLoader(dataset=d_val, batch_size=cfg['test_batch_size'], shuffle=False, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread'])
80 |
81 | d_test = Dataset(test_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], aug_data=False)
82 | dl_test = data.DataLoader(dataset=d_test, batch_size=cfg['test_batch_size'], shuffle=False, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread'])
83 |
84 | print('{0:s}: val/test = {1:d}/{2:d}'.format(node_name, len(d_val), len(d_test)))
85 |
86 | for i in range(len(nodes)):
87 | nodes[i][4] = nodes[i][4] / weight_sum
88 | print('Weight of {0:s}: {1:f}'.format(nodes[i][3], nodes[i][4]))
89 |
90 | return nodes, train_record, dl_val, dl_test
91 |
92 | def exchange_local_models(nodes, train_record, commu_t, node_id):
93 | new_node_id = train_record[commu_t][node_id]
94 | return new_node_id, nodes[new_node_id][3], nodes[new_node_id][4], nodes[new_node_id][5]
95 |
96 | def train_local_model(local_model, optimizer, scheduler, data_loader, epoch_num):
97 | train_loss = 0
98 | train_loss_num = 0
99 | for epoch_id in range(epoch_num):
100 |
101 | t0 = time.perf_counter()
102 |
103 | epoch_loss = 0
104 | epoch_loss_num = 0
105 | batch_id = 0
106 | for batch in data_loader:
107 | image = batch['data'].cuda()
108 | label = batch['label'].cuda()
109 |
110 | N = len(image)
111 |
112 | pred, pred_logit = local_model(image)
113 |
114 | print_line = 'Epoch {0:d}/{1:d} (train) --- Progress {2:5.2f}% (+{3:02d})'.format(
115 | epoch_id+1, epoch_num, 100.0 * batch_id * cfg['batch_size'] / len(data_loader.dataset), N)
116 |
117 | l_ce, l_dice = dice_and_ce_loss(pred, pred_logit, label)
118 | loss_sup = l_dice + l_ce
119 | epoch_loss += l_dice.item() + l_ce.item()
120 | epoch_loss_num += 1
121 |
122 | print_line += ' --- Loss: {0:.6f}({1:.6f}/{2:.6f})'.format(loss_sup.item(), l_dice.item(), l_ce.item())
123 | print(print_line)
124 |
125 | optimizer.zero_grad()
126 | loss_sup.backward()
127 | optimizer.step()
128 |
129 | del image, label, pred, pred_logit, loss_sup
130 | batch_id += 1
131 |
132 | train_loss += epoch_loss
133 | train_loss_num += epoch_loss_num
134 | epoch_loss = epoch_loss / epoch_loss_num
135 | lr = scheduler.get_last_lr()[0]
136 |
137 | print_line = 'Epoch {0:d}/{1:d} (train) --- Loss: {2:.6f} --- Lr: {3:.6f}'.format(epoch_id+1, epoch_num, epoch_loss, lr)
138 | print(print_line)
139 |
140 | scheduler.step()
141 |
142 | t1 = time.perf_counter()
143 | epoch_t = t1 - t0
144 | print("Epoch time cost: {h:>02d}:{m:>02d}:{s:>02d}\n".format(
145 | h=int(epoch_t) // 3600, m=(int(epoch_t) % 3600) // 60, s=int(epoch_t) % 60))
146 |
147 | train_loss = train_loss / train_loss_num
148 |
149 | return train_loss
150 |
151 | # mode: 'val' or 'test'
152 | # commu_iters: communication iteration index, only available when mode == 'val'
153 | def eval_local_model(nodes, data_loader, result_path, mode, commu_iters):
154 | t0 = time.perf_counter()
155 |
156 | if mode == 'val':
157 | metric_fname = 'metric_validation-{0:04d}'.format(commu_iters+1)
158 | print('Validation ({0:d}/{1:d}) ...'.format(commu_iters+1, cfg['commu_times']))
159 | elif mode == 'test':
160 | metric_fname = 'metric_testing'
161 | print('Testing ...')
162 |
163 | gt_entries = []
164 | output_buffer = None
165 | std_buffer = None
166 | for batch_id, batch in enumerate(data_loader):
167 | image = batch['data'].cuda()
168 | N = len(image)
169 |
170 | Ey, Ey2 = None, None
171 | for [local_model, _, _, _, _, _] in nodes:
172 |
173 | local_model.eval()
174 | prob = local_model(image)
175 | y = prob[:,1,:].detach().cpu().numpy().copy()
176 |
177 | if Ey is None:
178 | Ey = np.zeros_like(y)
179 | Ey2 = np.zeros_like(y)
180 | Ey = Ey + y
181 | Ey2 = Ey2 + y**2
182 |
183 | del prob
184 |
185 | Ey = Ey/len(nodes)
186 | Ey2 = Ey2/len(nodes)
187 | mask = (Ey.copy()>0.5).astype(dtype=np.uint8)
188 |
189 | stddev = Ey2-Ey**2
190 | stddev[stddev<0] = 0
191 | stddev = np.sqrt(stddev) + 1.0
192 |
193 | print_line = '{0:s} --- Progress {1:5.2f}% (+{2:d})'.format(
194 | mode, 100.0 * batch_id * cfg['test_batch_size'] / len(data_loader.dataset), N)
195 | print(print_line)
196 |
197 | for i in range(N):
198 | sample_mask = resample_array(
199 | mask[i,:], batch['size'][i].numpy(), batch['spacing'][i].numpy(), batch['origin'][i].numpy(),
200 | batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy())
201 | if output_buffer is None:
202 | output_buffer = np.zeros_like(sample_mask, dtype=np.uint8)
203 | output_buffer[sample_mask > 0] = 0
204 | output_buffer = output_buffer + sample_mask
205 |
206 | if batch['eof'][i] == True:
207 | output2file(output_buffer, batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy(),
208 | '{0:s}/{1:s}@{2:s}.nii.gz'.format(result_path, batch['dataset'][i], batch['case'][i]))
209 | output_buffer = None
210 | gt_entries.append([batch['dataset'][i], batch['case'][i], batch['label_fname'][i]])
211 |
212 | if mode == 'test':
213 | sample_stddev = resample_array(
214 | stddev[i,:], batch['size'][i].numpy(), batch['spacing'][i].numpy(), batch['origin'][i].numpy(),
215 | batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy(), linear=True)
216 | if std_buffer is None:
217 | std_buffer = np.zeros_like(sample_stddev)
218 | std_buffer[sample_stddev >= 1.0] = 0
219 | sample_stddev[sample_stddev < 1.0] = 0
220 | sample_stddev = sample_stddev - 1.0
221 | sample_stddev[sample_stddev < 0.0] = 0
222 | std_buffer = std_buffer + sample_stddev
223 |
224 | if batch['eof'][i] == True:
225 | output2file(std_buffer, batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy(),
226 | '{0:s}/{1:s}@{2:s}-std.nii.gz'.format(result_path, batch['dataset'][i], batch['case'][i]))
227 | std_buffer = None
228 |
229 | del image
230 |
231 | seg_dsc, seg_asd, seg_hd, seg_dsc_m, seg_asd_m, seg_hd_m = eval(
232 | pd_path=result_path, gt_entries=gt_entries, label_map=cfg['label_map'], cls_num=cfg['cls_num'],
233 | metric_fn=metric_fname, calc_asd=(mode != 'val'), keep_largest=False)
234 |
235 | if mode == 'val':
236 | print_line = 'Validation result (iter = {0:d}/{1:d}) --- DSC {2:.2f} ({3:s})%'.format(
237 | commu_iters+1, cfg['commu_times'],
238 | seg_dsc_m*100.0, '/'.join(['%.2f']*len(seg_dsc[:,0])) % tuple(seg_dsc[:,0]*100.0))
239 | else:
240 | print_line = 'Testing results --- DSC {0:.2f} ({1:s})% --- ASD {2:.2f} ({3:s})mm --- HD {4:.2f} ({5:s})mm'.format(
241 | seg_dsc_m*100.0, '/'.join(['%.2f']*len(seg_dsc[:,0])) % tuple(seg_dsc[:,0]*100.0),
242 | seg_asd_m, '/'.join(['%.2f']*len(seg_asd[:,0])) % tuple(seg_asd[:,0]),
243 | seg_hd_m, '/'.join(['%.2f']*len(seg_hd[:,0])) % tuple(seg_hd[:,0]))
244 | print(print_line)
245 | t1 = time.perf_counter()
246 | eval_t = t1 - t0
247 | print("Evaluation time cost: {h:>02d}:{m:>02d}:{s:>02d}\n".format(
248 | h=int(eval_t) // 3600, m=(int(eval_t) % 3600) // 60, s=int(eval_t) % 60))
249 |
250 | return seg_dsc_m, seg_dsc
251 |
252 | def load_models(nodes, model_fname):
253 | for node_id in range(len(nodes)):
254 | nodes[node_id][0].load_state_dict(torch.load(model_fname)['local_model_{0:d}_state_dict'.format(node_id)])
255 | nodes[node_id][1].load_state_dict(torch.load(model_fname)['local_model_{0:d}_optimizer'.format(node_id)])
256 | nodes[node_id][2].load_state_dict(torch.load(model_fname)['local_model_{0:d}_scheduler'.format(node_id)])
257 |
258 | def train():
259 |
260 | train_start_time = time.localtime()
261 | print("Start time: {start_time}\n".format(start_time=time.strftime("%Y-%m-%d %H:%M:%S", train_start_time)))
262 | time_stamp = time.strftime("%Y%m%d%H%M%S", train_start_time)
263 |
264 | # create directory for results storage
265 | store_dir = '{}/model_{}'.format(cfg['model_path'], time_stamp)
266 | loss_fn = '{}/loss.txt'.format(store_dir)
267 | val_result_path = '{}/results_val'.format(store_dir)
268 | os.makedirs(val_result_path, exist_ok=True)
269 | test_result_path = '{}/results_test'.format(store_dir)
270 | os.makedirs(test_result_path, exist_ok=True)
271 |
272 | print('Loading local data from each nodes ... \n')
273 |
274 | nodes, train_record, dl_val, dl_test = initialization()
275 |
276 | print("Training order:", train_record)
277 |
278 | best_val_acc = 0
279 | start_iter = 0
280 | acc_time = 0
281 | best_model_fn = '{0:s}/cp_commu_{1:04d}.pth.tar'.format(store_dir, 1)
282 |
283 | print()
284 | log_line = "Model: {}\nModel parameters: {}\nStart time: {}\nConfiguration:\n".format(
285 | nodes[0][0].module.description(),
286 | sum(x.numel() for x in nodes[0][0].parameters()),
287 | time.strftime("%Y-%m-%d %H:%M:%S", train_start_time))
288 | for cfg_key in cfg:
289 | log_line += ' --- {}: {}\n'.format(cfg_key, cfg[cfg_key])
290 | print(log_line)
291 |
292 | for commu_t in range(start_iter, cfg['commu_times'], 1):
293 |
294 | t0 = time.perf_counter()
295 |
296 | train_loss = []
297 | for i, [local_model, optimizer, scheduler, _, _, _] in enumerate(nodes):
298 |
299 | node_id, node_name, node_weight, dl_train = exchange_local_models(nodes, train_record, commu_t, i)
300 |
301 | print('Training ({0:d}/{1:d}) on Node: {2:s}\n'.format(commu_t+1, cfg['commu_times'], node_name))
302 |
303 | local_model.train()
304 |
305 | train_loss.append(train_local_model(local_model, optimizer, scheduler, dl_train, cfg['epoch_per_commu']))
306 |
307 | seg_dsc_m, seg_dsc = eval_local_model(nodes, dl_val, val_result_path, mode='val', commu_iters=commu_t)
308 |
309 | t1 = time.perf_counter()
310 | epoch_t = t1 - t0
311 | acc_time += epoch_t
312 | print("Iteration time cost: {h:>02d}:{m:>02d}:{s:>02d}\n".format(
313 | h=int(epoch_t) // 3600, m=(int(epoch_t) % 3600) // 60, s=int(epoch_t) % 60))
314 |
315 | loss_line = '{commu_iter:>04d}\t{train_loss:s}\t{seg_val_dsc:>8.6f}\t{seg_val_dsc_cls:s}'.format(
316 | commu_iter=commu_t+1, train_loss='\t'.join(['%8.6f']*len(train_loss)) % tuple(train_loss),
317 | seg_val_dsc=seg_dsc_m, seg_val_dsc_cls='\t'.join(['%8.6f']*len(seg_dsc[:,0])) % tuple(seg_dsc[:,0])
318 | )
319 | for [_, _, scheduler, _, _, _] in nodes:
320 | loss_line += '\t{node_lr:>8.6f}'.format(node_lr=scheduler.get_last_lr()[0])
321 | loss_line += '\n'
322 |
323 | with open(loss_fn, 'a') as loss_file:
324 | loss_file.write(loss_line)
325 |
326 | # save best model
327 | if commu_t == 0 or seg_dsc_m > best_val_acc:
328 | # remove former best model
329 | if os.path.exists(best_model_fn):
330 | os.remove(best_model_fn)
331 | # save current best model
332 | best_val_acc = seg_dsc_m
333 | best_model_fn = '{0:s}/cp_commu_{1:04d}.pth.tar'.format(store_dir, commu_t+1)
334 | best_model_cp = {
335 | 'commu_iter':commu_t,
336 | 'acc_time':acc_time,
337 | 'time_stamp':time_stamp,
338 | 'best_val_acc':best_val_acc,
339 | 'best_model_filename':best_model_fn}
340 | for node_id, [local_model, optimizer, scheduler, _, _, _] in enumerate(nodes):
341 | best_model_cp['local_model_{0:d}_state_dict'.format(node_id)] = local_model.state_dict()
342 | best_model_cp['local_model_{0:d}_optimizer'.format(node_id)] = optimizer.state_dict()
343 | best_model_cp['local_model_{0:d}_scheduler'.format(node_id)] = scheduler.state_dict()
344 | torch.save(best_model_cp, best_model_fn)
345 | print('Best model (communication iteration = {}) saved.\n'.format(commu_t+1))
346 |
347 | print("Total training time: {h:>02d}:{m:>02d}:{s:>02d}\n\n".format(
348 | h=int(acc_time) // 3600, m=(int(acc_time) % 3600) // 60, s=int(acc_time) % 60))
349 |
350 | # test
351 | load_models(nodes, best_model_fn)
352 | eval_local_model(nodes, dl_test, test_result_path, mode='test', commu_iters=0)
353 |
354 | print("Finish time: {finish_time}\n\n".format(
355 | finish_time=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())))
356 |
357 | if __name__ == '__main__':
358 |
359 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg['gpu']
360 |
361 | train()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # Note:
3 | # Use itk here will cause deadlock after the first training epoch
4 | # when using multithread (dataloader num_workers > 0) but reason unknown
5 | import SimpleITK as sitk
6 |
7 | def read_image(fname):
8 | reader = sitk.ImageFileReader()
9 | reader.SetFileName(fname)
10 | image = reader.Execute()
11 | return image
12 |
13 | def resample_array(array, size, spacing, origin, size_rs, spacing_rs, origin_rs, transform=None, linear=False):
14 | array = np.reshape(array, [size[2], size[1], size[0]])
15 | image = sitk.GetImageFromArray(array)
16 | image.SetSpacing((float(spacing[0]), float(spacing[1]), float(spacing[2])))
17 | image.SetOrigin((float(origin[0]), float(origin[1]), float(origin[2])))
18 |
19 | resampler = sitk.ResampleImageFilter()
20 | resampler.SetSize((int(size_rs[0]), int(size_rs[1]), int(size_rs[2])))
21 | resampler.SetOutputSpacing((float(spacing_rs[0]), float(spacing_rs[1]), float(spacing_rs[2])))
22 | resampler.SetOutputOrigin((float(origin_rs[0]), float(origin_rs[1]), float(origin_rs[2])))
23 | if transform is not None:
24 | resampler.SetTransform(transform)
25 | else:
26 | resampler.SetTransform(sitk.Transform(3, sitk.sitkIdentity))
27 | if linear:
28 | resampler.SetInterpolator(sitk.sitkLinear)
29 | else:
30 | resampler.SetInterpolator(sitk.sitkNearestNeighbor)
31 | resampler.SetDefaultPixelValue(0)
32 | rs_image = resampler.Execute(image)
33 | rs_array = sitk.GetArrayFromImage(rs_image)
34 |
35 | return rs_array
36 |
37 | def output2file(array, size, spacing, origin, fname):
38 | array = np.reshape(array, [size[2], size[1], size[0]])#.astype(dtype=np.uint8)
39 | image = sitk.GetImageFromArray(array)
40 | image.SetSpacing((float(spacing[0]), float(spacing[1]), float(spacing[2])))
41 | image.SetOrigin((float(origin[0]), float(origin[1]), float(origin[2])))
42 |
43 | writer = sitk.ImageFileWriter()
44 | writer.SetFileName(fname)
45 | writer.Execute(image)
--------------------------------------------------------------------------------