├── .idea
├── SALMON.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
├── check_loader_patches.py
├── images
├── image.gif
├── label.gif
├── prostate.gif
├── prostate_inf.gif
├── prostate_label.gif
├── result.gif
├── salmon.JPG
├── salmon2.JPG
├── salmon3.JPG
├── salmon4.JPG
├── salmon5.JPG
└── salmon6.JPG
├── init.py
├── installation_commands_.txt
├── monai 0.5.0
├── LICENSE
├── README.md
├── check_loader_patches.py
├── deprecated
│ ├── LICENSE
│ ├── README.md
│ ├── check_loader_patches.py
│ ├── check_resolution.py
│ ├── docker commands_.txt
│ ├── images
│ │ ├── image.gif
│ │ ├── label.gif
│ │ ├── prostate.gif
│ │ ├── prostate_inf.gif
│ │ ├── prostate_label.gif
│ │ ├── result.gif
│ │ ├── salmon.JPG
│ │ ├── salmon2.JPG
│ │ ├── salmon3.JPG
│ │ ├── salmon4.JPG
│ │ ├── salmon5.JPG
│ │ └── salmon6.JPG
│ ├── init.py
│ ├── multi_label_segmentation_example
│ │ ├── check_loader_patches.py
│ │ ├── check_resolution.py
│ │ ├── init.py
│ │ ├── networks.py
│ │ ├── organize_folder_structure.py
│ │ ├── predict_single_image.py
│ │ └── train.py
│ ├── networks.py
│ ├── organize_folder_structure.py
│ ├── predict_single_image.py
│ ├── requirements.txt
│ └── train.py
├── init.py
├── installation_commands_.txt
├── multi_label_segmentation_example
│ ├── init.py
│ ├── predict_single_image.py
│ └── train.py
├── networks.py
├── organize_folder_structure.py
├── predict_single_image.py
├── requirements.txt
├── train.py
└── utils.py
├── networks.py
├── organize_folder_structure.py
├── predict_single_image.py
├── requirements.txt
├── train.py
└── utils.py
/.idea/SALMON.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 | 1632060253005
109 |
110 |
111 | 1632060253005
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 David Iommi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | # SALMON v.2: Segmentation deep learning ALgorithm based on MONai toolbox
3 | - SALMON is a computational toolbox for segmentation using neural networks (3D patches-based segmentation)
4 | - SALMON is based on MONAI 0.7.0 : PyTorch-based, open-source frameworks for deep learning in healthcare imaging.
5 | (https://github.com/Project-MONAI/MONAI)
6 | (https://github.com/MIC-DKFZ/nnUNet)
7 | (https://arxiv.org/abs/2103.10504)
8 |
9 | This is my "open-box" version if I want to modify the parameters for some particular task, while the two above are hard-coded. The monai 0.5.0 folder contains the previous versions based on the old monai version.
10 |
11 | *******************************************************************************
12 | ## Requirements
13 | Follow the steps in "installation_commands.txt". Installation via Anaconda and creation of a virtual env to download the python libraries and pytorch/cuda.
14 | *******************************************************************************
15 | ## Python scripts and their function
16 |
17 | - organize_folder_structure.py: Organize the data in the folder structure (training,validation,testing) for the network.
18 | Labels are resampled and resized to the corresponding image, to avoid array size conflicts. You can set here a new image resolution for the dataset.
19 |
20 | - init.py: List of options used to train the network.
21 |
22 | - check_loader_patches: Shows example of patches fed to the network during the training.
23 |
24 | - networks.py: The architectures available for segmentation are nn-Unet and UneTR (based on Visual transformers)
25 |
26 | - train.py: Runs the training
27 |
28 | - predict_single_image.py: It launches the inference on a single input image chosen by the user.
29 | *******************************************************************************
30 | ## Usage
31 | ### Folders structure:
32 |
33 | Use first "organize_folder_structure.py" to create organize the data.
34 | Modify the input parameters to select the two folders: images and labels folders with the dataset. Set the resolution of the images here before training.
35 |
36 | .
37 | ├── Data_folder
38 | | ├── CT
39 | | | ├── 1.nii
40 | | | ├── 2.nii
41 | | | └── 3.nii
42 | | ├── CT_labels
43 | | | ├── 1.nii
44 | | | ├── 2.nii
45 | | | └── 3.nii
46 |
47 | Data structure after running it:
48 |
49 | .
50 | ├── Data_folder
51 | | ├── CT
52 | | ├── CT_labels
53 | | ├── images
54 | | | ├── train
55 | | | | ├── image1.nii
56 | | | | └── image2.nii
57 | | | └── val
58 | | | | ├── image3.nii
59 | | | | └── image4.nii
60 | | | └── test
61 | | | | ├── image5.nii
62 | | | | └── image6.nii
63 | | ├── labels
64 | | | ├── train
65 | | | | ├── label1.nii
66 | | | | └── label2.nii
67 | | | └── val
68 | | | | ├── label3.nii
69 | | | | └── label4.nii
70 | | | └── test
71 | | | | ├── label5.nii
72 | | | | └── label6.nii
73 |
74 | *******************************************************************************
75 | ### Training:
76 | - Modify the "init.py" to set the parameters and start the training/testing on the data. Read the descriptions for each parameter.
77 | - Afterwards launch the "train.py" for training. Tensorboard is available to monitor the training ("runs" folder created)
78 | - Check and modify the train_transforms applied to the images in "train.py" for your specific case. (e.g. In the last update there is a HU windowing for CT images)
79 |
80 | Sample images: the following images show the segmentation of carotid artery from MRI sequence
81 |
82 | 
83 |
84 | Sample images: the following images show the multi-label segmentation of prostate transition zone and peripheral zone from MRI sequence
85 |
86 | !
87 |
88 | *******************************************************************************
89 | ### Inference:
90 | - Launch "predict_single_image.py" to test the network. Modify the parameters in the parse section to select the path of the weights, images to infer and result.
91 | - You can test the model on a new image, with different size and resolution from the training. The script will resample it before the inference and give you a mask
92 | with same size and resolution of the source image.
93 | *******************************************************************************
94 | ### Tips:
95 | - Use and modify "check_loader_patches.py" to check the patches fed during training.
96 | - The "networks.py" calls the nn-Unet, which adapts itself to the input data (resolution and patches size). The script also saves the graph of you network, so you can visualize it.
97 | - "networks.py" includes also UneTR (based on Visual transformers). This is experimental. For more info check (https://arxiv.org/abs/2103.10504) and https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unetr_btcv_segmentation_3d.ipynb
98 | - Is it possible to add other networks, but for segmentation the U-net architecture is the state of the art.
99 |
100 | ### Sample script inference
101 | - The label can be omitted (None) if you segment an unknown image. You have to add the --resolution if you resampled the data during training (look at the argsparse in the code).
102 | ```console
103 | python predict_single_image.py --image './Data_folder/image.nii' --label './Data_folder/label.nii' --result './Data_folder/prova.nii' --weights './best_metric_model.pth'
104 | ```
105 | *******************************************************************************
106 |
107 | ### Some note:
108 | - Tensorboard can show you all segmented channels, but for now the metric is the Mean-Dice (of all channels). If you want to evaluate the Dice score for each channel you
109 | have to modify a bit the plot_dice function. I will do it...one day...who knows...maybe not
110 | - The loss is the DiceLoss + CrossEntropy. You can modify it if you want to try others (https://docs.monai.io/en/latest/losses.html#diceloss)
111 |
112 | Check more examples at https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/.
113 |
114 | ### UneTR Notes from the authors:
115 |
116 | Feature_size and pos_embed are the parameters that need to changed to adopt it for your application of interest. Other parameters that are mentioned come from Vision Transformer (ViT) default hyper-parameters (original architecture). In addition, the new revision of UNETR paper with more descriptions is now publicly available. Please check for more details:
117 | https://arxiv.org/pdf/2103.10504.pdf
118 |
119 | Now let's look at each of these hyper-parameters in the order of importance:
120 |
121 | - feature_size : In UNETR, we multiply the size of the CNN-based features in the decoder by a factor of 2 at every resolution ( just like the original UNet paper). By default, we set this value to 16 ( to make the entire network lighter). However using larger values such as 32 can improve the segmentation performance if GPU memory is not an issue. Figure2 of the paper also shows this in details.
122 |
123 | - pos_embed: this determines how the image is divided into non-overlapping patches. Essentially, there are 2 ways to achieve this ( by setting it to conv or perceptron). Let's further dive into it for more information:
124 | First is by directly applying a convolutional layer with the same stride and kernel size of the patch size and with feature size of the hidden size in the ViT model. Second is by first breaking the image into patches by properly resizing the tensor ( for which we use einops) and then feed it into a perceptron (linear) layer with a hidden size of the ViT model. Our experiments show that for certain applications such as brain segmentation with multiple modalities (e.g. 4 modes such as T1,T2 etc.), using the convolutional layer works better as it takes into account all modes concurrently. For CT images ( e.g. BTCV multi-organ segmentation), we did not see any difference in terms of performance between these two approaches.
125 |
126 | - hidden_size : this is the size of the hidden layers in the ViT encoder. We follow the original ViT model and set this value to 768. In addition, the hidden size should be divisible by the number of attention heads in the ViT model.
127 |
128 | - num_heads : in the multi-headed self-attention block, this is the number of attention heads. Following the ViT architecture, we set it to 12.
129 |
130 | - mlp_dim : this is the dimension of the multi-layer perceptrons (MLP) in the transformer encoder. Again, we follow the ViT model and set this to 3072 as default value to be consistent with their architecture.
131 |
132 |
--------------------------------------------------------------------------------
/check_loader_patches.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | import os
13 | import sys
14 | from glob import glob
15 | import tempfile
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 | import nibabel as nib
19 | import torch
20 | from torch.utils.data import DataLoader
21 | from init import Options
22 | import monai
23 | from monai.data import ArrayDataset, GridPatchDataset, create_test_image_3d
24 | from monai.transforms import (Compose, LoadImaged, AddChanneld, Transpose, Resized, CropForegroundd, CastToTyped,RandGaussianSmoothd,
25 | ScaleIntensityd, ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, SpatialPadd,
26 | Spacingd, Orientationd, RandZoomd, ThresholdIntensityd, RandShiftIntensityd, RandGaussianNoised, BorderPadd,RandAdjustContrastd, NormalizeIntensityd,RandFlipd, ScaleIntensityRanged)
27 |
28 |
29 | class IndexTracker(object):
30 | def __init__(self, ax, X):
31 | self.ax = ax
32 | ax.set_title('use scroll wheel to navigate images')
33 |
34 | self.X = X
35 | rows, cols, self.slices = X.shape
36 | self.ind = self.slices//2
37 |
38 | self.im = ax.imshow(self.X[:, :, self.ind],cmap= 'gray')
39 | self.update()
40 |
41 | def onscroll(self, event):
42 | print("%s %s" % (event.button, event.step))
43 | if event.button == 'up':
44 | self.ind = (self.ind + 1) % self.slices
45 | else:
46 | self.ind = (self.ind - 1) % self.slices
47 | self.update()
48 |
49 | def update(self):
50 | self.im.set_data(self.X[:, :, self.ind])
51 | self.ax.set_ylabel('slice %s' % self.ind)
52 | self.im.axes.figure.canvas.draw()
53 |
54 |
55 | def plot3d(image):
56 | original=image
57 | original = np.rot90(original, k=-1)
58 | fig, ax = plt.subplots(1, 1)
59 | tracker = IndexTracker(ax, original)
60 | fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
61 | plt.show()
62 |
63 |
64 | if __name__ == "__main__":
65 |
66 | opt = Options().parse()
67 |
68 | train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
69 | train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))
70 |
71 | data_dicts = [{'image': image_name, 'label': label_name}
72 | for image_name, label_name in zip(train_images, train_segs)]
73 |
74 | monai_transforms = [
75 |
76 | LoadImaged(keys=['image', 'label']),
77 | AddChanneld(keys=['image', 'label']),
78 | # Orientationd(keys=["image", "label"], axcodes="RAS"),
79 | # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
80 | # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
81 | CropForegroundd(keys=['image', 'label'], source_key='image', start_coord_key='foreground_start_coord',
82 | end_coord_key='foreground_end_coord', ), # crop CropForeground
83 | NormalizeIntensityd(keys=['image']),
84 | ScaleIntensityd(keys=['image']),
85 | # Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
86 |
87 | SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),
88 | RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
89 | ToTensord(keys=['image', 'label','foreground_start_coord', 'foreground_end_coord'],)
90 | ]
91 |
92 | transform = Compose(monai_transforms)
93 | check_ds = monai.data.Dataset(data=data_dicts, transform=transform)
94 | loader = DataLoader(check_ds, batch_size=1, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
95 | check_data = monai.utils.misc.first(loader)
96 | im, seg, coord1, coord2 = (check_data['image'][0], check_data['label'][0],check_data['foreground_start_coord'][0],
97 | check_data['foreground_end_coord'][0])
98 |
99 | print(im.shape, seg.shape, coord1, coord2)
100 |
101 | vol = im[0].numpy()
102 | mask = seg[0].numpy()
103 |
104 | print(vol.shape, mask.shape)
105 | plot3d(vol)
106 | plot3d(mask)
107 |
--------------------------------------------------------------------------------
/images/image.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/image.gif
--------------------------------------------------------------------------------
/images/label.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/label.gif
--------------------------------------------------------------------------------
/images/prostate.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/prostate.gif
--------------------------------------------------------------------------------
/images/prostate_inf.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/prostate_inf.gif
--------------------------------------------------------------------------------
/images/prostate_label.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/prostate_label.gif
--------------------------------------------------------------------------------
/images/result.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/result.gif
--------------------------------------------------------------------------------
/images/salmon.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/salmon.JPG
--------------------------------------------------------------------------------
/images/salmon2.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/salmon2.JPG
--------------------------------------------------------------------------------
/images/salmon3.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/salmon3.JPG
--------------------------------------------------------------------------------
/images/salmon4.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/salmon4.JPG
--------------------------------------------------------------------------------
/images/salmon5.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/salmon5.JPG
--------------------------------------------------------------------------------
/images/salmon6.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/images/salmon6.JPG
--------------------------------------------------------------------------------
/init.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 |
5 | class Options():
6 |
7 | """This class defines options used during both training and test time."""
8 |
9 | def __init__(self):
10 | """Reset the class; indicates the class hasn't been initailized"""
11 | self.initialized = False
12 |
13 | def initialize(self, parser):
14 |
15 | # basic parameters
16 | parser.add_argument('--images_folder', type=str, default='./Data_folder/images')
17 | parser.add_argument('--labels_folder', type=str, default='./Data_folder/labels')
18 | parser.add_argument('--increase_factor_data', default=1, help='Increase data number per epoch')
19 | parser.add_argument('--preload', type=str, default=None)
20 | parser.add_argument('--gpu_ids', type=str, default='2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
21 | parser.add_argument('--workers', default=8, type=int, help='number of data loading workers')
22 |
23 | # dataset parameters
24 | parser.add_argument('--network', default='unetr', help='nnunet, unetr')
25 | parser.add_argument('--patch_size', default=(256, 256, 16), help='Size of the patches extracted from the image')
26 | parser.add_argument('--spacing', default=[0.7, 0.7, 3], help='Original Resolution')
27 | parser.add_argument('--resolution', default=None, help='New Resolution, if you want to resample the data in training. I suggest to resample in organize_folder_structure.py, otherwise in train resampling is slower')
28 | parser.add_argument('--batch_size', type=int, default=4, help='batch size, depends on your machine')
29 | parser.add_argument('--in_channels', default=1, type=int, help='Channels of the input')
30 | parser.add_argument('--out_channels', default=1, type=int, help='Channels of the output')
31 |
32 | # training parameters
33 | parser.add_argument('--epochs', default=1000, help='Number of epochs')
34 | parser.add_argument('--lr', default=0.01, help='Learning rate')
35 | parser.add_argument('--benchmark', default=True)
36 |
37 | # Inference
38 | # This is just a trick to make the predict script working, do not touch it now for the training.
39 | parser.add_argument('--result', default=None, help='Keep this empty and go to predict_single_image script')
40 | parser.add_argument('--weights', default=None, help='Keep this empty and go to predict_single_image script')
41 |
42 | self.initialized = True
43 | return parser
44 |
45 | def parse(self):
46 | if not self.initialized:
47 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
48 | parser = self.initialize(parser)
49 | opt = parser.parse_args()
50 | # set gpu ids
51 | if opt.gpu_ids != '-1':
52 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
53 | return opt
54 |
55 |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/installation_commands_.txt:
--------------------------------------------------------------------------------
1 |
2 | 1) set up anaconda env: conda create -n monai_david python=3.8
3 | conda activate monai_david
4 |
5 | 2) install pytorch conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.1 -c pytorch # this for cuda (check your cuda version)
6 | 2b) conda install pytorch==1.5.0 torchvision==0.6.0 cpuonly -c pytorch # this for cpu
7 |
8 | 3) conda install git pip
9 | 4) pip install git+https://github.com/davidiommi/MONAI_0_7_0.git # dowload libraries
10 | 5) pip install -r requirements.txt # dowload libraries
11 |
12 |
13 |
--------------------------------------------------------------------------------
/monai 0.5.0/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 David Iommi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/monai 0.5.0/README.md:
--------------------------------------------------------------------------------
1 | 
2 | # SALMON v.2: Segmentation deep learning ALgorithm based on MONai toolbox
3 | - SALMON is a computational toolbox for segmentation using neural networks (3D patches-based segmentation)
4 | - SALMON is based on NN-UNET and MONAI: PyTorch-based, open-source frameworks for deep learning in healthcare imaging.
5 | (https://github.com/Project-MONAI/MONAI)
6 | (https://github.com/MIC-DKFZ/nnUNet)
7 |
8 | This is my "open-box" version if I want to modify the parameters for some particular task, while the two above are hard-coded.
9 |
10 | *******************************************************************************
11 | ## Requirements
12 | Follow the steps in "installation_commands.txt". Installation via Anaconda and creation of a virtual env to download the python libraries and pytorch/cuda.
13 | *******************************************************************************
14 | ## Python scripts and their function
15 |
16 | - organize_folder_structure.py: Organize the data in the folder structure (training,validation,testing) for the network.
17 | Labels are resampled and resized to the corresponding image, to avoid array size conflicts. You can set here a new image resolution for the dataset.
18 |
19 | - init.py: List of options used to train the network.
20 |
21 | - check_loader_patches: Shows example of patches fed to the network during the training.
22 |
23 | - networks.py: The architecture available for segmentation is a nn-Unet.
24 |
25 | - train.py: Runs the training
26 |
27 | - predict_single_image.py: It launches the inference on a single input image chosen by the user.
28 | *******************************************************************************
29 | ## Usage
30 | ### Folders structure:
31 |
32 | Use first "organize_folder_structure.py" to create organize the data.
33 | Modify the input parameters to select the two folders: images and labels folders with the dataset. Set the resolution of the images here before training.
34 |
35 | .
36 | ├── Data_folder
37 | | ├── CT
38 | | | ├── 1.nii
39 | | | ├── 2.nii
40 | | | └── 3.nii
41 | | ├── CT_labels
42 | | | ├── 1.nii
43 | | | ├── 2.nii
44 | | | └── 3.nii
45 |
46 | Data structure after running it:
47 |
48 | .
49 | ├── Data_folder
50 | | ├── CT
51 | | ├── CT_labels
52 | | ├── images
53 | | | ├── train
54 | | | | ├── image1.nii
55 | | | | └── image2.nii
56 | | | └── val
57 | | | | ├── image3.nii
58 | | | | └── image4.nii
59 | | | └── test
60 | | | | ├── image5.nii
61 | | | | └── image6.nii
62 | | ├── labels
63 | | | ├── train
64 | | | | ├── label1.nii
65 | | | | └── label2.nii
66 | | | └── val
67 | | | | ├── label3.nii
68 | | | | └── label4.nii
69 | | | └── test
70 | | | | ├── label5.nii
71 | | | | └── label6.nii
72 |
73 | *******************************************************************************
74 | ### Training:
75 | - Modify the "init.py" to set the parameters and start the training/testing on the data. Read the descriptions for each parameter.
76 | - Afterwards launch the "train.py" for training. Tensorboard is available to monitor the training ("runs" folder created)
77 | - Check and modify the train_transforms applied to the images in "train.py" for your specific case. (e.g. In the last update there is a HU windowing for CT images)
78 |
79 | Sample images: the following images show the segmentation of carotid artery from MRI sequence
80 |
81 | 
82 |
83 | Sample images: the following images show the multi-label segmentation of prostate transition zone and peripheral zone from MRI sequence
84 |
85 | !
86 |
87 | *******************************************************************************
88 | ### Inference:
89 | - Launch "predict_single_image.py" to test the network. Modify the parameters in the parse section to select the path of the weights, images to infer and result.
90 | - You can test the model on a new image, with different size and resolution from the training. The script will resample it before the inference and give you a mask
91 | with same size and resolution of the source image.
92 | *******************************************************************************
93 | ### Tips:
94 | - Use and modify "check_loader_patches.py" to check the patches fed during training.
95 | - The "networks.py" calls the nn-Unet, which adapts itself to the input data (resolution and patches size). The script also saves the graph of you network, so you can visualize it.
96 | - Is it possible to add other networks, but for segmentation the U-net architecture is the state of the art.
97 |
98 | ### Sample script inference
99 | - The label can be omitted (None) if you segment an unknown image. You have to add the --resolution if you resampled the data during training (look at the argsparse in the code).
100 | ```console
101 | python predict_single_image.py --image './Data_folder/image.nii' --label './Data_folder/label.nii' --result './Data_folder/prova.nii' --weights './best_metric_model.pth'
102 | ```
103 | *******************************************************************************
104 | ### Multi-channel segmentation:
105 |
106 | The subfolder "multi_label_segmentation_example" include the modified code for multi_labels scenario.
107 | The example segment the prostate (1 channel input) in the transition zone and peripheral zone (2 channels output).
108 | The gif files with some example images are shown above.
109 |
110 | Some note:
111 | - You must add an additional channel for the background. Example: 0 background, 1 prostate, 2 prostate tumor = 3 out channels in total.
112 | - Tensorboard can show you all segmented channels, but for now the metric is the Mean-Dice (of all channels). If you want to evaluate the Dice score for each channel you
113 | have to modify a bit the plot_dice function. I will do it...one day...who knows...maybe not
114 | - The loss is the DiceLoss + CrossEntropy. You can modify it if you want to try others (https://docs.monai.io/en/latest/losses.html#diceloss)
115 |
116 | Check more examples at https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb.
117 |
--------------------------------------------------------------------------------
/monai 0.5.0/check_loader_patches.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | import os
13 | import sys
14 | from glob import glob
15 | import tempfile
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 | import nibabel as nib
19 | import torch
20 | from torch.utils.data import DataLoader
21 | from init import Options
22 | import monai
23 | from monai.data import ArrayDataset, GridPatchDataset, create_test_image_3d
24 | from monai.transforms import (Compose, LoadImaged, AddChanneld, Transpose, Resized, CropForegroundd, CastToTyped,RandGaussianSmoothd,
25 | ScaleIntensityd, ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, SpatialPadd,
26 | Spacingd, Orientationd, RandZoomd, ThresholdIntensityd, RandShiftIntensityd, RandGaussianNoised, BorderPadd,RandAdjustContrastd, NormalizeIntensityd,RandFlipd, ScaleIntensityRanged)
27 |
28 |
29 | class IndexTracker(object):
30 | def __init__(self, ax, X):
31 | self.ax = ax
32 | ax.set_title('use scroll wheel to navigate images')
33 |
34 | self.X = X
35 | rows, cols, self.slices = X.shape
36 | self.ind = self.slices//2
37 |
38 | self.im = ax.imshow(self.X[:, :, self.ind],cmap= 'gray')
39 | self.update()
40 |
41 | def onscroll(self, event):
42 | print("%s %s" % (event.button, event.step))
43 | if event.button == 'up':
44 | self.ind = (self.ind + 1) % self.slices
45 | else:
46 | self.ind = (self.ind - 1) % self.slices
47 | self.update()
48 |
49 | def update(self):
50 | self.im.set_data(self.X[:, :, self.ind])
51 | self.ax.set_ylabel('slice %s' % self.ind)
52 | self.im.axes.figure.canvas.draw()
53 |
54 |
55 | def plot3d(image):
56 | original=image
57 | original = np.rot90(original, k=-1)
58 | fig, ax = plt.subplots(1, 1)
59 | tracker = IndexTracker(ax, original)
60 | fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
61 | plt.show()
62 |
63 |
64 | if __name__ == "__main__":
65 |
66 | opt = Options().parse()
67 |
68 | train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
69 | train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))
70 |
71 | data_dicts = [{'image': image_name, 'label': label_name}
72 | for image_name, label_name in zip(train_images, train_segs)]
73 |
74 | monai_transforms = [
75 |
76 | LoadImaged(keys=['image', 'label']),
77 | AddChanneld(keys=['image', 'label']),
78 | Orientationd(keys=["image", "label"], axcodes="RAS"),
79 | # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
80 | # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
81 | CropForegroundd(keys=['image', 'label'], source_key='image', start_coord_key='foreground_start_coord',
82 | end_coord_key='foreground_end_coord', ), # crop CropForeground
83 | NormalizeIntensityd(keys=['image']),
84 | ScaleIntensityd(keys=['image']),
85 | # Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
86 |
87 | SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),
88 | RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
89 | ToTensord(keys=['image', 'label','foreground_start_coord', 'foreground_end_coord'],)
90 | ]
91 |
92 | transform = Compose(monai_transforms)
93 | check_ds = monai.data.Dataset(data=data_dicts, transform=transform)
94 | loader = DataLoader(check_ds, batch_size=1, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
95 | check_data = monai.utils.misc.first(loader)
96 | im, seg, coord1, coord2 = (check_data['image'][0], check_data['label'][0],check_data['foreground_start_coord'][0],
97 | check_data['foreground_end_coord'][0])
98 |
99 | print(im.shape, seg.shape, coord1, coord2)
100 |
101 | vol = im[0].numpy()
102 | mask = seg[0].numpy()
103 |
104 | print(vol.shape, mask.shape)
105 | plot3d(vol)
106 | plot3d(mask)
107 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 David Iommi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/README.md:
--------------------------------------------------------------------------------
1 | 
2 | # SALMON v.2: Segmentation deep learning ALgorithm based on MONai toolbox
3 | - SALMON is a computational toolbox for segmentation using neural networks (3D patches-based segmentation)
4 | - SALMON is based on NN-UNET and MONAI: PyTorch-based, open-source frameworks for deep learning in healthcare imaging.
5 | (https://github.com/Project-MONAI/MONAI)
6 | (https://github.com/MIC-DKFZ/nnUNet)
7 |
8 | This is my "open-box" version if I want to modify the parameters for some particular task, while the two above are hard-coded.
9 |
10 | *******************************************************************************
11 | ## Requirements
12 | We download the official MONAI DockerHub, with the latest MONAI version. Please visit https://docs.monai.io/en/latest/installation.html
13 | Additional packages can be installed with "pip install -r requirements.txt"
14 | *******************************************************************************
15 | ## Python scripts and their function
16 |
17 | - organize_folder_structure.py: Organize the data in the folder structure (training,validation,testing) for the network. Labels are resampled and resized to the corresponding image, to avoid conflicts.
18 |
19 | - init.py: List of options used to train the network.
20 |
21 | - check_loader_patches: Shows example of patches fed to the network during the training.
22 |
23 | - networks.py: The architecture available for segmentation is a nn-Unet.
24 |
25 | - train.py: Runs the training
26 |
27 | - predict_single_image.py: It launches the inference on a single input image chosen by the user.
28 | *******************************************************************************
29 | ## Usage
30 | ### Folders structure:
31 |
32 | Use first "organize_folder_structure.py" to create organize the data.
33 | Modify the input parameters to select the two folders: images and labels folders with the dataset.
34 |
35 | .
36 | ├── Data_folder
37 | | ├── CT
38 | | | ├── 1.nii
39 | | | ├── 2.nii
40 | | | └── 3.nii
41 | | ├── CT_labels
42 | | | ├── 1.nii
43 | | | ├── 2.nii
44 | | | └── 3.nii
45 |
46 | Data structure after running it:
47 |
48 | .
49 | ├── Data_folder
50 | | ├── images
51 | | | ├── train
52 | | | | ├── image1.nii
53 | | | | └── image2.nii
54 | | | └── val
55 | | | | ├── image3.nii
56 | | | | └── image4.nii
57 | | | └── test
58 | | | | ├── image5.nii
59 | | | | └── image6.nii
60 | | ├── labels
61 | | | ├── train
62 | | | | ├── label1.nii
63 | | | | └── label2.nii
64 | | | └── val
65 | | | | ├── label3.nii
66 | | | | └── label4.nii
67 | | | └── test
68 | | | | ├── label5.nii
69 | | | | └── label6.nii
70 |
71 | *******************************************************************************
72 | ### Training:
73 | - Modify the "init.py" to set the parameters and start the training/testing on the data. Read the descriptions for each parameter.
74 | - Afterwards launch the train.py for training. Tensorboard is available to monitor the training:
75 |
76 | 
77 |
78 | Sample images: on the left side the image, in the middle the result of the segmentation and on the right side the true label
79 | The following images show the segmentation of carotid artery from MR sequence
80 |
81 | 
82 |
83 | Sample images: on the left side the image, in the middle the result of the segmentation and on the right side the true label
84 | The following images show the multi-label segmentation of prostate transition zone and peripheral zone from MR sequence
85 |
86 | 
87 | *******************************************************************************
88 | ### Inference:
89 | Launch "predict_single_image.py" to test the network. Modify the parameters in the parse section to select the path of the weights, images to infer and result.
90 | *******************************************************************************
91 | ### Tips:
92 | - Use and modify "check_loader_patches.py" to check the patches fed during training.
93 | - "Organize_folder_structure.py" solves dimensionality conflicts if the label has different size and resolution of the image. Check on 3DSlicer or ITKSnap if your data are correctly centered-overlaid.
94 | - The "networks.py" calls the nn-Unet, which adapts itself to the input data (resolution and patches size). The script also saves the graph of you network, so you can visualize it.
95 | - Is it possible to add other networks, but for segmentation the U-net architecture is the state of the art.
96 | - During the training phase, the script crops the image background and pad the image if the post-cropping size is smaller than the patch size you set in init.py
97 |
98 |
99 | ### Sample script inference
100 | - The label can be omitted (None) if you segment an unknown image. You can add the --resolution if you resampled the data during training (look at the argsparse in the code).
101 | ```console
102 | python predict_single_image.py --image './Data_folder/images/train/image13.nii' --label './Data_folder/labels/train/label13.nii' --result './Data_folder/results/train/prova.nii' --weights './best_metric_model.pth'
103 | ```
104 | *******************************************************************************
105 | ### Multi-channel segmentation:
106 |
107 | The subfolder "multi_label_segmentation_example" include the modified code for multi_labels scenario.
108 | The example segment the prostate (1 channel input) in the transition zone and peripheral zone (2 channels output).
109 | The gif files with some example images are shown above.
110 |
111 | Some note:
112 | - You must add an additional channel for the background. Example: 0 background, 1 prostate, 2 prostate tumor = 3 out channels in total.
113 | - Tensorboard can show you all segmented channels, but for now the metric is the Mean-Dice (of all channels). If you want to evaluate the Dice score for each channel you
114 | have to modify a bit the plot_dice function. I will do it...one day...who knows...maybe not
115 | - The loss is the DiceLoss + CrossEntropy. You can modify it if you want to try others (https://docs.monai.io/en/latest/losses.html#diceloss)
116 |
117 | Check more examples at https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb.
118 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/check_loader_patches.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | import os
13 | import sys
14 | from glob import glob
15 | import tempfile
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 | import nibabel as nib
19 | import torch
20 | from torch.utils.data import DataLoader
21 | from init import Options
22 | import monai
23 | from monai.data import ArrayDataset, GridPatchDataset, create_test_image_3d
24 | from monai.transforms import (Compose, LoadImaged, AddChanneld, Transpose, Resized, CropForegroundd, CastToTyped,RandGaussianSmoothd,
25 | ScaleIntensityd, ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, SpatialPadd,
26 | Spacingd, Orientationd, RandZoomd, ThresholdIntensityd, RandShiftIntensityd, RandGaussianNoised, BorderPadd,RandAdjustContrastd, NormalizeIntensityd,RandFlipd, ScaleIntensityRanged)
27 |
28 |
29 | class IndexTracker(object):
30 | def __init__(self, ax, X):
31 | self.ax = ax
32 | ax.set_title('use scroll wheel to navigate images')
33 |
34 | self.X = X
35 | rows, cols, self.slices = X.shape
36 | self.ind = self.slices//2
37 |
38 | self.im = ax.imshow(self.X[:, :, self.ind],cmap= 'gray')
39 | self.update()
40 |
41 | def onscroll(self, event):
42 | print("%s %s" % (event.button, event.step))
43 | if event.button == 'up':
44 | self.ind = (self.ind + 1) % self.slices
45 | else:
46 | self.ind = (self.ind - 1) % self.slices
47 | self.update()
48 |
49 | def update(self):
50 | self.im.set_data(self.X[:, :, self.ind])
51 | self.ax.set_ylabel('slice %s' % self.ind)
52 | self.im.axes.figure.canvas.draw()
53 |
54 |
55 | def plot3d(image):
56 | original=image
57 | original = np.rot90(original, k=-1)
58 | fig, ax = plt.subplots(1, 1)
59 | tracker = IndexTracker(ax, original)
60 | fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
61 | plt.show()
62 |
63 |
64 | if __name__ == "__main__":
65 |
66 | opt = Options().parse()
67 |
68 | train_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
69 | train_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))
70 |
71 | data_dicts = [{'image': image_name, 'label': label_name}
72 | for image_name, label_name in zip(train_images, train_segs)]
73 |
74 | monai_transforms = [
75 |
76 | LoadImaged(keys=['image', 'label']),
77 | AddChanneld(keys=['image', 'label']),
78 | CropForegroundd(keys=['image', 'label'], source_key='image', start_coord_key='foreground_start_coord',
79 | end_coord_key='foreground_end_coord', ), # crop CropForeground
80 | ThresholdIntensityd(keys=['image'],threshold=-350, above=True, cval=-350),
81 | ThresholdIntensityd(keys=['image'], threshold=350, above=False, cval=350),
82 |
83 | NormalizeIntensityd(keys=['image']),
84 | ScaleIntensityd(keys=['image']),
85 | Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
86 |
87 | SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),
88 | RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
89 | # Orientationd(keys=["image", "label"], axcodes="PIL"),
90 | ToTensord(keys=['image', 'label','foreground_start_coord', 'foreground_end_coord'],)
91 | ]
92 |
93 | transform = Compose(monai_transforms)
94 | check_ds = monai.data.Dataset(data=data_dicts, transform=transform)
95 | loader = DataLoader(check_ds, batch_size=1, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
96 | check_data = monai.utils.misc.first(loader)
97 | im, seg, coord1, coord2 = (check_data['image'][0], check_data['label'][0],check_data['foreground_start_coord'][0],
98 | check_data['foreground_end_coord'][0])
99 |
100 | print(im.shape, seg.shape, coord1, coord2)
101 |
102 |
103 | vol = im[0].numpy()
104 | mask = seg[0].numpy()
105 |
106 | print(vol.shape, mask.shape)
107 | plot3d(vol)
108 | plot3d(mask)
109 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/check_resolution.py:
--------------------------------------------------------------------------------
1 | # from NiftiDataset import *
2 | import argparse
3 | import SimpleITK as sitk
4 | import re
5 | import numpy as np
6 | import os
7 |
8 | '''Check if the images and the labels have different size after resampling (or not) them to the same resolution'''
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--images', default='./Data_folder/CT', help='path to the images')
12 | parser.add_argument('--labels', default='./Data_folder/CT_label', help='path to the labels')
13 | parser.add_argument("--resample", action='store_true', default=True, help='Decide or not to resample the images to a new resolution')
14 | parser.add_argument("--new_resolution", type=float, default=((1.3671875, 1.3671875, 3.0)), help='New resolution')
15 | args = parser.parse_args()
16 |
17 | def resize(img, new_size, interpolator):
18 | # img = sitk.ReadImage(img)
19 | dimension = img.GetDimension()
20 |
21 | # Physical image size corresponds to the largest physical size in the training set, or any other arbitrary size.
22 | reference_physical_size = np.zeros(dimension)
23 |
24 | reference_physical_size[:] = [(sz - 1) * spc if sz * spc > mx else mx for sz, spc, mx in
25 | zip(img.GetSize(), img.GetSpacing(), reference_physical_size)]
26 |
27 | # Create the reference image with a zero origin, identity direction cosine matrix and dimension
28 | reference_origin = np.zeros(dimension)
29 | reference_direction = np.identity(dimension).flatten()
30 | reference_size = new_size
31 | reference_spacing = [phys_sz / (sz - 1) for sz, phys_sz in zip(reference_size, reference_physical_size)]
32 |
33 | reference_image = sitk.Image(reference_size, img.GetPixelIDValue())
34 | reference_image.SetOrigin(reference_origin)
35 | reference_image.SetSpacing(reference_spacing)
36 | reference_image.SetDirection(reference_direction)
37 |
38 | # Always use the TransformContinuousIndexToPhysicalPoint to compute an indexed point's physical coordinates as
39 | # this takes into account size, spacing and direction cosines. For the vast majority of images the direction
40 | # cosines are the identity matrix, but when this isn't the case simply multiplying the central index by the
41 | # spacing will not yield the correct coordinates resulting in a long debugging session.
42 | reference_center = np.array(
43 | reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize()) / 2.0))
44 |
45 | # Transform which maps from the reference_image to the current img with the translation mapping the image
46 | # origins to each other.
47 | transform = sitk.AffineTransform(dimension)
48 | transform.SetMatrix(img.GetDirection())
49 | transform.SetTranslation(np.array(img.GetOrigin()) - reference_origin)
50 | # Modify the transformation to align the centers of the original and reference image instead of their origins.
51 | centering_transform = sitk.TranslationTransform(dimension)
52 | img_center = np.array(img.TransformContinuousIndexToPhysicalPoint(np.array(img.GetSize()) / 2.0))
53 | centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
54 | centered_transform = sitk.Transform(transform)
55 | centered_transform.AddTransform(centering_transform)
56 | # Using the linear interpolator as these are intensity images, if there is a need to resample a ground truth
57 | # segmentation then the segmentation image should be resampled using the NearestNeighbor interpolator so that
58 | # no new labels are introduced.
59 |
60 | return sitk.Resample(img, reference_image, centered_transform, interpolator, 0.0)
61 |
62 |
63 | def resample_sitk_image(sitk_image, spacing=None, interpolator=None, fill_value=0):
64 | # https://github.com/SimpleITK/SlicerSimpleFilters/blob/master/SimpleFilters/SimpleFilters.py
65 | _SITK_INTERPOLATOR_DICT = {
66 | 'nearest': sitk.sitkNearestNeighbor,
67 | 'linear': sitk.sitkLinear,
68 | 'gaussian': sitk.sitkGaussian,
69 | 'label_gaussian': sitk.sitkLabelGaussian,
70 | 'bspline': sitk.sitkBSpline,
71 | 'hamming_sinc': sitk.sitkHammingWindowedSinc,
72 | 'cosine_windowed_sinc': sitk.sitkCosineWindowedSinc,
73 | 'welch_windowed_sinc': sitk.sitkWelchWindowedSinc,
74 | 'lanczos_windowed_sinc': sitk.sitkLanczosWindowedSinc
75 | }
76 |
77 | """Resamples an ITK image to a new grid. If no spacing is given,
78 | the resampling is done isotropically to the smallest value in the current
79 | spacing. This is usually the in-plane resolution. If not given, the
80 | interpolation is derived from the input data type. Binary input
81 | (e.g., masks) are resampled with nearest neighbors, otherwise linear
82 | interpolation is chosen.
83 | Parameters
84 | ----------
85 | sitk_image : SimpleITK image or str
86 | Either a SimpleITK image or a path to a SimpleITK readable file.
87 | spacing : tuple
88 | Tuple of integers
89 | interpolator : str
90 | Either `nearest`, `linear` or None.
91 | fill_value : int
92 | Returns
93 | -------
94 | SimpleITK image.
95 | """
96 |
97 | if isinstance(sitk_image, str):
98 | sitk_image = sitk.ReadImage(sitk_image)
99 | num_dim = sitk_image.GetDimension()
100 |
101 | if not interpolator:
102 | interpolator = 'linear'
103 | pixelid = sitk_image.GetPixelIDValue()
104 |
105 | if pixelid not in [1, 2, 4]:
106 | raise NotImplementedError(
107 | 'Set `interpolator` manually, '
108 | 'can only infer for 8-bit unsigned or 16, 32-bit signed integers')
109 | if pixelid == 1: # 8-bit unsigned int
110 | interpolator = 'nearest'
111 |
112 | orig_pixelid = sitk_image.GetPixelIDValue()
113 | orig_origin = sitk_image.GetOrigin()
114 | orig_direction = sitk_image.GetDirection()
115 | orig_spacing = np.array(sitk_image.GetSpacing())
116 | orig_size = np.array(sitk_image.GetSize(), dtype=np.int)
117 |
118 | if not spacing:
119 | min_spacing = orig_spacing.min()
120 | new_spacing = [min_spacing] * num_dim
121 | else:
122 | new_spacing = [float(s) for s in spacing]
123 |
124 | assert interpolator in _SITK_INTERPOLATOR_DICT.keys(), \
125 | '`interpolator` should be one of {}'.format(_SITK_INTERPOLATOR_DICT.keys())
126 |
127 | sitk_interpolator = _SITK_INTERPOLATOR_DICT[interpolator]
128 |
129 | new_size = orig_size * (orig_spacing / new_spacing)
130 | new_size = np.ceil(new_size).astype(np.int) # Image dimensions are in integers
131 | new_size = [int(s) for s in new_size] # SimpleITK expects lists, not ndarrays
132 |
133 | resample_filter = sitk.ResampleImageFilter()
134 |
135 | resampled_sitk_image = resample_filter.Execute(sitk_image,
136 | new_size,
137 | sitk.Transform(),
138 | sitk_interpolator,
139 | orig_origin,
140 | new_spacing,
141 | orig_direction,
142 | fill_value,
143 | orig_pixelid)
144 |
145 | return resampled_sitk_image
146 |
147 |
148 |
149 | def numericalSort(value):
150 | numbers = re.compile(r'(\d+)')
151 | parts = numbers.split(value)
152 | parts[1::2] = map(int, parts[1::2])
153 | return parts
154 |
155 |
156 | def lstFiles(Path):
157 |
158 | images_list = [] # create an empty list, the raw image data files is stored here
159 | for dirName, subdirList, fileList in os.walk(Path):
160 | for filename in fileList:
161 | if ".nii.gz" in filename.lower():
162 | images_list.append(os.path.join(dirName, filename))
163 | elif ".nii" in filename.lower():
164 | images_list.append(os.path.join(dirName, filename))
165 | elif ".mhd" in filename.lower():
166 | images_list.append(os.path.join(dirName, filename))
167 |
168 | images_list = sorted(images_list, key=numericalSort)
169 |
170 | return images_list
171 |
172 | list_images = lstFiles(args.images)
173 | list_labels = lstFiles(args.labels)
174 |
175 | for i in range(len(list_images)):
176 |
177 | a = sitk.ReadImage(list_images[i])
178 | if args.resample is True:
179 | a = resample_sitk_image(a, spacing=args.new_resolution, interpolator='linear')
180 | spacing1 = a.GetSpacing()
181 | a = sitk.GetArrayFromImage(a)
182 | a = np.transpose(a, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
183 | a1 = a.shape
184 |
185 | b = sitk.ReadImage(list_labels[i])
186 | if args.resample is True:
187 | b = resample_sitk_image(b, spacing=args.new_resolution, interpolator='nearest')
188 |
189 | b = resize(b,a1,sitk.sitkNearestNeighbor)
190 | spacing2 = b.GetSpacing()
191 | b = sitk.GetArrayFromImage(b)
192 | b = np.transpose(b, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
193 | b1 = b.shape
194 |
195 | print(list_images[i], a1)
196 |
197 | if a1 != b1:
198 | print('Mismatch of size in ', list_images[i])
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 | # a=sitk.ReadImage('aaaaaa.nii')
242 | # a = sitk.GetArrayFromImage(a)
243 | # a = np.transpose(a, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
244 | # result = np.rot90(a, k=-1)
245 | # fig, ax = plt.subplots(1, 1)
246 | # tracker = IndexTracker(ax, result)
247 | # fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
248 | # plt.show()
249 |
250 | # a=sitk.ReadImage(labels[36])
251 | # a = sitk.GetArrayFromImage(a)
252 | # a = np.transpose(a, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
253 | # result = np.rot90(a, k=-1)
254 | # fig, ax = plt.subplots(1, 1)
255 | # tracker = IndexTracker(ax, result)
256 | # fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
257 | # plt.show()
258 |
259 |
260 |
261 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/docker commands_.txt:
--------------------------------------------------------------------------------
1 |
2 |
3 | 2)
4 | docker run --gpus all --rm --name monai_david -v /ceph/zmpbmt.meduniwien.ac.at/p_OENB2018/Codes/:/data/tensorflow/ -ti --ipc=host projectmonai/monai:latest
5 |
6 |
7 | 3) cd /data/tensorflow/Segmentation_deep_learning/POODLE_FET
8 |
9 |
10 | 4) pip install -r requirements.txt
11 |
12 | docker stop
13 |
14 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/image.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/image.gif
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/label.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/label.gif
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/prostate.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/prostate.gif
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/prostate_inf.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/prostate_inf.gif
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/prostate_label.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/prostate_label.gif
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/result.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/result.gif
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/salmon.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/salmon.JPG
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/salmon2.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/salmon2.JPG
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/salmon3.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/salmon3.JPG
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/salmon4.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/salmon4.JPG
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/salmon5.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/salmon5.JPG
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/images/salmon6.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidiommi/Pytorch--3D-Medical-Images-Segmentation--SALMON/62d6e2b5ffcb7bde31675ca76e8bca25392bb988/monai 0.5.0/deprecated/images/salmon6.JPG
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/init.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import SimpleITK as sitk
4 |
5 | image = sitk.ReadImage('Data_folder/images/train/image0.nii')
6 | image_spacing = image.GetSpacing()
7 |
8 | class Options():
9 |
10 | """This class defines options used during both training and test time."""
11 |
12 | def __init__(self):
13 | """Reset the class; indicates the class hasn't been initailized"""
14 | self.initialized = False
15 |
16 | def initialize(self, parser):
17 |
18 | # basic parameters
19 | parser.add_argument('--images_folder', type=str, default='./Data_folder/images')
20 | parser.add_argument('--labels_folder', type=str, default='./Data_folder/labels')
21 | parser.add_argument('--increase_factor_data', default=4, help='Increase data number per epoch')
22 | parser.add_argument('--preload', type=str, default=None)
23 | parser.add_argument('--gpu_ids', type=str, default='2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
24 | parser.add_argument('--workers', default=8, type=int, help='number of data loading workers')
25 |
26 | # dataset parameters
27 | parser.add_argument('--patch_size', default=(128, 128, 64), help='Size of the patches extracted from the image')
28 | parser.add_argument('--spacing', default=image_spacing, help='Original Resolution')
29 | parser.add_argument('--resolution', default=None, help='New Resolution, if you want to resample the data')
30 | parser.add_argument('--batch_size', type=int, default=6, help='batch size')
31 | parser.add_argument('--in_channels', default=1, type=int, help='Channels of the input')
32 | parser.add_argument('--out_channels', default=1, type=int, help='Channels of the output')
33 |
34 | # training parameters
35 | parser.add_argument('--epochs', default=200, help='Number of epochs')
36 | parser.add_argument('--lr', default=0.001, help='Learning rate')
37 |
38 | # Inference
39 | # This is just a trick to make the predict script working
40 | parser.add_argument('--result', default=None, help='Keep this empty and go to predict_single_image script')
41 | parser.add_argument('--weights', default=None, help='Keep this empty and go to predict_single_image script')
42 |
43 | self.initialized = True
44 | return parser
45 |
46 | def parse(self):
47 | if not self.initialized:
48 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
49 | parser = self.initialize(parser)
50 | opt = parser.parse_args()
51 | # set gpu ids
52 | if opt.gpu_ids != '-1':
53 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
54 | return opt
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/multi_label_segmentation_example/check_loader_patches.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | import os
13 | import sys
14 | from glob import glob
15 | import tempfile
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 | import nibabel as nib
19 | import torch
20 | from torch.utils.data import DataLoader
21 | from init import Options
22 | import monai
23 | from monai.data import ArrayDataset, GridPatchDataset, create_test_image_3d
24 | from monai.transforms import (Compose, RepeatChanneld, LoadImaged, AddChanneld, Transpose, Resized, CropForegroundd, CastToTyped,RandGaussianSmoothd,
25 | ScaleIntensityd, ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, SpatialPadd,
26 | Spacingd, Orientationd, RandZoomd, ThresholdIntensityd, RandShiftIntensityd, RandGaussianNoised, BorderPadd,RandAdjustContrastd, NormalizeIntensityd,RandFlipd, ScaleIntensityRanged)
27 |
28 |
29 | class IndexTracker(object):
30 | def __init__(self, ax, X):
31 | self.ax = ax
32 | ax.set_title('use scroll wheel to navigate images')
33 |
34 | self.X = X
35 | rows, cols, self.slices = X.shape
36 | self.ind = self.slices//2
37 |
38 | self.im = ax.imshow(self.X[:, :, self.ind],cmap= 'gray')
39 | self.update()
40 |
41 | def onscroll(self, event):
42 | print("%s %s" % (event.button, event.step))
43 | if event.button == 'up':
44 | self.ind = (self.ind + 1) % self.slices
45 | else:
46 | self.ind = (self.ind - 1) % self.slices
47 | self.update()
48 |
49 | def update(self):
50 | self.im.set_data(self.X[:, :, self.ind])
51 | self.ax.set_ylabel('slice %s' % self.ind)
52 | self.im.axes.figure.canvas.draw()
53 |
54 |
55 | def plot3d(image):
56 | original=image
57 | original = np.rot90(original, k=-1)
58 | fig, ax = plt.subplots(1, 1)
59 | tracker = IndexTracker(ax, original)
60 | fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
61 | plt.show()
62 |
63 |
64 | if __name__ == "__main__":
65 |
66 | opt = Options().parse()
67 |
68 | train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
69 | train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))
70 |
71 | data_dicts = [{'image': image_name, 'label': label_name}
72 | for image_name, label_name in zip(train_images, train_segs)]
73 |
74 | monai_transforms = [
75 |
76 | LoadImaged(keys=['image', 'label']),
77 | AddChanneld(keys=['image', 'label']),
78 | CropForegroundd(keys=['image', 'label'], source_key='image', start_coord_key='foreground_start_coord',
79 | end_coord_key='foreground_end_coord', ), # crop CropForeground
80 |
81 | NormalizeIntensityd(keys=['image']),
82 | ScaleIntensityd(keys=['image']),
83 | Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
84 |
85 | SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),
86 | # RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
87 | # Orientationd(keys=["image", "label"], axcodes="PIL"),
88 | ToTensord(keys=['image', 'label','foreground_start_coord', 'foreground_end_coord'],)
89 | ]
90 |
91 | transform = Compose(monai_transforms)
92 | check_ds = monai.data.Dataset(data=data_dicts, transform=transform)
93 | loader = DataLoader(check_ds, batch_size=1, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
94 | check_data = monai.utils.misc.first(loader)
95 | im, seg, coord1, coord2 = (check_data['image'][0], check_data['label'][0],check_data['foreground_start_coord'][0],
96 | check_data['foreground_end_coord'][0])
97 |
98 | print(im.shape, seg.shape, coord1, coord2)
99 |
100 |
101 | vol = im[0].numpy()
102 | mask = seg[0].numpy()
103 |
104 | print(vol.shape, mask.shape)
105 | plot3d(vol)
106 | plot3d(mask)
107 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/multi_label_segmentation_example/check_resolution.py:
--------------------------------------------------------------------------------
1 | # from NiftiDataset import *
2 | import argparse
3 | import SimpleITK as sitk
4 | import re
5 | import numpy as np
6 | import os
7 |
8 | '''Check if the images and the labels have different size after resampling (or not) them to the same resolution'''
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--images', default='./Data_folder/CT', help='path to the images')
12 | parser.add_argument('--labels', default='./Data_folder/CT_label', help='path to the labels')
13 | parser.add_argument("--resample", action='store_true', default=True, help='Decide or not to resample the images to a new resolution')
14 | parser.add_argument("--new_resolution", type=float, default=((1.3671875, 1.3671875, 3.0)), help='New resolution')
15 | args = parser.parse_args()
16 |
17 | def resize(img, new_size, interpolator):
18 | # img = sitk.ReadImage(img)
19 | dimension = img.GetDimension()
20 |
21 | # Physical image size corresponds to the largest physical size in the training set, or any other arbitrary size.
22 | reference_physical_size = np.zeros(dimension)
23 |
24 | reference_physical_size[:] = [(sz - 1) * spc if sz * spc > mx else mx for sz, spc, mx in
25 | zip(img.GetSize(), img.GetSpacing(), reference_physical_size)]
26 |
27 | # Create the reference image with a zero origin, identity direction cosine matrix and dimension
28 | reference_origin = np.zeros(dimension)
29 | reference_direction = np.identity(dimension).flatten()
30 | reference_size = new_size
31 | reference_spacing = [phys_sz / (sz - 1) for sz, phys_sz in zip(reference_size, reference_physical_size)]
32 |
33 | reference_image = sitk.Image(reference_size, img.GetPixelIDValue())
34 | reference_image.SetOrigin(reference_origin)
35 | reference_image.SetSpacing(reference_spacing)
36 | reference_image.SetDirection(reference_direction)
37 |
38 | # Always use the TransformContinuousIndexToPhysicalPoint to compute an indexed point's physical coordinates as
39 | # this takes into account size, spacing and direction cosines. For the vast majority of images the direction
40 | # cosines are the identity matrix, but when this isn't the case simply multiplying the central index by the
41 | # spacing will not yield the correct coordinates resulting in a long debugging session.
42 | reference_center = np.array(
43 | reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize()) / 2.0))
44 |
45 | # Transform which maps from the reference_image to the current img with the translation mapping the image
46 | # origins to each other.
47 | transform = sitk.AffineTransform(dimension)
48 | transform.SetMatrix(img.GetDirection())
49 | transform.SetTranslation(np.array(img.GetOrigin()) - reference_origin)
50 | # Modify the transformation to align the centers of the original and reference image instead of their origins.
51 | centering_transform = sitk.TranslationTransform(dimension)
52 | img_center = np.array(img.TransformContinuousIndexToPhysicalPoint(np.array(img.GetSize()) / 2.0))
53 | centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
54 | centered_transform = sitk.Transform(transform)
55 | centered_transform.AddTransform(centering_transform)
56 | # Using the linear interpolator as these are intensity images, if there is a need to resample a ground truth
57 | # segmentation then the segmentation image should be resampled using the NearestNeighbor interpolator so that
58 | # no new labels are introduced.
59 |
60 | return sitk.Resample(img, reference_image, centered_transform, interpolator, 0.0)
61 |
62 |
63 | def resample_sitk_image(sitk_image, spacing=None, interpolator=None, fill_value=0):
64 | # https://github.com/SimpleITK/SlicerSimpleFilters/blob/master/SimpleFilters/SimpleFilters.py
65 | _SITK_INTERPOLATOR_DICT = {
66 | 'nearest': sitk.sitkNearestNeighbor,
67 | 'linear': sitk.sitkLinear,
68 | 'gaussian': sitk.sitkGaussian,
69 | 'label_gaussian': sitk.sitkLabelGaussian,
70 | 'bspline': sitk.sitkBSpline,
71 | 'hamming_sinc': sitk.sitkHammingWindowedSinc,
72 | 'cosine_windowed_sinc': sitk.sitkCosineWindowedSinc,
73 | 'welch_windowed_sinc': sitk.sitkWelchWindowedSinc,
74 | 'lanczos_windowed_sinc': sitk.sitkLanczosWindowedSinc
75 | }
76 |
77 | """Resamples an ITK image to a new grid. If no spacing is given,
78 | the resampling is done isotropically to the smallest value in the current
79 | spacing. This is usually the in-plane resolution. If not given, the
80 | interpolation is derived from the input data type. Binary input
81 | (e.g., masks) are resampled with nearest neighbors, otherwise linear
82 | interpolation is chosen.
83 | Parameters
84 | ----------
85 | sitk_image : SimpleITK image or str
86 | Either a SimpleITK image or a path to a SimpleITK readable file.
87 | spacing : tuple
88 | Tuple of integers
89 | interpolator : str
90 | Either `nearest`, `linear` or None.
91 | fill_value : int
92 | Returns
93 | -------
94 | SimpleITK image.
95 | """
96 |
97 | if isinstance(sitk_image, str):
98 | sitk_image = sitk.ReadImage(sitk_image)
99 | num_dim = sitk_image.GetDimension()
100 |
101 | if not interpolator:
102 | interpolator = 'linear'
103 | pixelid = sitk_image.GetPixelIDValue()
104 |
105 | if pixelid not in [1, 2, 4]:
106 | raise NotImplementedError(
107 | 'Set `interpolator` manually, '
108 | 'can only infer for 8-bit unsigned or 16, 32-bit signed integers')
109 | if pixelid == 1: # 8-bit unsigned int
110 | interpolator = 'nearest'
111 |
112 | orig_pixelid = sitk_image.GetPixelIDValue()
113 | orig_origin = sitk_image.GetOrigin()
114 | orig_direction = sitk_image.GetDirection()
115 | orig_spacing = np.array(sitk_image.GetSpacing())
116 | orig_size = np.array(sitk_image.GetSize(), dtype=np.int)
117 |
118 | if not spacing:
119 | min_spacing = orig_spacing.min()
120 | new_spacing = [min_spacing] * num_dim
121 | else:
122 | new_spacing = [float(s) for s in spacing]
123 |
124 | assert interpolator in _SITK_INTERPOLATOR_DICT.keys(), \
125 | '`interpolator` should be one of {}'.format(_SITK_INTERPOLATOR_DICT.keys())
126 |
127 | sitk_interpolator = _SITK_INTERPOLATOR_DICT[interpolator]
128 |
129 | new_size = orig_size * (orig_spacing / new_spacing)
130 | new_size = np.ceil(new_size).astype(np.int) # Image dimensions are in integers
131 | new_size = [int(s) for s in new_size] # SimpleITK expects lists, not ndarrays
132 |
133 | resample_filter = sitk.ResampleImageFilter()
134 |
135 | resampled_sitk_image = resample_filter.Execute(sitk_image,
136 | new_size,
137 | sitk.Transform(),
138 | sitk_interpolator,
139 | orig_origin,
140 | new_spacing,
141 | orig_direction,
142 | fill_value,
143 | orig_pixelid)
144 |
145 | return resampled_sitk_image
146 |
147 |
148 |
149 | def numericalSort(value):
150 | numbers = re.compile(r'(\d+)')
151 | parts = numbers.split(value)
152 | parts[1::2] = map(int, parts[1::2])
153 | return parts
154 |
155 |
156 | def lstFiles(Path):
157 |
158 | images_list = [] # create an empty list, the raw image data files is stored here
159 | for dirName, subdirList, fileList in os.walk(Path):
160 | for filename in fileList:
161 | if ".nii.gz" in filename.lower():
162 | images_list.append(os.path.join(dirName, filename))
163 | elif ".nii" in filename.lower():
164 | images_list.append(os.path.join(dirName, filename))
165 | elif ".mhd" in filename.lower():
166 | images_list.append(os.path.join(dirName, filename))
167 |
168 | images_list = sorted(images_list, key=numericalSort)
169 |
170 | return images_list
171 |
172 | list_images = lstFiles(args.images)
173 | list_labels = lstFiles(args.labels)
174 |
175 | for i in range(len(list_images)):
176 |
177 | a = sitk.ReadImage(list_images[i])
178 | if args.resample is True:
179 | a = resample_sitk_image(a, spacing=args.new_resolution, interpolator='linear')
180 | spacing1 = a.GetSpacing()
181 | a = sitk.GetArrayFromImage(a)
182 | a = np.transpose(a, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
183 | a1 = a.shape
184 |
185 | b = sitk.ReadImage(list_labels[i])
186 | if args.resample is True:
187 | b = resample_sitk_image(b, spacing=args.new_resolution, interpolator='nearest')
188 |
189 | b = resize(b,a1,sitk.sitkNearestNeighbor)
190 | spacing2 = b.GetSpacing()
191 | b = sitk.GetArrayFromImage(b)
192 | b = np.transpose(b, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
193 | b1 = b.shape
194 |
195 | print(list_images[i], a1)
196 |
197 | if a1 != b1:
198 | print('Mismatch of size in ', list_images[i])
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 | # a=sitk.ReadImage('aaaaaa.nii')
242 | # a = sitk.GetArrayFromImage(a)
243 | # a = np.transpose(a, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
244 | # result = np.rot90(a, k=-1)
245 | # fig, ax = plt.subplots(1, 1)
246 | # tracker = IndexTracker(ax, result)
247 | # fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
248 | # plt.show()
249 |
250 | # a=sitk.ReadImage(labels[36])
251 | # a = sitk.GetArrayFromImage(a)
252 | # a = np.transpose(a, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
253 | # result = np.rot90(a, k=-1)
254 | # fig, ax = plt.subplots(1, 1)
255 | # tracker = IndexTracker(ax, result)
256 | # fig.canvas.mpl_connect('scroll_event', tracker.onscroll)
257 | # plt.show()
258 |
259 |
260 |
261 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/multi_label_segmentation_example/init.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import SimpleITK as sitk
4 |
5 | image = sitk.ReadImage('Data_folder/images/train/image0.nii')
6 | image_spacing = image.GetSpacing()
7 |
8 | class Options():
9 |
10 | """This class defines options used during both training and test time."""
11 |
12 | def __init__(self):
13 | """Reset the class; indicates the class hasn't been initailized"""
14 | self.initialized = False
15 |
16 | def initialize(self, parser):
17 |
18 | # basic parameters
19 | parser.add_argument('--images_folder', type=str, default='./Data_folder/images')
20 | parser.add_argument('--labels_folder', type=str, default='./Data_folder/labels')
21 | parser.add_argument('--increase_factor_data', default=3, help='Increase data number per epoch')
22 | parser.add_argument('--preload', type=str, default=None)
23 | parser.add_argument('--gpu_ids', type=str, default='0,1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
24 | parser.add_argument('--workers', default=8, type=int, help='number of data loading workers')
25 |
26 | # dataset parameters
27 | parser.add_argument('--patch_size', default=(256, 256, 32), help='Size of the patches extracted from the image')
28 | parser.add_argument('--spacing', default=image_spacing, help='Original Resolution')
29 | parser.add_argument('--resolution', default=(0.6, 0.6, 3), help='New Resolution, if you want to resample the data')
30 | parser.add_argument('--batch_size', type=int, default=4, help='batch size')
31 | parser.add_argument('--in_channels', default=1, type=int, help='Channels of the input')
32 | parser.add_argument('--out_channels', default=3, type=int, help='Channels of the output')
33 |
34 | # training parameters
35 | parser.add_argument('--epochs', default=200, help='Number of epochs')
36 | parser.add_argument('--lr', default=0.01, help='Learning rate')
37 |
38 | # Inference
39 | # This is just a trick to make the predict script working
40 | parser.add_argument('--result', default=None, help='Keep this empty and go to predict_single_image script')
41 | parser.add_argument('--weights', default=None, help='Keep this empty and go to predict_single_image script')
42 |
43 | self.initialized = True
44 | return parser
45 |
46 | def parse(self):
47 | if not self.initialized:
48 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
49 | parser = self.initialize(parser)
50 | opt = parser.parse_args()
51 | # set gpu ids
52 | if opt.gpu_ids != '-1':
53 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
54 | return opt
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/multi_label_segmentation_example/networks.py:
--------------------------------------------------------------------------------
1 | from train import *
2 | from torch.nn import init
3 | import monai
4 | from torch.optim import lr_scheduler
5 |
6 |
7 | def init_weights(net, init_type='normal', init_gain=0.02):
8 | """Initialize network weights.
9 | Parameters:
10 | net (network) -- network to be initialized
11 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
12 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
13 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
14 | work better for some applications. Feel free to try yourself.
15 | """
16 | def init_func(m): # define the initialization function
17 | classname = m.__class__.__name__
18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19 | if init_type == 'normal':
20 | init.normal_(m.weight.data, 0.0, init_gain)
21 | elif init_type == 'xavier':
22 | init.xavier_normal_(m.weight.data, gain=init_gain)
23 | elif init_type == 'kaiming':
24 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
25 | elif init_type == 'orthogonal':
26 | init.orthogonal_(m.weight.data, gain=init_gain)
27 | else:
28 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
29 | if hasattr(m, 'bias') and m.bias is not None:
30 | init.constant_(m.bias.data, 0.0)
31 | elif classname.find('BatchNorm3d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
32 | init.normal_(m.weight.data, 1.0, init_gain)
33 | init.constant_(m.bias.data, 0.0)
34 |
35 | # print('initialize network with %s' % init_type)
36 | net.apply(init_func) # apply the initialization function
37 |
38 |
39 | def get_scheduler(optimizer, opt):
40 | if opt.lr_policy == 'lambda':
41 | def lambda_rule(epoch):
42 | # lr_l = 1.0 - max(0, epoch + 1 - opt.epochs/2) / float(opt.epochs/2 + 1)
43 | lr_l = (1 - epoch / opt.epochs) ** 0.9
44 | return lr_l
45 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
46 | elif opt.lr_policy == 'step':
47 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
48 | elif opt.lr_policy == 'plateau':
49 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
50 | elif opt.lr_policy == 'cosine':
51 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs, eta_min=0)
52 | else:
53 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
54 | return scheduler
55 |
56 |
57 | # update learning rate (called once every epoch)
58 | def update_learning_rate(scheduler, optimizer):
59 | scheduler.step()
60 | lr = optimizer.param_groups[0]['lr']
61 | # print('learning rate = %.7f' % lr)
62 |
63 |
64 | from torch.nn import Module, Sequential
65 | from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d
66 | from torch.nn import ReLU, Sigmoid
67 | import torch
68 |
69 |
70 | def build_net():
71 |
72 | from init import Options
73 | opt = Options().parse()
74 | from monai.networks.layers import Norm
75 | from monai.networks.layers.factories import split_args
76 | act_type, args = split_args("RELU")
77 |
78 | # # create Unet
79 | # Unet = monai.networks.nets.UNet(
80 | # dimensions=3,
81 | # in_channels=opt.in_channels,
82 | # out_channels=opt.out_channels,
83 | # channels=(64, 128, 256, 512, 1024),
84 | # strides=(2, 2, 2, 2),
85 | # act=act_type,
86 | # num_res_units=3,
87 | # dropout=0.2,
88 | # norm=Norm.BATCH,
89 | #
90 | # )
91 |
92 | # create nn-Unet
93 | if opt.resolution is None:
94 | sizes, spacings = opt.patch_size, opt.spacing
95 | else:
96 | sizes, spacings = opt.patch_size, opt.resolution
97 |
98 | strides, kernels = [], []
99 |
100 | while True:
101 | spacing_ratio = [sp / min(spacings) for sp in spacings]
102 | stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
103 | kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
104 | if all(s == 1 for s in stride):
105 | break
106 | sizes = [i / j for i, j in zip(sizes, stride)]
107 | spacings = [i * j for i, j in zip(spacings, stride)]
108 | kernels.append(kernel)
109 | strides.append(stride)
110 | strides.insert(0, len(spacings) * [1])
111 | kernels.append(len(spacings) * [3])
112 |
113 | nn_Unet = monai.networks.nets.DynUNet(
114 | spatial_dims=3,
115 | in_channels=opt.in_channels,
116 | out_channels=opt.out_channels,
117 | kernel_size=kernels,
118 | strides=strides,
119 | upsample_kernel_size=strides[1:],
120 | res_block=True,
121 | )
122 |
123 | init_weights(nn_Unet, init_type='normal')
124 |
125 | return nn_Unet
126 |
127 |
128 | if __name__ == '__main__':
129 | import time
130 | import torch
131 | from torch.autograd import Variable
132 | from torchsummaryX import summary
133 | from torch.nn import init
134 |
135 | opt = Options().parse()
136 |
137 | torch.cuda.set_device(0)
138 | network = build_net()
139 | net = network.cuda().eval()
140 |
141 | data = Variable(torch.randn(1, int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda()
142 |
143 | out = net(data)
144 |
145 | torch.onnx.export(net, data, "Unet_model_graph.onnx")
146 |
147 | summary(net,data)
148 | print("out size: {}".format(out.size()))
149 |
150 |
151 |
152 |
153 |
154 |
155 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/multi_label_segmentation_example/organize_folder_structure.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import argparse
4 | import SimpleITK as sitk
5 | import numpy as np
6 | import random
7 |
8 |
9 | def resize(img, new_size, interpolator):
10 | # img = sitk.ReadImage(img)
11 | dimension = img.GetDimension()
12 |
13 | # Physical image size corresponds to the largest physical size in the training set, or any other arbitrary size.
14 | reference_physical_size = np.zeros(dimension)
15 |
16 | reference_physical_size[:] = [(sz - 1) * spc if sz * spc > mx else mx for sz, spc, mx in
17 | zip(img.GetSize(), img.GetSpacing(), reference_physical_size)]
18 |
19 | # Create the reference image with a zero origin, identity direction cosine matrix and dimension
20 | reference_origin = np.zeros(dimension)
21 | reference_direction = np.identity(dimension).flatten()
22 | reference_size = new_size
23 | reference_spacing = [phys_sz / (sz - 1) for sz, phys_sz in zip(reference_size, reference_physical_size)]
24 |
25 | reference_image = sitk.Image(reference_size, img.GetPixelIDValue())
26 | reference_image.SetOrigin(reference_origin)
27 | reference_image.SetSpacing(reference_spacing)
28 | reference_image.SetDirection(reference_direction)
29 |
30 | # Always use the TransformContinuousIndexToPhysicalPoint to compute an indexed point's physical coordinates as
31 | # this takes into account size, spacing and direction cosines. For the vast majority of images the direction
32 | # cosines are the identity matrix, but when this isn't the case simply multiplying the central index by the
33 | # spacing will not yield the correct coordinates resulting in a long debugging session.
34 | reference_center = np.array(
35 | reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize()) / 2.0))
36 |
37 | # Transform which maps from the reference_image to the current img with the translation mapping the image
38 | # origins to each other.
39 | transform = sitk.AffineTransform(dimension)
40 | transform.SetMatrix(img.GetDirection())
41 | transform.SetTranslation(np.array(img.GetOrigin()) - reference_origin)
42 | # Modify the transformation to align the centers of the original and reference image instead of their origins.
43 | centering_transform = sitk.TranslationTransform(dimension)
44 | img_center = np.array(img.TransformContinuousIndexToPhysicalPoint(np.array(img.GetSize()) / 2.0))
45 | centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
46 | centered_transform = sitk.Transform(transform)
47 | centered_transform.AddTransform(centering_transform)
48 | # Using the linear interpolator as these are intensity images, if there is a need to resample a ground truth
49 | # segmentation then the segmentation image should be resampled using the NearestNeighbor interpolator so that
50 | # no new labels are introduced.
51 |
52 | return sitk.Resample(img, reference_image, centered_transform, interpolator, 0.0)
53 |
54 |
55 | def resample_sitk_image(sitk_image, spacing=None, interpolator=None, fill_value=0):
56 | # https://github.com/SimpleITK/SlicerSimpleFilters/blob/master/SimpleFilters/SimpleFilters.py
57 | _SITK_INTERPOLATOR_DICT = {
58 | 'nearest': sitk.sitkNearestNeighbor,
59 | 'linear': sitk.sitkLinear,
60 | 'gaussian': sitk.sitkGaussian,
61 | 'label_gaussian': sitk.sitkLabelGaussian,
62 | 'bspline': sitk.sitkBSpline,
63 | 'hamming_sinc': sitk.sitkHammingWindowedSinc,
64 | 'cosine_windowed_sinc': sitk.sitkCosineWindowedSinc,
65 | 'welch_windowed_sinc': sitk.sitkWelchWindowedSinc,
66 | 'lanczos_windowed_sinc': sitk.sitkLanczosWindowedSinc
67 | }
68 |
69 | """Resamples an ITK image to a new grid. If no spacing is given,
70 | the resampling is done isotropically to the smallest value in the current
71 | spacing. This is usually the in-plane resolution. If not given, the
72 | interpolation is derived from the input data type. Binary input
73 | (e.g., masks) are resampled with nearest neighbors, otherwise linear
74 | interpolation is chosen.
75 | Parameters
76 | ----------
77 | sitk_image : SimpleITK image or str
78 | Either a SimpleITK image or a path to a SimpleITK readable file.
79 | spacing : tuple
80 | Tuple of integers
81 | interpolator : str
82 | Either `nearest`, `linear` or None.
83 | fill_value : int
84 | Returns
85 | -------
86 | SimpleITK image.
87 | """
88 |
89 | if isinstance(sitk_image, str):
90 | sitk_image = sitk.ReadImage(sitk_image)
91 | num_dim = sitk_image.GetDimension()
92 |
93 | if not interpolator:
94 | interpolator = 'linear'
95 | pixelid = sitk_image.GetPixelIDValue()
96 |
97 | if pixelid not in [1, 2, 4]:
98 | raise NotImplementedError(
99 | 'Set `interpolator` manually, '
100 | 'can only infer for 8-bit unsigned or 16, 32-bit signed integers')
101 | if pixelid == 1: # 8-bit unsigned int
102 | interpolator = 'nearest'
103 |
104 | orig_pixelid = sitk_image.GetPixelIDValue()
105 | orig_origin = sitk_image.GetOrigin()
106 | orig_direction = sitk_image.GetDirection()
107 | orig_spacing = np.array(sitk_image.GetSpacing())
108 | orig_size = np.array(sitk_image.GetSize(), dtype=np.int)
109 |
110 | if not spacing:
111 | min_spacing = orig_spacing.min()
112 | new_spacing = [min_spacing] * num_dim
113 | else:
114 | new_spacing = [float(s) for s in spacing]
115 |
116 | assert interpolator in _SITK_INTERPOLATOR_DICT.keys(), \
117 | '`interpolator` should be one of {}'.format(_SITK_INTERPOLATOR_DICT.keys())
118 |
119 | sitk_interpolator = _SITK_INTERPOLATOR_DICT[interpolator]
120 |
121 | new_size = orig_size * (orig_spacing / new_spacing)
122 | new_size = np.ceil(new_size).astype(np.int) # Image dimensions are in integers
123 | new_size = [int(s) for s in new_size] # SimpleITK expects lists, not ndarrays
124 |
125 | resample_filter = sitk.ResampleImageFilter()
126 |
127 | resampled_sitk_image = resample_filter.Execute(sitk_image,
128 | new_size,
129 | sitk.Transform(),
130 | sitk_interpolator,
131 | orig_origin,
132 | new_spacing,
133 | orig_direction,
134 | fill_value,
135 | orig_pixelid)
136 |
137 | return resampled_sitk_image
138 |
139 |
140 | def numericalSort(value):
141 | numbers = re.compile(r'(\d+)')
142 | parts = numbers.split(value)
143 | parts[1::2] = map(int, parts[1::2])
144 | return parts
145 |
146 |
147 | def lstFiles(Path):
148 |
149 | images_list = [] # create an empty list, the raw image data files is stored here
150 | for dirName, subdirList, fileList in os.walk(Path):
151 | for filename in fileList:
152 | if ".nii.gz" in filename.lower():
153 | images_list.append(os.path.join(dirName, filename))
154 | elif ".nii" in filename.lower():
155 | images_list.append(os.path.join(dirName, filename))
156 | elif ".nrrd" in filename.lower():
157 | images_list.append(os.path.join(dirName, filename))
158 |
159 | images_list = sorted(images_list, key=numericalSort)
160 |
161 | return images_list
162 |
163 |
164 | def uniform_img_dimensions(image, label):
165 |
166 | image_array = sitk.GetArrayFromImage(image)
167 | image_array = np.transpose(image_array, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
168 | image_shape = image_array.shape
169 |
170 | label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='nearest')
171 | res = resize(label,image_shape,sitk.sitkNearestNeighbor)
172 | res = (np.rint(sitk.GetArrayFromImage(res)))
173 | res = sitk.GetImageFromArray(res.astype('uint8'))
174 | res.SetDirection(image.GetDirection())
175 | res.SetOrigin(image.GetOrigin())
176 | res.SetSpacing(image.GetSpacing())
177 | print(res.GetSize())
178 |
179 | return image, res
180 |
181 | parser = argparse.ArgumentParser()
182 | parser.add_argument('--images', default='./Data_folder/MR', help='path to the images')
183 | parser.add_argument('--labels', default='./Data_folder/MR_label', help='path to the labels')
184 | parser.add_argument('--split_val', default=8, help='number of images for validation')
185 | parser.add_argument('--split_test', default=7, help='number of images for testing')
186 | args = parser.parse_args()
187 |
188 | if __name__ == "__main__":
189 |
190 | list_images = lstFiles(args.images)
191 | list_labels = lstFiles(args.labels)
192 |
193 | # mapIndexPosition = list(zip(list_images, list_labels)) # shuffle order list
194 | # random.shuffle(mapIndexPosition)
195 | # list_images, list_labels = zip(*mapIndexPosition)
196 |
197 | os.mkdir('./Data_folder/images')
198 | os.mkdir('./Data_folder/labels')
199 |
200 | # 1
201 | if not os.path.isdir('./Data_folder/images/train'):
202 | os.mkdir('./Data_folder/images/train/')
203 | # 2
204 | if not os.path.isdir('./Data_folder/images/val'):
205 | os.mkdir('./Data_folder/images/val')
206 |
207 | # 3
208 | if not os.path.isdir('./Data_folder/images/test'):
209 | os.mkdir('./Data_folder/images/test')
210 |
211 | # 4
212 | if not os.path.isdir('./Data_folder/labels/train'):
213 | os.mkdir('./Data_folder/labels/train')
214 |
215 | # 5
216 | if not os.path.isdir('./Data_folder/labels/val'):
217 | os.mkdir('./Data_folder/labels/val')
218 |
219 | # 6
220 | if not os.path.isdir('./Data_folder/labels/test'):
221 | os.mkdir('./Data_folder/labels/test')
222 |
223 | for i in range(len(list_images)-int(args.split_test + args.split_val)):
224 |
225 | a = list_images[int(args.split_test + args.split_val)+i]
226 | b = list_labels[int(args.split_test + args.split_val)+i]
227 |
228 | print(a)
229 |
230 | label = sitk.ReadImage(b)
231 | image = sitk.ReadImage(a)
232 |
233 | image, label = uniform_img_dimensions(image, label)
234 |
235 | image_directory = os.path.join('./Data_folder/images/train', f"image{i:d}.nii")
236 | label_directory = os.path.join('./Data_folder/labels/train', f"label{i:d}.nii")
237 |
238 | sitk.WriteImage(image, image_directory)
239 | sitk.WriteImage(label, label_directory)
240 |
241 | for i in range(int(args.split_val)):
242 |
243 | a = list_images[int(args.split_test)+i]
244 | b = list_labels[int(args.split_test)+i]
245 |
246 | print(a)
247 |
248 | label = sitk.ReadImage(b)
249 | image = sitk.ReadImage(a)
250 |
251 | image, label = uniform_img_dimensions(image, label)
252 |
253 | image_directory = os.path.join('./Data_folder/images/val', f"image{i:d}.nii")
254 | label_directory = os.path.join('./Data_folder/labels/val', f"label{i:d}.nii")
255 |
256 | sitk.WriteImage(image, image_directory)
257 | sitk.WriteImage(label, label_directory)
258 |
259 | for i in range(int(args.split_test)):
260 |
261 | a = list_images[i]
262 | b = list_labels[i]
263 |
264 | print(a)
265 |
266 | label = sitk.ReadImage(b)
267 | image = sitk.ReadImage(a)
268 |
269 | image, label = uniform_img_dimensions(image, label)
270 |
271 | image_directory = os.path.join('./Data_folder/images/test', f"image{i:d}.nii")
272 | label_directory = os.path.join('./Data_folder/labels/test', f"label{i:d}.nii")
273 |
274 | sitk.WriteImage(image, image_directory)
275 | sitk.WriteImage(label, label_directory)
276 |
277 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/multi_label_segmentation_example/predict_single_image.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | # -*- coding: utf-8 -*-
3 |
4 | from train import *
5 | import argparse
6 | from networks import *
7 | import SimpleITK as sitk
8 | from monai.inferers import sliding_window_inference
9 | from monai.metrics import DiceMetric
10 | from monai.data import NiftiSaver, create_test_image_3d, list_data_collate
11 | from collections import OrderedDict
12 | from organize_folder_structure import resize, resample_sitk_image, uniform_img_dimensions
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--image", type=str, default='./Data_folder/images/test/image0.nii')
17 | parser.add_argument("--label", type=str, default='./Data_folder/labels/test/label0.nii')
18 | parser.add_argument("--result", type=str, default='./Data_folder/test.nii', help='path to the .nii result to save')
19 | parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load')
20 | parser.add_argument("--resolution", default=(0.6, 0.6, 3), help='New resolution if you want to resample')
21 | parser.add_argument("--out_channels", default=3, help='Number of labels')
22 | parser.add_argument("--patch_size", type=int, nargs=3, default=(256, 256, 32), help="Input dimension for the generator")
23 | parser.add_argument('--gpu_ids', type=str, default='2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
24 | args = parser.parse_args()
25 |
26 |
27 | def new_state_dict(file_name):
28 | state_dict = torch.load(file_name)
29 | new_state_dict = OrderedDict()
30 | for k, v in state_dict.items():
31 | if k[:6] == 'module':
32 | name = k[7:]
33 | new_state_dict[name] = v
34 | else:
35 | new_state_dict[k] = v
36 | return new_state_dict
37 |
38 |
39 | def from_numpy_to_itk(image_np, image_itk):
40 |
41 | # read image file
42 | reader = sitk.ImageFileReader()
43 | reader.SetFileName(image_itk)
44 | image_itk = reader.Execute()
45 |
46 | image_np = np.transpose(image_np, (2, 1, 0))
47 | image = sitk.GetImageFromArray(image_np)
48 | image.SetDirection(image_itk.GetDirection())
49 | image.SetSpacing(image_itk.GetSpacing())
50 | image.SetOrigin(image_itk.GetOrigin())
51 | return image
52 |
53 |
54 | # function to keep track of the cropped area and coordinates
55 | def statistics_crop(image, resolution):
56 |
57 | files = [{"image": image}]
58 |
59 | reader = sitk.ImageFileReader()
60 | reader.SetFileName(image)
61 | image_itk = reader.Execute()
62 | original_resolution = image_itk.GetSpacing()
63 |
64 | # original size
65 | transforms = Compose([
66 | LoadImaged(keys=['image']),
67 | AddChanneld(keys=['image']),
68 | ToTensord(keys=['image'])])
69 | data = monai.data.Dataset(data=files, transform=transforms)
70 | loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
71 | loader = monai.utils.misc.first(loader)
72 | im, = (loader['image'][0])
73 | vol = im.numpy()
74 | original_shape = vol.shape
75 |
76 | # cropped foreground size
77 | transforms = Compose([
78 | LoadImaged(keys=['image']),
79 | AddChanneld(keys=['image']),
80 | CropForegroundd(keys=['image'], source_key='image', start_coord_key='foreground_start_coord',
81 | end_coord_key='foreground_end_coord', ), # crop CropForeground
82 | ToTensord(keys=['image', 'foreground_start_coord', 'foreground_end_coord'])])
83 | data = monai.data.Dataset(data=files, transform=transforms)
84 | loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
85 | loader = monai.utils.misc.first(loader)
86 | im, coord1, coord2 = (loader['image'][0], loader['foreground_start_coord'][0], loader['foreground_end_coord'][0])
87 | vol = im[0].numpy()
88 | coord1 = coord1.numpy()
89 | coord2 = coord2.numpy()
90 | crop_shape = vol.shape
91 |
92 | if resolution is not None:
93 |
94 | transforms = Compose([
95 | LoadImaged(keys=['image']),
96 | AddChanneld(keys=['image']),
97 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
98 | Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution
99 | ToTensord(keys=['image'])])
100 |
101 | data = monai.data.Dataset(data=files, transform=transforms)
102 | loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
103 | loader = monai.utils.misc.first(loader)
104 | im, = (loader['image'][0])
105 | vol = im.numpy()
106 | resampled_size = vol.shape
107 |
108 | else:
109 |
110 | resampled_size = original_shape
111 |
112 | return original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution
113 |
114 |
115 | def segment(image, label, result, weights, resolution, patch_size, channels):
116 |
117 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
118 |
119 | if label is not None:
120 | files = [{"image": image, "label": label}]
121 | else:
122 | files = [{"image": image}]
123 |
124 | # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
125 | original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(image, resolution)
126 |
127 | # -------------------------------
128 |
129 | if label is not None:
130 | if resolution is not None:
131 |
132 | val_transforms = Compose([
133 | LoadImaged(keys=['image', 'label']),
134 | AddChanneld(keys=['image', 'label']),
135 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
136 |
137 | NormalizeIntensityd(keys=['image']), # intensity
138 | ScaleIntensityd(keys=['image']),
139 | Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution
140 |
141 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method= 'end'),
142 | ToTensord(keys=['image', 'label'])])
143 | else:
144 |
145 | val_transforms = Compose([
146 | LoadImaged(keys=['image', 'label']),
147 | AddChanneld(keys=['image', 'label']),
148 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
149 |
150 | NormalizeIntensityd(keys=['image']), # intensity
151 | ScaleIntensityd(keys=['image']),
152 |
153 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
154 | ToTensord(keys=['image', 'label'])])
155 |
156 | else:
157 | if resolution is not None:
158 |
159 | val_transforms = Compose([
160 | LoadImaged(keys=['image']),
161 | AddChanneld(keys=['image']),
162 |
163 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
164 |
165 | NormalizeIntensityd(keys=['image']), # intensity
166 | ScaleIntensityd(keys=['image']),
167 | Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution
168 |
169 | SpatialPadd(keys=['image'], spatial_size=patch_size, method= 'end'), # pad if the image is smaller than patch
170 | ToTensord(keys=['image'])])
171 | else:
172 |
173 | val_transforms = Compose([
174 | LoadImaged(keys=['image']),
175 | AddChanneld(keys=['image']),
176 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
177 |
178 | NormalizeIntensityd(keys=['image']), # intensity
179 | ScaleIntensityd(keys=['image']),
180 |
181 | SpatialPadd(keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
182 | ToTensord(keys=['image'])])
183 |
184 | val_ds = monai.data.Dataset(data=files, transform=val_transforms)
185 | val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available())
186 |
187 | dice_metric = DiceMetric(include_background=False, reduction="mean")
188 | # post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
189 | post_trans = AsDiscrete(argmax=True, to_onehot=True, n_classes=3)
190 | post_label = AsDiscrete(to_onehot=True, n_classes=3)
191 |
192 | # try to use all the available GPUs
193 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids # Multi-gpu selector for training
194 |
195 | if args.gpu_ids != '-1':
196 | num_gpus = len(args.gpu_ids.split(','))
197 | else:
198 | num_gpus = 0
199 | print('number of GPU:', num_gpus)
200 |
201 | if num_gpus > 1:
202 |
203 | # build the network
204 | net = build_net().cuda()
205 |
206 | net = torch.nn.DataParallel(net)
207 | net.load_state_dict(torch.load(weights))
208 |
209 | else:
210 |
211 | net = build_net().cuda()
212 | net.load_state_dict(new_state_dict(weights))
213 |
214 | # define sliding window size and batch size for windows inference
215 | roi_size = patch_size
216 | sw_batch_size = 4
217 |
218 | net.eval()
219 | with torch.no_grad():
220 |
221 | if label is None:
222 | for val_data in val_loader:
223 | val_images = val_data["image"].cuda()
224 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
225 | val_outputs = post_trans(val_outputs)
226 | # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
227 |
228 | else:
229 | metric_sum = 0.0
230 | metric_count = 0
231 | for val_data in val_loader:
232 | val_images, val_labels = val_data["image"].cuda(), val_data["label"].cuda()
233 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
234 | val_outputs = post_trans(val_outputs)
235 | val_labels = post_label(val_labels)
236 | value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
237 | metric_count += len(value)
238 | metric_sum += value.item() * len(value)
239 | # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
240 |
241 | metric = metric_sum / metric_count
242 | print("Evaluation Metric (Dice):", metric)
243 |
244 | result_array = val_outputs.squeeze().data.cpu().numpy()
245 |
246 | empty_array = np.zeros(result_array[0].shape)
247 | for i in range(channels):
248 | channel_i = result_array[i]
249 | if i == 0:
250 | channel_i = np.where(channel_i == 1, 0, channel_i)
251 | elif i > 0:
252 | channel_i = np.where(channel_i == 1, int(i), channel_i)
253 | empty_array = empty_array + channel_i
254 | result_array = empty_array
255 |
256 | # Remove the pad if the image was smaller than the patch in some directions
257 | result_array = result_array[0:resampled_size[0],0:resampled_size[1],0:resampled_size[2]]
258 |
259 | # resample back to the original resolution
260 | if resolution is not None:
261 |
262 | result_array_np = np.transpose(result_array, (2, 1, 0))
263 | result_array_temp = sitk.GetImageFromArray(result_array_np)
264 | result_array_temp.SetSpacing(resolution)
265 | label = resample_sitk_image(result_array_temp, spacing=original_resolution, interpolator='nearest')
266 | res = resize(label, crop_shape, sitk.sitkNearestNeighbor)
267 |
268 | result_array = np.transpose(np.rint(sitk.GetArrayFromImage(res)), axes=(2, 1, 0))
269 |
270 | # recover the cropped background before saving the image
271 | empty_array = np.zeros(original_shape)
272 | empty_array[coord1[0]:coord2[0],coord1[1]:coord2[1],coord1[2]:coord2[2]] = result_array
273 |
274 | result_seg = from_numpy_to_itk(empty_array, image)
275 |
276 | # save label
277 | writer = sitk.ImageFileWriter()
278 | writer.SetFileName(result)
279 | writer.Execute(result_seg)
280 | print("Saved Result at:", str(result))
281 |
282 |
283 | if __name__ == "__main__":
284 |
285 | segment(args.image, args.label, args.result, args.weights, args.resolution, args.patch_size, args.out_channels)
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/networks.py:
--------------------------------------------------------------------------------
1 | from train import *
2 | from torch.nn import init
3 | import monai
4 | from torch.optim import lr_scheduler
5 |
6 |
7 | def init_weights(net, init_type='normal', init_gain=0.02):
8 | """Initialize network weights.
9 | Parameters:
10 | net (network) -- network to be initialized
11 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
12 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
13 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
14 | work better for some applications. Feel free to try yourself.
15 | """
16 | def init_func(m): # define the initialization function
17 | classname = m.__class__.__name__
18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19 | if init_type == 'normal':
20 | init.normal_(m.weight.data, 0.0, init_gain)
21 | elif init_type == 'xavier':
22 | init.xavier_normal_(m.weight.data, gain=init_gain)
23 | elif init_type == 'kaiming':
24 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
25 | elif init_type == 'orthogonal':
26 | init.orthogonal_(m.weight.data, gain=init_gain)
27 | else:
28 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
29 | if hasattr(m, 'bias') and m.bias is not None:
30 | init.constant_(m.bias.data, 0.0)
31 | elif classname.find('BatchNorm3d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
32 | init.normal_(m.weight.data, 1.0, init_gain)
33 | init.constant_(m.bias.data, 0.0)
34 |
35 | # print('initialize network with %s' % init_type)
36 | net.apply(init_func) # apply the initialization function
37 |
38 |
39 | def get_scheduler(optimizer, opt):
40 | if opt.lr_policy == 'lambda':
41 | def lambda_rule(epoch):
42 | # lr_l = 1.0 - max(0, epoch + 1 - opt.epochs/2) / float(opt.epochs/2 + 1)
43 | lr_l = (1 - epoch / opt.epochs) ** 0.9
44 | return lr_l
45 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
46 | elif opt.lr_policy == 'step':
47 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
48 | elif opt.lr_policy == 'plateau':
49 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
50 | elif opt.lr_policy == 'cosine':
51 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs, eta_min=0)
52 | else:
53 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
54 | return scheduler
55 |
56 |
57 | # update learning rate (called once every epoch)
58 | def update_learning_rate(scheduler, optimizer):
59 | scheduler.step()
60 | lr = optimizer.param_groups[0]['lr']
61 | # print('learning rate = %.7f' % lr)
62 |
63 |
64 | from torch.nn import Module, Sequential
65 | from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d
66 | from torch.nn import ReLU, Sigmoid
67 | import torch
68 |
69 |
70 | def build_net():
71 |
72 | from init import Options
73 | opt = Options().parse()
74 | from monai.networks.layers import Norm
75 | from monai.networks.layers.factories import split_args
76 | act_type, args = split_args("RELU")
77 |
78 | # # create Unet
79 | # Unet = monai.networks.nets.UNet(
80 | # dimensions=3,
81 | # in_channels=opt.in_channels,
82 | # out_channels=opt.out_channels,
83 | # channels=(64, 128, 256, 512, 1024),
84 | # strides=(2, 2, 2, 2),
85 | # act=act_type,
86 | # num_res_units=3,
87 | # dropout=0.2,
88 | # norm=Norm.BATCH,
89 | #
90 | # )
91 |
92 | # create nn-Unet
93 | if opt.resolution is None:
94 | sizes, spacings = opt.patch_size, opt.spacing
95 | else:
96 | sizes, spacings = opt.patch_size, opt.resolution
97 |
98 | strides, kernels = [], []
99 |
100 | while True:
101 | spacing_ratio = [sp / min(spacings) for sp in spacings]
102 | stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
103 | kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
104 | if all(s == 1 for s in stride):
105 | break
106 | sizes = [i / j for i, j in zip(sizes, stride)]
107 | spacings = [i * j for i, j in zip(spacings, stride)]
108 | kernels.append(kernel)
109 | strides.append(stride)
110 | strides.insert(0, len(spacings) * [1])
111 | kernels.append(len(spacings) * [3])
112 |
113 | nn_Unet = monai.networks.nets.DynUNet(
114 | spatial_dims=3,
115 | in_channels=opt.in_channels,
116 | out_channels=opt.out_channels,
117 | kernel_size=kernels,
118 | strides=strides,
119 | upsample_kernel_size=strides[1:],
120 | res_block=True,
121 | # act=act_type,
122 | # norm=Norm.BATCH,
123 | )
124 |
125 | init_weights(nn_Unet, init_type='normal')
126 |
127 | return nn_Unet
128 |
129 |
130 | if __name__ == '__main__':
131 | import time
132 | import torch
133 | from torch.autograd import Variable
134 | from torchsummaryX import summary
135 | from torch.nn import init
136 |
137 | opt = Options().parse()
138 |
139 | torch.cuda.set_device(0)
140 | network = build_net()
141 | net = network.cuda().eval()
142 |
143 | data = Variable(torch.randn(int(opt.batch_size), int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda()
144 |
145 | out = net(data)
146 |
147 | torch.onnx.export(net, data, "Unet_model_graph.onnx")
148 |
149 | summary(net,data)
150 | print("out size: {}".format(out.size()))
151 |
152 |
153 |
154 |
155 |
156 |
157 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/organize_folder_structure.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import argparse
4 | import SimpleITK as sitk
5 | import numpy as np
6 | import random
7 |
8 |
9 | def resize(img, new_size, interpolator):
10 | # img = sitk.ReadImage(img)
11 | dimension = img.GetDimension()
12 |
13 | # Physical image size corresponds to the largest physical size in the training set, or any other arbitrary size.
14 | reference_physical_size = np.zeros(dimension)
15 |
16 | reference_physical_size[:] = [(sz - 1) * spc if sz * spc > mx else mx for sz, spc, mx in
17 | zip(img.GetSize(), img.GetSpacing(), reference_physical_size)]
18 |
19 | # Create the reference image with a zero origin, identity direction cosine matrix and dimension
20 | reference_origin = np.zeros(dimension)
21 | reference_direction = np.identity(dimension).flatten()
22 | reference_size = new_size
23 | reference_spacing = [phys_sz / (sz - 1) for sz, phys_sz in zip(reference_size, reference_physical_size)]
24 |
25 | reference_image = sitk.Image(reference_size, img.GetPixelIDValue())
26 | reference_image.SetOrigin(reference_origin)
27 | reference_image.SetSpacing(reference_spacing)
28 | reference_image.SetDirection(reference_direction)
29 |
30 | # Always use the TransformContinuousIndexToPhysicalPoint to compute an indexed point's physical coordinates as
31 | # this takes into account size, spacing and direction cosines. For the vast majority of images the direction
32 | # cosines are the identity matrix, but when this isn't the case simply multiplying the central index by the
33 | # spacing will not yield the correct coordinates resulting in a long debugging session.
34 | reference_center = np.array(
35 | reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize()) / 2.0))
36 |
37 | # Transform which maps from the reference_image to the current img with the translation mapping the image
38 | # origins to each other.
39 | transform = sitk.AffineTransform(dimension)
40 | transform.SetMatrix(img.GetDirection())
41 | transform.SetTranslation(np.array(img.GetOrigin()) - reference_origin)
42 | # Modify the transformation to align the centers of the original and reference image instead of their origins.
43 | centering_transform = sitk.TranslationTransform(dimension)
44 | img_center = np.array(img.TransformContinuousIndexToPhysicalPoint(np.array(img.GetSize()) / 2.0))
45 | centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
46 | centered_transform = sitk.Transform(transform)
47 | centered_transform.AddTransform(centering_transform)
48 | # Using the linear interpolator as these are intensity images, if there is a need to resample a ground truth
49 | # segmentation then the segmentation image should be resampled using the NearestNeighbor interpolator so that
50 | # no new labels are introduced.
51 |
52 | return sitk.Resample(img, reference_image, centered_transform, interpolator, 0.0)
53 |
54 |
55 | def resample_sitk_image(sitk_image, spacing=None, interpolator=None, fill_value=0):
56 | # https://github.com/SimpleITK/SlicerSimpleFilters/blob/master/SimpleFilters/SimpleFilters.py
57 | _SITK_INTERPOLATOR_DICT = {
58 | 'nearest': sitk.sitkNearestNeighbor,
59 | 'linear': sitk.sitkLinear,
60 | 'gaussian': sitk.sitkGaussian,
61 | 'label_gaussian': sitk.sitkLabelGaussian,
62 | 'bspline': sitk.sitkBSpline,
63 | 'hamming_sinc': sitk.sitkHammingWindowedSinc,
64 | 'cosine_windowed_sinc': sitk.sitkCosineWindowedSinc,
65 | 'welch_windowed_sinc': sitk.sitkWelchWindowedSinc,
66 | 'lanczos_windowed_sinc': sitk.sitkLanczosWindowedSinc
67 | }
68 |
69 | """Resamples an ITK image to a new grid. If no spacing is given,
70 | the resampling is done isotropically to the smallest value in the current
71 | spacing. This is usually the in-plane resolution. If not given, the
72 | interpolation is derived from the input data type. Binary input
73 | (e.g., masks) are resampled with nearest neighbors, otherwise linear
74 | interpolation is chosen.
75 | Parameters
76 | ----------
77 | sitk_image : SimpleITK image or str
78 | Either a SimpleITK image or a path to a SimpleITK readable file.
79 | spacing : tuple
80 | Tuple of integers
81 | interpolator : str
82 | Either `nearest`, `linear` or None.
83 | fill_value : int
84 | Returns
85 | -------
86 | SimpleITK image.
87 | """
88 |
89 | if isinstance(sitk_image, str):
90 | sitk_image = sitk.ReadImage(sitk_image)
91 | num_dim = sitk_image.GetDimension()
92 |
93 | if not interpolator:
94 | interpolator = 'linear'
95 | pixelid = sitk_image.GetPixelIDValue()
96 |
97 | if pixelid not in [1, 2, 4]:
98 | raise NotImplementedError(
99 | 'Set `interpolator` manually, '
100 | 'can only infer for 8-bit unsigned or 16, 32-bit signed integers')
101 | if pixelid == 1: # 8-bit unsigned int
102 | interpolator = 'nearest'
103 |
104 | orig_pixelid = sitk_image.GetPixelIDValue()
105 | orig_origin = sitk_image.GetOrigin()
106 | orig_direction = sitk_image.GetDirection()
107 | orig_spacing = np.array(sitk_image.GetSpacing())
108 | orig_size = np.array(sitk_image.GetSize(), dtype=np.int)
109 |
110 | if not spacing:
111 | min_spacing = orig_spacing.min()
112 | new_spacing = [min_spacing] * num_dim
113 | else:
114 | new_spacing = [float(s) for s in spacing]
115 |
116 | assert interpolator in _SITK_INTERPOLATOR_DICT.keys(), \
117 | '`interpolator` should be one of {}'.format(_SITK_INTERPOLATOR_DICT.keys())
118 |
119 | sitk_interpolator = _SITK_INTERPOLATOR_DICT[interpolator]
120 |
121 | new_size = orig_size * (orig_spacing / new_spacing)
122 | new_size = np.ceil(new_size).astype(np.int) # Image dimensions are in integers
123 | new_size = [int(s) for s in new_size] # SimpleITK expects lists, not ndarrays
124 |
125 | resample_filter = sitk.ResampleImageFilter()
126 |
127 | resampled_sitk_image = resample_filter.Execute(sitk_image,
128 | new_size,
129 | sitk.Transform(),
130 | sitk_interpolator,
131 | orig_origin,
132 | new_spacing,
133 | orig_direction,
134 | fill_value,
135 | orig_pixelid)
136 |
137 | return resampled_sitk_image
138 |
139 |
140 | def numericalSort(value):
141 | numbers = re.compile(r'(\d+)')
142 | parts = numbers.split(value)
143 | parts[1::2] = map(int, parts[1::2])
144 | return parts
145 |
146 |
147 | def lstFiles(Path):
148 |
149 | images_list = [] # create an empty list, the raw image data files is stored here
150 | for dirName, subdirList, fileList in os.walk(Path):
151 | for filename in fileList:
152 | if ".nii.gz" in filename.lower():
153 | images_list.append(os.path.join(dirName, filename))
154 | elif ".nii" in filename.lower():
155 | images_list.append(os.path.join(dirName, filename))
156 | elif ".mhd" in filename.lower():
157 | images_list.append(os.path.join(dirName, filename))
158 |
159 | images_list = sorted(images_list, key=numericalSort)
160 |
161 | return images_list
162 |
163 |
164 | def uniform_img_dimensions(image, label):
165 |
166 | image_array = sitk.GetArrayFromImage(image)
167 | image_array = np.transpose(image_array, axes=(2, 1, 0)) # reshape array from itk z,y,x to x,y,z
168 | image_shape = image_array.shape
169 |
170 | label = resample_sitk_image(label, spacing=image.GetSpacing(), interpolator='nearest')
171 | res = resize(label,image_shape,sitk.sitkNearestNeighbor)
172 | res = (np.rint(sitk.GetArrayFromImage(res)))
173 | res = sitk.GetImageFromArray(res.astype('uint8'))
174 | res.SetDirection(image.GetDirection())
175 | res.SetOrigin(image.GetOrigin())
176 | res.SetSpacing(image.GetSpacing())
177 | print(res.GetSize())
178 |
179 | return image, res
180 |
181 | parser = argparse.ArgumentParser()
182 | parser.add_argument('--images', default='./Data_folder/CT', help='path to the images')
183 | parser.add_argument('--labels', default='./Data_folder/CT_label', help='path to the labels')
184 | parser.add_argument('--split_val', default=7, help='number of images for validation')
185 | parser.add_argument('--split_test', default=3, help='number of images for testing')
186 | args = parser.parse_args()
187 |
188 | if __name__ == "__main__":
189 |
190 | list_images = lstFiles(args.images)
191 | list_labels = lstFiles(args.labels)
192 |
193 | mapIndexPosition = list(zip(list_images, list_labels)) # shuffle order list
194 | random.shuffle(mapIndexPosition)
195 | list_images, list_labels = zip(*mapIndexPosition)
196 |
197 | os.mkdir('./Data_folder/images')
198 | os.mkdir('./Data_folder/labels')
199 |
200 | # 1
201 | if not os.path.isdir('./Data_folder/images/train'):
202 | os.mkdir('./Data_folder/images/train/')
203 | # 2
204 | if not os.path.isdir('./Data_folder/images/val'):
205 | os.mkdir('./Data_folder/images/val')
206 |
207 | # 3
208 | if not os.path.isdir('./Data_folder/images/test'):
209 | os.mkdir('./Data_folder/images/test')
210 |
211 | # 4
212 | if not os.path.isdir('./Data_folder/labels/train'):
213 | os.mkdir('./Data_folder/labels/train')
214 |
215 | # 5
216 | if not os.path.isdir('./Data_folder/labels/val'):
217 | os.mkdir('./Data_folder/labels/val')
218 |
219 | # 6
220 | if not os.path.isdir('./Data_folder/labels/test'):
221 | os.mkdir('./Data_folder/labels/test')
222 |
223 | for i in range(len(list_images)-int(args.split_test + args.split_val)):
224 |
225 | a = list_images[int(args.split_test + args.split_val)+i]
226 | b = list_labels[int(args.split_test + args.split_val)+i]
227 |
228 | print(a)
229 |
230 | label = sitk.ReadImage(b)
231 | image = sitk.ReadImage(a)
232 |
233 | image, label = uniform_img_dimensions(image, label)
234 |
235 | image_directory = os.path.join('./Data_folder/images/train', f"image{i:d}.nii")
236 | label_directory = os.path.join('./Data_folder/labels/train', f"label{i:d}.nii")
237 |
238 | sitk.WriteImage(image, image_directory)
239 | sitk.WriteImage(label, label_directory)
240 |
241 | for i in range(int(args.split_val)):
242 |
243 | a = list_images[int(args.split_test)+i]
244 | b = list_labels[int(args.split_test)+i]
245 |
246 | print(a)
247 |
248 | label = sitk.ReadImage(b)
249 | image = sitk.ReadImage(a)
250 |
251 | image, label = uniform_img_dimensions(image, label)
252 |
253 | image_directory = os.path.join('./Data_folder/images/val', f"image{i:d}.nii")
254 | label_directory = os.path.join('./Data_folder/labels/val', f"label{i:d}.nii")
255 |
256 | sitk.WriteImage(image, image_directory)
257 | sitk.WriteImage(label, label_directory)
258 |
259 | for i in range(int(args.split_test)):
260 |
261 | a = list_images[i]
262 | b = list_labels[i]
263 |
264 | print(a)
265 |
266 | label = sitk.ReadImage(b)
267 | image = sitk.ReadImage(a)
268 |
269 | image, label = uniform_img_dimensions(image, label)
270 |
271 | image_directory = os.path.join('./Data_folder/images/test', f"image{i:d}.nii")
272 | label_directory = os.path.join('./Data_folder/labels/test', f"label{i:d}.nii")
273 |
274 | sitk.WriteImage(image, image_directory)
275 | sitk.WriteImage(label, label_directory)
276 |
277 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/predict_single_image.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | # -*- coding: utf-8 -*-
3 |
4 | from train import *
5 | import argparse
6 | from networks import *
7 | import SimpleITK as sitk
8 | from monai.inferers import sliding_window_inference
9 | from monai.metrics import DiceMetric
10 | from monai.data import NiftiSaver, create_test_image_3d, list_data_collate
11 | from collections import OrderedDict
12 | from organize_folder_structure import resize, resample_sitk_image, uniform_img_dimensions
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--image", type=str, default='./Data_folder/images/test/image0.nii')
17 | parser.add_argument("--label", type=str, default='./Data_folder/labels/test/label0.nii')
18 | parser.add_argument("--result", type=str, default='./Data_folder/test.nii', help='path to the .nii result to save')
19 | parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load')
20 | parser.add_argument("--resolution", default=[3,3,3], help='New resolution if you want to resample')
21 | parser.add_argument("--patch_size", type=int, nargs=3, default=(128, 128, 64), help="Input dimension for the generator")
22 | parser.add_argument('--gpu_ids', type=str, default='2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
23 | args = parser.parse_args()
24 |
25 |
26 | def new_state_dict(file_name):
27 | state_dict = torch.load(file_name)
28 | new_state_dict = OrderedDict()
29 | for k, v in state_dict.items():
30 | if k[:6] == 'module':
31 | name = k[7:]
32 | new_state_dict[name] = v
33 | else:
34 | new_state_dict[k] = v
35 | return new_state_dict
36 |
37 |
38 | def from_numpy_to_itk(image_np, image_itk):
39 |
40 | # read image file
41 | reader = sitk.ImageFileReader()
42 | reader.SetFileName(image_itk)
43 | image_itk = reader.Execute()
44 |
45 | image_np = np.transpose(image_np, (2, 1, 0))
46 | image = sitk.GetImageFromArray(image_np)
47 | image.SetDirection(image_itk.GetDirection())
48 | image.SetSpacing(image_itk.GetSpacing())
49 | image.SetOrigin(image_itk.GetOrigin())
50 | return image
51 |
52 |
53 | # function to keep track of the cropped area and coordinates
54 | def statistics_crop(image, resolution):
55 |
56 | files = [{"image": image}]
57 |
58 | reader = sitk.ImageFileReader()
59 | reader.SetFileName(image)
60 | image_itk = reader.Execute()
61 | original_resolution = image_itk.GetSpacing()
62 |
63 | # original size
64 | transforms = Compose([
65 | LoadImaged(keys=['image']),
66 | AddChanneld(keys=['image']),
67 | ToTensord(keys=['image'])])
68 | data = monai.data.Dataset(data=files, transform=transforms)
69 | loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
70 | loader = monai.utils.misc.first(loader)
71 | im, = (loader['image'][0])
72 | vol = im.numpy()
73 | original_shape = vol.shape
74 |
75 | # cropped foreground size
76 | transforms = Compose([
77 | LoadImaged(keys=['image']),
78 | AddChanneld(keys=['image']),
79 | CropForegroundd(keys=['image'], source_key='image', start_coord_key='foreground_start_coord',
80 | end_coord_key='foreground_end_coord', ), # crop CropForeground
81 | ToTensord(keys=['image', 'foreground_start_coord', 'foreground_end_coord'])])
82 | data = monai.data.Dataset(data=files, transform=transforms)
83 | loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
84 | loader = monai.utils.misc.first(loader)
85 | im, coord1, coord2 = (loader['image'][0], loader['foreground_start_coord'][0], loader['foreground_end_coord'][0])
86 | vol = im[0].numpy()
87 | coord1 = coord1.numpy()
88 | coord2 = coord2.numpy()
89 | crop_shape = vol.shape
90 |
91 | if resolution is not None:
92 |
93 | transforms = Compose([
94 | LoadImaged(keys=['image']),
95 | AddChanneld(keys=['image']),
96 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
97 | Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution
98 | ToTensord(keys=['image'])])
99 |
100 | data = monai.data.Dataset(data=files, transform=transforms)
101 | loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
102 | loader = monai.utils.misc.first(loader)
103 | im, = (loader['image'][0])
104 | vol = im.numpy()
105 | resampled_size = vol.shape
106 |
107 | else:
108 |
109 | resampled_size = original_shape
110 |
111 | return original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution
112 |
113 |
114 | def segment(image, label, result, weights, resolution, patch_size):
115 |
116 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
117 |
118 | if label is not None:
119 | files = [{"image": image, "label": label}]
120 | else:
121 | files = [{"image": image}]
122 |
123 | # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
124 | original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(image, resolution)
125 |
126 | # -------------------------------
127 |
128 | if label is not None:
129 | if resolution is not None:
130 |
131 | val_transforms = Compose([
132 | LoadImaged(keys=['image', 'label']),
133 | AddChanneld(keys=['image', 'label']),
134 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
135 | ThresholdIntensityd(keys=['image'], threshold=-350, above=True, cval=-350), # Threshold CT
136 | ThresholdIntensityd(keys=['image'], threshold=350, above=False, cval=350),
137 |
138 | NormalizeIntensityd(keys=['image']), # intensity
139 | ScaleIntensityd(keys=['image']),
140 | Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution
141 |
142 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method= 'end'),
143 | ToTensord(keys=['image', 'label'])])
144 | else:
145 |
146 | val_transforms = Compose([
147 | LoadImaged(keys=['image', 'label']),
148 | AddChanneld(keys=['image', 'label']),
149 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
150 | ThresholdIntensityd(keys=['image'], threshold=-350, above=True, cval=-350), # Threshold CT
151 | ThresholdIntensityd(keys=['image'], threshold=350, above=False, cval=350),
152 |
153 | NormalizeIntensityd(keys=['image']), # intensity
154 | ScaleIntensityd(keys=['image']),
155 |
156 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
157 | ToTensord(keys=['image', 'label'])])
158 |
159 | else:
160 | if resolution is not None:
161 |
162 | val_transforms = Compose([
163 | LoadImaged(keys=['image']),
164 | AddChanneld(keys=['image']),
165 |
166 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
167 | ThresholdIntensityd(keys=['image'], threshold=-350, above=True, cval=-350), # Threshold CT
168 | ThresholdIntensityd(keys=['image'], threshold=350, above=False, cval=350),
169 |
170 | NormalizeIntensityd(keys=['image']), # intensity
171 | ScaleIntensityd(keys=['image']),
172 | Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution
173 |
174 | SpatialPadd(keys=['image'], spatial_size=patch_size, method= 'end'), # pad if the image is smaller than patch
175 | ToTensord(keys=['image'])])
176 | else:
177 |
178 | val_transforms = Compose([
179 | LoadImaged(keys=['image']),
180 | AddChanneld(keys=['image']),
181 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
182 | ThresholdIntensityd(keys=['image'], threshold=-350, above=True, cval=-350), # Threshold CT
183 | ThresholdIntensityd(keys=['image'], threshold=350, above=False, cval=350),
184 |
185 | NormalizeIntensityd(keys=['image']), # intensity
186 | ScaleIntensityd(keys=['image']),
187 |
188 | SpatialPadd(keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
189 | ToTensord(keys=['image'])])
190 |
191 | val_ds = monai.data.Dataset(data=files, transform=val_transforms)
192 | val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available())
193 |
194 | dice_metric = DiceMetric(include_background=True, reduction="mean")
195 | post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
196 |
197 | # try to use all the available GPUs
198 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids # Multi-gpu selector for training
199 |
200 | if args.gpu_ids != '-1':
201 | num_gpus = len(args.gpu_ids.split(','))
202 | else:
203 | num_gpus = 0
204 | print('number of GPU:', num_gpus)
205 |
206 | if num_gpus > 1:
207 |
208 | # build the network
209 | net = build_net().cuda()
210 |
211 | net = torch.nn.DataParallel(net)
212 | net.load_state_dict(torch.load(weights))
213 |
214 | else:
215 |
216 | net = build_net().cuda()
217 | net.load_state_dict(new_state_dict(weights))
218 |
219 | # define sliding window size and batch size for windows inference
220 | roi_size = patch_size
221 | sw_batch_size = 4
222 |
223 | net.eval()
224 | with torch.no_grad():
225 |
226 | if label is None:
227 | for val_data in val_loader:
228 | val_images = val_data["image"].cuda()
229 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
230 | val_outputs = post_trans(val_outputs)
231 | # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
232 |
233 | else:
234 | metric_sum = 0.0
235 | metric_count = 0
236 | for val_data in val_loader:
237 | val_images, val_labels = val_data["image"].cuda(), val_data["label"].cuda()
238 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
239 | val_outputs = post_trans(val_outputs)
240 | value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
241 | metric_count += len(value)
242 | metric_sum += value.item() * len(value)
243 | # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
244 |
245 | metric = metric_sum / metric_count
246 | print("Evaluation Metric (Dice):", metric)
247 |
248 | result_array = val_outputs.squeeze().data.cpu().numpy()
249 | # Remove the pad if the image was smaller than the patch in some directions
250 | result_array = result_array[0:resampled_size[0],0:resampled_size[1],0:resampled_size[2]]
251 |
252 | # resample back to the original resolution
253 | if resolution is not None:
254 |
255 | result_array_np = np.transpose(result_array, (2, 1, 0))
256 | result_array_temp = sitk.GetImageFromArray(result_array_np)
257 | result_array_temp.SetSpacing(resolution)
258 | label = resample_sitk_image(result_array_temp, spacing=original_resolution, interpolator='nearest')
259 | res = resize(label, crop_shape, sitk.sitkNearestNeighbor)
260 |
261 | result_array = np.transpose(np.rint(sitk.GetArrayFromImage(res)), axes=(2, 1, 0))
262 |
263 | # recover the cropped background before saving the image
264 | empty_array = np.zeros(original_shape)
265 | empty_array[coord1[0]:coord2[0],coord1[1]:coord2[1],coord1[2]:coord2[2]] = result_array
266 |
267 | result_seg = from_numpy_to_itk(empty_array, image)
268 |
269 | # save label
270 | writer = sitk.ImageFileWriter()
271 | writer.SetFileName(result)
272 | writer.Execute(result_seg)
273 | print("Saved Result at:", str(result))
274 |
275 |
276 | if __name__ == "__main__":
277 |
278 | segment(args.image, args.label, args.result, args.weights, args.resolution, args.patch_size)
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
--------------------------------------------------------------------------------
/monai 0.5.0/deprecated/requirements.txt:
--------------------------------------------------------------------------------
1 | simpleITK==1.2.4
2 | torchsummaryX
3 |
--------------------------------------------------------------------------------
/monai 0.5.0/init.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 |
5 | class Options():
6 |
7 | """This class defines options used during both training and test time."""
8 |
9 | def __init__(self):
10 | """Reset the class; indicates the class hasn't been initailized"""
11 | self.initialized = False
12 |
13 | def initialize(self, parser):
14 |
15 | # basic parameters
16 | parser.add_argument('--images_folder', type=str, default='./Data_folder/images')
17 | parser.add_argument('--labels_folder', type=str, default='./Data_folder/labels')
18 | parser.add_argument('--increase_factor_data', default=1, help='Increase data number per epoch')
19 | parser.add_argument('--preload', type=str, default=None)
20 | parser.add_argument('--gpu_ids', type=str, default='0,1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
21 | parser.add_argument('--workers', default=8, type=int, help='number of data loading workers')
22 |
23 | # dataset parameters
24 | parser.add_argument('--patch_size', default=(160, 160, 32), help='Size of the patches extracted from the image')
25 | parser.add_argument('--spacing', default=[2.25, 2.25, 3], help='Original Resolution')
26 | parser.add_argument('--resolution', default=None, help='New Resolution, if you want to resample the data in training. I suggest to resample in organize_folder_structure.py, otherwise in train resampling is slower')
27 | parser.add_argument('--batch_size', type=int, default=4, help='batch size, depends on your machine')
28 | parser.add_argument('--in_channels', default=1, type=int, help='Channels of the input')
29 | parser.add_argument('--out_channels', default=1, type=int, help='Channels of the output')
30 |
31 | # training parameters
32 | parser.add_argument('--epochs', default=1000, help='Number of epochs')
33 | parser.add_argument('--lr', default=0.01, help='Learning rate')
34 |
35 | self.initialized = True
36 | return parser
37 |
38 | def parse(self):
39 | if not self.initialized:
40 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
41 | parser = self.initialize(parser)
42 | opt = parser.parse_args()
43 | # set gpu ids
44 | if opt.gpu_ids != '-1':
45 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
46 | return opt
47 |
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/monai 0.5.0/installation_commands_.txt:
--------------------------------------------------------------------------------
1 |
2 | 1) set up anaconda env: conda create -n monai_david python=3.8
3 | conda activate monai_david
4 |
5 | 2) install pytorch conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.1 -c pytorch # this for cuda (check your cuda version)
6 |
7 | 3) conda install git pip
8 | 4) pip install git+git://github.com/davidiommi/MONAI # dowload libraries
9 | 5) pip install -r requirements.txt # dowload libraries
10 |
11 |
12 |
--------------------------------------------------------------------------------
/monai 0.5.0/multi_label_segmentation_example/init.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import SimpleITK as sitk
4 |
5 | image = sitk.ReadImage('Data_folder/images/train/image0.nii')
6 | image_spacing = image.GetSpacing()
7 |
8 | class Options():
9 |
10 | """This class defines options used during both training and test time."""
11 |
12 | def __init__(self):
13 | """Reset the class; indicates the class hasn't been initailized"""
14 | self.initialized = False
15 |
16 | def initialize(self, parser):
17 |
18 | # basic parameters
19 | parser.add_argument('--images_folder', type=str, default='./Data_folder/images')
20 | parser.add_argument('--labels_folder', type=str, default='./Data_folder/labels')
21 | parser.add_argument('--increase_factor_data', default=3, help='Increase data number per epoch')
22 | parser.add_argument('--preload', type=str, default=None)
23 | parser.add_argument('--gpu_ids', type=str, default='0,1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
24 | parser.add_argument('--workers', default=8, type=int, help='number of data loading workers')
25 |
26 | # dataset parameters
27 | parser.add_argument('--patch_size', default=(256, 256, 32), help='Size of the patches extracted from the image')
28 | parser.add_argument('--spacing', default=image_spacing, help='Original Resolution')
29 | parser.add_argument('--resolution', default=(0.6, 0.6, 3), help='New Resolution, if you want to resample the data')
30 | parser.add_argument('--batch_size', type=int, default=4, help='batch size')
31 | parser.add_argument('--in_channels', default=1, type=int, help='Channels of the input')
32 | parser.add_argument('--out_channels', default=3, type=int, help='Channels of the output')
33 |
34 | # training parameters
35 | parser.add_argument('--epochs', default=200, help='Number of epochs')
36 | parser.add_argument('--lr', default=0.01, help='Learning rate')
37 |
38 | # Inference
39 | # This is just a trick to make the predict script working
40 | parser.add_argument('--result', default=None, help='Keep this empty and go to predict_single_image script')
41 | parser.add_argument('--weights', default=None, help='Keep this empty and go to predict_single_image script')
42 |
43 | self.initialized = True
44 | return parser
45 |
46 | def parse(self):
47 | if not self.initialized:
48 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
49 | parser = self.initialize(parser)
50 | opt = parser.parse_args()
51 | # set gpu ids
52 | if opt.gpu_ids != '-1':
53 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
54 | return opt
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/monai 0.5.0/multi_label_segmentation_example/predict_single_image.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | # -*- coding: utf-8 -*-
3 |
4 | from train import *
5 | import argparse
6 | from networks import *
7 | import SimpleITK as sitk
8 | from monai.inferers import sliding_window_inference
9 | from monai.metrics import DiceMetric
10 | from monai.data import NiftiSaver, create_test_image_3d, list_data_collate
11 | from collections import OrderedDict
12 | from organize_folder_structure import resize, resample_sitk_image, uniform_img_dimensions
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--image", type=str, default='./Data_folder/images/test/image0.nii')
17 | parser.add_argument("--label", type=str, default='./Data_folder/labels/test/label0.nii')
18 | parser.add_argument("--result", type=str, default='./Data_folder/test.nii', help='path to the .nii result to save')
19 | parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load')
20 | parser.add_argument("--resolution", default=(0.6, 0.6, 3), help='New resolution if you want to resample')
21 | parser.add_argument("--out_channels", default=3, help='Number of labels')
22 | parser.add_argument("--patch_size", type=int, nargs=3, default=(256, 256, 32), help="Input dimension for the generator")
23 | parser.add_argument('--gpu_ids', type=str, default='2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
24 | args = parser.parse_args()
25 |
26 |
27 | def new_state_dict(file_name):
28 | state_dict = torch.load(file_name)
29 | new_state_dict = OrderedDict()
30 | for k, v in state_dict.items():
31 | if k[:6] == 'module':
32 | name = k[7:]
33 | new_state_dict[name] = v
34 | else:
35 | new_state_dict[k] = v
36 | return new_state_dict
37 |
38 |
39 | def from_numpy_to_itk(image_np, image_itk):
40 |
41 | # read image file
42 | reader = sitk.ImageFileReader()
43 | reader.SetFileName(image_itk)
44 | image_itk = reader.Execute()
45 |
46 | image_np = np.transpose(image_np, (2, 1, 0))
47 | image = sitk.GetImageFromArray(image_np)
48 | image.SetDirection(image_itk.GetDirection())
49 | image.SetSpacing(image_itk.GetSpacing())
50 | image.SetOrigin(image_itk.GetOrigin())
51 | return image
52 |
53 |
54 | # function to keep track of the cropped area and coordinates
55 | def statistics_crop(image, resolution):
56 |
57 | files = [{"image": image}]
58 |
59 | reader = sitk.ImageFileReader()
60 | reader.SetFileName(image)
61 | image_itk = reader.Execute()
62 | original_resolution = image_itk.GetSpacing()
63 |
64 | # original size
65 | transforms = Compose([
66 | LoadImaged(keys=['image']),
67 | AddChanneld(keys=['image']),
68 | ToTensord(keys=['image'])])
69 | data = monai.data.Dataset(data=files, transform=transforms)
70 | loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
71 | loader = monai.utils.misc.first(loader)
72 | im, = (loader['image'][0])
73 | vol = im.numpy()
74 | original_shape = vol.shape
75 |
76 | # cropped foreground size
77 | transforms = Compose([
78 | LoadImaged(keys=['image']),
79 | AddChanneld(keys=['image']),
80 | CropForegroundd(keys=['image'], source_key='image', start_coord_key='foreground_start_coord',
81 | end_coord_key='foreground_end_coord', ), # crop CropForeground
82 | ToTensord(keys=['image', 'foreground_start_coord', 'foreground_end_coord'])])
83 | data = monai.data.Dataset(data=files, transform=transforms)
84 | loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
85 | loader = monai.utils.misc.first(loader)
86 | im, coord1, coord2 = (loader['image'][0], loader['foreground_start_coord'][0], loader['foreground_end_coord'][0])
87 | vol = im[0].numpy()
88 | coord1 = coord1.numpy()
89 | coord2 = coord2.numpy()
90 | crop_shape = vol.shape
91 |
92 | if resolution is not None:
93 |
94 | transforms = Compose([
95 | LoadImaged(keys=['image']),
96 | AddChanneld(keys=['image']),
97 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
98 | Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution
99 | ToTensord(keys=['image'])])
100 |
101 | data = monai.data.Dataset(data=files, transform=transforms)
102 | loader = DataLoader(data, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
103 | loader = monai.utils.misc.first(loader)
104 | im, = (loader['image'][0])
105 | vol = im.numpy()
106 | resampled_size = vol.shape
107 |
108 | else:
109 |
110 | resampled_size = original_shape
111 |
112 | return original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution
113 |
114 |
115 | def segment(image, label, result, weights, resolution, patch_size, channels):
116 |
117 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
118 |
119 | if label is not None:
120 | files = [{"image": image, "label": label}]
121 | else:
122 | files = [{"image": image}]
123 |
124 | # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
125 | original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(image, resolution)
126 |
127 | # -------------------------------
128 |
129 | if label is not None:
130 | if resolution is not None:
131 |
132 | val_transforms = Compose([
133 | LoadImaged(keys=['image', 'label']),
134 | AddChanneld(keys=['image', 'label']),
135 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
136 |
137 | NormalizeIntensityd(keys=['image']), # intensity
138 | ScaleIntensityd(keys=['image']),
139 | Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution
140 |
141 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method= 'end'),
142 | ToTensord(keys=['image', 'label'])])
143 | else:
144 |
145 | val_transforms = Compose([
146 | LoadImaged(keys=['image', 'label']),
147 | AddChanneld(keys=['image', 'label']),
148 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
149 |
150 | NormalizeIntensityd(keys=['image']), # intensity
151 | ScaleIntensityd(keys=['image']),
152 |
153 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
154 | ToTensord(keys=['image', 'label'])])
155 |
156 | else:
157 | if resolution is not None:
158 |
159 | val_transforms = Compose([
160 | LoadImaged(keys=['image']),
161 | AddChanneld(keys=['image']),
162 |
163 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
164 |
165 | NormalizeIntensityd(keys=['image']), # intensity
166 | ScaleIntensityd(keys=['image']),
167 | Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution
168 |
169 | SpatialPadd(keys=['image'], spatial_size=patch_size, method= 'end'), # pad if the image is smaller than patch
170 | ToTensord(keys=['image'])])
171 | else:
172 |
173 | val_transforms = Compose([
174 | LoadImaged(keys=['image']),
175 | AddChanneld(keys=['image']),
176 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
177 |
178 | NormalizeIntensityd(keys=['image']), # intensity
179 | ScaleIntensityd(keys=['image']),
180 |
181 | SpatialPadd(keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
182 | ToTensord(keys=['image'])])
183 |
184 | val_ds = monai.data.Dataset(data=files, transform=val_transforms)
185 | val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available())
186 |
187 | dice_metric = DiceMetric(include_background=False, reduction="mean")
188 | # post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
189 | post_trans = AsDiscrete(argmax=True, to_onehot=True, n_classes=3)
190 | post_label = AsDiscrete(to_onehot=True, n_classes=3)
191 |
192 | # try to use all the available GPUs
193 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids # Multi-gpu selector for training
194 |
195 | if args.gpu_ids != '-1':
196 | num_gpus = len(args.gpu_ids.split(','))
197 | else:
198 | num_gpus = 0
199 | print('number of GPU:', num_gpus)
200 |
201 | if num_gpus > 1:
202 |
203 | # build the network
204 | net = build_net().cuda()
205 |
206 | net = torch.nn.DataParallel(net)
207 | net.load_state_dict(torch.load(weights))
208 |
209 | else:
210 |
211 | net = build_net().cuda()
212 | net.load_state_dict(new_state_dict(weights))
213 |
214 | # define sliding window size and batch size for windows inference
215 | roi_size = patch_size
216 | sw_batch_size = 4
217 |
218 | net.eval()
219 | with torch.no_grad():
220 |
221 | if label is None:
222 | for val_data in val_loader:
223 | val_images = val_data["image"].cuda()
224 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
225 | val_outputs = post_trans(val_outputs)
226 | # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
227 |
228 | else:
229 | metric_sum = 0.0
230 | metric_count = 0
231 | for val_data in val_loader:
232 | val_images, val_labels = val_data["image"].cuda(), val_data["label"].cuda()
233 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
234 | val_outputs = post_trans(val_outputs)
235 | val_labels = post_label(val_labels)
236 | value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
237 | metric_count += len(value)
238 | metric_sum += value.item() * len(value)
239 | # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
240 |
241 | metric = metric_sum / metric_count
242 | print("Evaluation Metric (Dice):", metric)
243 |
244 | result_array = val_outputs.squeeze().data.cpu().numpy()
245 |
246 | empty_array = np.zeros(result_array[0].shape)
247 |
248 | for i in range(channels): # MULTI LABEL segmentation part
249 | channel_i = result_array[i]
250 | if i == 0:
251 | channel_i = np.where(channel_i == 1, 0, channel_i)
252 | elif i > 0:
253 | channel_i = np.where(channel_i == 1, int(i), channel_i)
254 | empty_array = empty_array + channel_i
255 | result_array = empty_array
256 |
257 | # Remove the pad if the image was smaller than the patch in some directions
258 | result_array = result_array[0:resampled_size[0],0:resampled_size[1],0:resampled_size[2]]
259 |
260 | # resample back to the original resolution
261 | if resolution is not None:
262 |
263 | result_array_np = np.transpose(result_array, (2, 1, 0))
264 | result_array_temp = sitk.GetImageFromArray(result_array_np)
265 | result_array_temp.SetSpacing(resolution)
266 | label = resample_sitk_image(result_array_temp, spacing=original_resolution, interpolator='nearest')
267 | res = resize(label, crop_shape, sitk.sitkNearestNeighbor)
268 |
269 | result_array = np.transpose(np.rint(sitk.GetArrayFromImage(res)), axes=(2, 1, 0))
270 |
271 | # recover the cropped background before saving the image
272 | empty_array = np.zeros(original_shape)
273 | empty_array[coord1[0]:coord2[0],coord1[1]:coord2[1],coord1[2]:coord2[2]] = result_array
274 |
275 | result_seg = from_numpy_to_itk(empty_array, image)
276 |
277 | # save label
278 | writer = sitk.ImageFileWriter()
279 | writer.SetFileName(result)
280 | writer.Execute(result_seg)
281 | print("Saved Result at:", str(result))
282 |
283 |
284 | if __name__ == "__main__":
285 |
286 | parser = argparse.ArgumentParser()
287 | parser.add_argument("--image", type=str, default='./Data_folder/images/test/image0.nii')
288 | parser.add_argument("--label", type=str, default='./Data_folder/labels/test/label0.nii')
289 | parser.add_argument("--result", type=str, default='./Data_folder/test.nii', help='path to the .nii result to save')
290 | parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load')
291 | parser.add_argument("--resolution", default=(0.6, 0.6, 3), help='New resolution if you want to resample')
292 | parser.add_argument("--out_channels", default=3, help='Number of labels')
293 | parser.add_argument("--patch_size", type=int, nargs=3, default=(256, 256, 32), help="Input dimension for the generator")
294 | parser.add_argument('--gpu_ids', type=str, default='2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
295 | args = parser.parse_args()
296 |
297 | segment(args.image, args.label, args.result, args.weights, args.resolution, args.patch_size, args.out_channels)
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
--------------------------------------------------------------------------------
/monai 0.5.0/networks.py:
--------------------------------------------------------------------------------
1 | from train import *
2 | from torch.nn import init
3 | import monai
4 | from torch.optim import lr_scheduler
5 |
6 |
7 | def init_weights(net, init_type='normal', init_gain=0.02):
8 | """Initialize network weights.
9 | Parameters:
10 | net (network) -- network to be initialized
11 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
12 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
13 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
14 | work better for some applications. Feel free to try yourself.
15 | """
16 | def init_func(m): # define the initialization function
17 | classname = m.__class__.__name__
18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19 | if init_type == 'normal':
20 | init.normal_(m.weight.data, 0.0, init_gain)
21 | elif init_type == 'xavier':
22 | init.xavier_normal_(m.weight.data, gain=init_gain)
23 | elif init_type == 'kaiming':
24 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
25 | elif init_type == 'orthogonal':
26 | init.orthogonal_(m.weight.data, gain=init_gain)
27 | else:
28 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
29 | if hasattr(m, 'bias') and m.bias is not None:
30 | init.constant_(m.bias.data, 0.0)
31 | elif classname.find('BatchNorm3d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
32 | init.normal_(m.weight.data, 1.0, init_gain)
33 | init.constant_(m.bias.data, 0.0)
34 |
35 | # print('initialize network with %s' % init_type)
36 | net.apply(init_func) # apply the initialization function
37 |
38 |
39 | def get_scheduler(optimizer, opt):
40 | if opt.lr_policy == 'lambda':
41 | def lambda_rule(epoch):
42 | # lr_l = 1.0 - max(0, epoch + 1 - opt.epochs/2) / float(opt.epochs/2 + 1)
43 | lr_l = (1 - epoch / opt.epochs) ** 0.9
44 | return lr_l
45 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
46 | elif opt.lr_policy == 'step':
47 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
48 | elif opt.lr_policy == 'plateau':
49 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
50 | elif opt.lr_policy == 'cosine':
51 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs, eta_min=0)
52 | else:
53 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
54 | return scheduler
55 |
56 |
57 | # update learning rate (called once every epoch)
58 | def update_learning_rate(scheduler, optimizer):
59 | scheduler.step()
60 | lr = optimizer.param_groups[0]['lr']
61 | # print('learning rate = %.7f' % lr)
62 |
63 |
64 | from torch.nn import Module, Sequential
65 | from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d
66 | from torch.nn import ReLU, Sigmoid
67 | import torch
68 |
69 |
70 | def build_net():
71 |
72 | from init import Options
73 | opt = Options().parse()
74 | from monai.networks.layers import Norm
75 |
76 | # create nn-Unet
77 | if opt.resolution is None:
78 | sizes, spacings = opt.patch_size, opt.spacing
79 | else:
80 | sizes, spacings = opt.patch_size, opt.resolution
81 |
82 | strides, kernels = [], []
83 |
84 | while True:
85 | spacing_ratio = [sp / min(spacings) for sp in spacings]
86 | stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
87 | kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
88 | if all(s == 1 for s in stride):
89 | break
90 | sizes = [i / j for i, j in zip(sizes, stride)]
91 | spacings = [i * j for i, j in zip(spacings, stride)]
92 | kernels.append(kernel)
93 | strides.append(stride)
94 | strides.insert(0, len(spacings) * [1])
95 | kernels.append(len(spacings) * [3])
96 |
97 | # # create Unet
98 |
99 | nn_Unet = monai.networks.nets.DynUNet(
100 | spatial_dims=3,
101 | in_channels=opt.in_channels,
102 | out_channels=opt.out_channels,
103 | kernel_size=kernels,
104 | strides=strides,
105 | upsample_kernel_size=strides[1:],
106 | res_block=True,
107 | )
108 |
109 | init_weights(nn_Unet, init_type='normal')
110 |
111 | return nn_Unet
112 |
113 |
114 | if __name__ == '__main__':
115 | import time
116 | import torch
117 | from torch.autograd import Variable
118 | from torchsummaryX import summary
119 | from torch.nn import init
120 |
121 | opt = Options().parse()
122 |
123 | torch.cuda.set_device(0)
124 | network = build_net()
125 | net = network.cuda().eval()
126 |
127 | data = Variable(torch.randn(1, int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda()
128 |
129 | out = net(data)
130 |
131 | torch.onnx.export(net, data, "Unet_model_graph.onnx")
132 |
133 | summary(net,data)
134 | print("out size: {}".format(out.size()))
135 |
136 |
137 |
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/monai 0.5.0/organize_folder_structure.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import argparse
4 | import SimpleITK as sitk
5 | import numpy as np
6 | import random
7 | from utils import *
8 |
9 |
10 | if __name__ == "__main__":
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--images', default='./Data_folder/CT', help='path to the images')
14 | parser.add_argument('--labels', default='./Data_folder/CT_label', help='path to the labels')
15 | parser.add_argument('--split_val', default=30, help='number of images for validation')
16 | parser.add_argument('--split_test', default=30, help='number of images for testing')
17 | parser.add_argument('--resolution', default=[2.25, 2.25, 3], help='New Resolution to resample the data to same spacing')
18 | parser.add_argument('--smooth', default=False, help='Set True if you want to smooth a bit the binary mask')
19 | args = parser.parse_args()
20 |
21 | list_images = lstFiles(args.images)
22 | list_labels = lstFiles(args.labels)
23 |
24 | mapIndexPosition = list(zip(list_images, list_labels)) # shuffle order list
25 | random.shuffle(mapIndexPosition)
26 | list_images, list_labels = zip(*mapIndexPosition)
27 |
28 | os.mkdir('./Data_folder/images')
29 | os.mkdir('./Data_folder/labels')
30 |
31 | # 1
32 | if not os.path.isdir('./Data_folder/images/train'):
33 | os.mkdir('./Data_folder/images/train/')
34 | # 2
35 | if not os.path.isdir('./Data_folder/images/val'):
36 | os.mkdir('./Data_folder/images/val')
37 |
38 | # 3
39 | if not os.path.isdir('./Data_folder/images/test'):
40 | os.mkdir('./Data_folder/images/test')
41 |
42 | # 4
43 | if not os.path.isdir('./Data_folder/labels/train'):
44 | os.mkdir('./Data_folder/labels/train')
45 |
46 | # 5
47 | if not os.path.isdir('./Data_folder/labels/val'):
48 | os.mkdir('./Data_folder/labels/val')
49 |
50 | # 6
51 | if not os.path.isdir('./Data_folder/labels/test'):
52 | os.mkdir('./Data_folder/labels/test')
53 |
54 | for i in range(len(list_images)-int(args.split_test + args.split_val)):
55 |
56 | a = list_images[int(args.split_test + args.split_val)+i]
57 | b = list_labels[int(args.split_test + args.split_val)+i]
58 |
59 | print('train',i, a,b)
60 |
61 | label = sitk.ReadImage(b)
62 | image = sitk.ReadImage(a)
63 |
64 | image = resample_sitk_image(image, spacing=args.resolution, interpolator='linear', fill_value=0)
65 | image, label = uniform_img_dimensions(image, label, nearest=True)
66 | if args.smooth is True:
67 | label = gaussian2(label)
68 |
69 | image_directory = os.path.join('./Data_folder/images/train', f"image{i:d}.nii")
70 | label_directory = os.path.join('./Data_folder/labels/train', f"label{i:d}.nii")
71 |
72 | sitk.WriteImage(image, image_directory)
73 | sitk.WriteImage(label, label_directory)
74 |
75 | for i in range(int(args.split_val)):
76 |
77 | a = list_images[int(args.split_test)+i]
78 | b = list_labels[int(args.split_test)+i]
79 |
80 | print('val',i, a,b)
81 |
82 | label = sitk.ReadImage(b)
83 | image = sitk.ReadImage(a)
84 |
85 | image = resample_sitk_image(image, spacing=args.resolution, interpolator='linear', fill_value=0)
86 | image, label = uniform_img_dimensions(image, label, nearest=True)
87 | if args.smooth is True:
88 | label = gaussian2(label)
89 |
90 | image_directory = os.path.join('./Data_folder/images/val', f"image{i:d}.nii")
91 | label_directory = os.path.join('./Data_folder/labels/val', f"label{i:d}.nii")
92 |
93 | sitk.WriteImage(image, image_directory)
94 | sitk.WriteImage(label, label_directory)
95 |
96 | for i in range(int(args.split_test)):
97 |
98 | a = list_images[i]
99 | b = list_labels[i]
100 |
101 | print('test',i,a,b)
102 |
103 | label = sitk.ReadImage(b)
104 | image = sitk.ReadImage(a)
105 |
106 | image = resample_sitk_image(image, spacing=args.resolution, interpolator='linear', fill_value=0)
107 | image, label = uniform_img_dimensions(image, label, nearest=True)
108 | if args.smooth is True:
109 | label = gaussian2(label)
110 |
111 | image_directory = os.path.join('./Data_folder/images/test', f"image{i:d}.nii")
112 | label_directory = os.path.join('./Data_folder/labels/test', f"label{i:d}.nii")
113 |
114 | sitk.WriteImage(image, image_directory)
115 | sitk.WriteImage(label, label_directory)
116 |
117 |
--------------------------------------------------------------------------------
/monai 0.5.0/predict_single_image.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | # -*- coding: utf-8 -*-
3 |
4 | from utils import *
5 | import argparse
6 | from networks import *
7 | from monai.inferers import sliding_window_inference
8 | from monai.metrics import DiceMetric
9 | from monai.data import NiftiSaver, create_test_image_3d, list_data_collate
10 |
11 |
12 | def segment(image, label, result, weights, resolution, patch_size, gpu_ids):
13 |
14 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
15 |
16 | if label is not None:
17 | uniform_img_dimensions_internal(image, label, True)
18 | files = [{"image": image, "label": label}]
19 | else:
20 | files = [{"image": image}]
21 |
22 | # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
23 | original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(image, resolution)
24 |
25 | # -------------------------------
26 |
27 | if label is not None:
28 | if resolution is not None:
29 |
30 | val_transforms = Compose([
31 | LoadImaged(keys=['image', 'label']),
32 | AddChanneld(keys=['image', 'label']),
33 | ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
34 | ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
35 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
36 |
37 | NormalizeIntensityd(keys=['image']), # intensity
38 | ScaleIntensityd(keys=['image']),
39 | Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution
40 |
41 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method= 'end'),
42 | ToTensord(keys=['image', 'label'])])
43 | else:
44 |
45 | val_transforms = Compose([
46 | LoadImaged(keys=['image', 'label']),
47 | AddChanneld(keys=['image', 'label']),
48 | ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
49 | ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
50 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
51 |
52 | NormalizeIntensityd(keys=['image']), # intensity
53 | ScaleIntensityd(keys=['image']),
54 |
55 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
56 | ToTensord(keys=['image', 'label'])])
57 |
58 | else:
59 | if resolution is not None:
60 |
61 | val_transforms = Compose([
62 | LoadImaged(keys=['image']),
63 | AddChanneld(keys=['image']),
64 | ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
65 | ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
66 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
67 |
68 | NormalizeIntensityd(keys=['image']), # intensity
69 | ScaleIntensityd(keys=['image']),
70 | Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution
71 |
72 | SpatialPadd(keys=['image'], spatial_size=patch_size, method= 'end'), # pad if the image is smaller than patch
73 | ToTensord(keys=['image'])])
74 | else:
75 |
76 | val_transforms = Compose([
77 | LoadImaged(keys=['image']),
78 | AddChanneld(keys=['image']),
79 | ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
80 | ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
81 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
82 |
83 | NormalizeIntensityd(keys=['image']), # intensity
84 | ScaleIntensityd(keys=['image']),
85 |
86 | SpatialPadd(keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
87 | ToTensord(keys=['image'])])
88 |
89 | val_ds = monai.data.Dataset(data=files, transform=val_transforms)
90 | val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=False)
91 |
92 | dice_metric = DiceMetric(include_background=True, reduction="mean")
93 | post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
94 |
95 | if gpu_ids != '-1':
96 |
97 | # try to use all the available GPUs
98 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids
99 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100 |
101 | else:
102 | device = torch.device("cpu")
103 |
104 | net = build_net()
105 | net = net.to(device)
106 |
107 | if gpu_ids == '-1':
108 |
109 | net.load_state_dict(new_state_dict_cpu(weights))
110 |
111 | else:
112 |
113 | net.load_state_dict(new_state_dict(weights))
114 |
115 | # define sliding window size and batch size for windows inference
116 | roi_size = patch_size
117 | sw_batch_size = 4
118 |
119 | net.eval()
120 | with torch.no_grad():
121 |
122 | if label is None:
123 | for val_data in val_loader:
124 | val_images = val_data["image"].cuda()
125 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
126 | val_outputs = post_trans(val_outputs)
127 | # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
128 |
129 | else:
130 | metric_sum = 0.0
131 | metric_count = 0
132 | for val_data in val_loader:
133 | val_images, val_labels = val_data["image"].cuda(), val_data["label"].cuda()
134 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
135 | val_outputs = post_trans(val_outputs)
136 | value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
137 | metric_count += len(value)
138 | metric_sum += value.item() * len(value)
139 | # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
140 |
141 | metric = metric_sum / metric_count
142 | print("Evaluation Metric (Dice):", metric)
143 |
144 | result_array = val_outputs.squeeze().data.cpu().numpy()
145 | # Remove the pad if the image was smaller than the patch in some directions
146 | result_array = result_array[0:resampled_size[0],0:resampled_size[1],0:resampled_size[2]]
147 |
148 | # resample back to the original resolution
149 | if resolution is not None:
150 |
151 | result_array_np = np.transpose(result_array, (2, 1, 0))
152 | result_array_temp = sitk.GetImageFromArray(result_array_np)
153 | result_array_temp.SetSpacing(resolution)
154 |
155 | # save temporary label
156 | writer = sitk.ImageFileWriter()
157 | writer.SetFileName('temp_seg.nii')
158 | writer.Execute(result_array_temp)
159 |
160 | files = [{"image": 'temp_seg.nii'}]
161 |
162 | files_transforms = Compose([
163 | LoadImaged(keys=['image']),
164 | AddChanneld(keys=['image']),
165 | Spacingd(keys=['image'], pixdim=original_resolution, mode=('nearest')),
166 | Resized(keys=['image'], spatial_size=crop_shape, mode=('nearest')),
167 | ])
168 |
169 | files_ds = Dataset(data=files, transform=files_transforms)
170 | files_loader = DataLoader(files_ds, batch_size=1, num_workers=0)
171 |
172 | for files_data in files_loader:
173 | files_images = files_data["image"]
174 |
175 | res = files_images.squeeze().data.numpy()
176 |
177 | result_array = np.rint(res)
178 |
179 | os.remove('./temp_seg.nii')
180 |
181 | # recover the cropped background before saving the image
182 | empty_array = np.zeros(original_shape)
183 | empty_array[coord1[0]:coord2[0],coord1[1]:coord2[1],coord1[2]:coord2[2]] = result_array
184 |
185 | result_seg = from_numpy_to_itk(empty_array, image)
186 |
187 | # save label
188 | writer = sitk.ImageFileWriter()
189 | writer.SetFileName(result)
190 | writer.Execute(result_seg)
191 | print("Saved Result at:", str(result))
192 |
193 |
194 | if __name__ == "__main__":
195 |
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument("--image", type=str, default='./Data_folder/CT/0.nii', help='source image' )
198 | parser.add_argument("--label", type=str, default=None, help='source label, if you want to compute dice. None for new case')
199 | parser.add_argument("--result", type=str, default='./Data_folder/test_0.nii', help='path to the .nii result to save')
200 | parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load')
201 | parser.add_argument("--resolution", default=[2.25, 2.25, 3], help='Resolution used in training phase')
202 | parser.add_argument("--patch_size", type=int, nargs=3, default=(160, 160, 32), help="Input dimension for the generator, same of training")
203 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
204 | args = parser.parse_args()
205 |
206 | segment(args.image, args.label, args.result, args.weights, args.resolution, args.patch_size, args.gpu_ids)
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
--------------------------------------------------------------------------------
/monai 0.5.0/requirements.txt:
--------------------------------------------------------------------------------
1 | simpleITK==2.1.0
2 | torchsummaryX
3 | nibabel
4 | pillow
5 | tensorboard
6 | gdown
7 | pytorch-ignite==0.4.4
8 | itk
9 | tqdm
10 | lmdb
11 | psutil
12 | pandas
13 | einops
14 | scikit-image
15 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | # from train import *
2 | from torch.nn import init
3 | from init import Options
4 | import monai
5 | from torch.optim import lr_scheduler
6 |
7 |
8 | def init_weights(net, init_type='normal', init_gain=0.02):
9 | """Initialize network weights.
10 | Parameters:
11 | net (network) -- network to be initialized
12 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
13 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
14 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
15 | work better for some applications. Feel free to try yourself.
16 | """
17 | def init_func(m): # define the initialization function
18 | classname = m.__class__.__name__
19 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
20 | if init_type == 'normal':
21 | init.normal_(m.weight.data, 0.0, init_gain)
22 | elif init_type == 'xavier':
23 | init.xavier_normal_(m.weight.data, gain=init_gain)
24 | elif init_type == 'kaiming':
25 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
26 | elif init_type == 'orthogonal':
27 | init.orthogonal_(m.weight.data, gain=init_gain)
28 | else:
29 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
30 | if hasattr(m, 'bias') and m.bias is not None:
31 | init.constant_(m.bias.data, 0.0)
32 | elif classname.find('BatchNorm3d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
33 | init.normal_(m.weight.data, 1.0, init_gain)
34 | init.constant_(m.bias.data, 0.0)
35 |
36 | # print('initialize network with %s' % init_type)
37 | net.apply(init_func) # apply the initialization function
38 |
39 |
40 | def get_scheduler(optimizer, opt):
41 | if opt.lr_policy == 'lambda':
42 | def lambda_rule(epoch):
43 | # lr_l = 1.0 - max(0, epoch + 1 - opt.epochs/2) / float(opt.epochs/2 + 1)
44 | lr_l = (1 - epoch / opt.epochs) ** 0.9
45 | return lr_l
46 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
47 | elif opt.lr_policy == 'step':
48 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
49 | elif opt.lr_policy == 'plateau':
50 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
51 | elif opt.lr_policy == 'cosine':
52 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs, eta_min=0)
53 | else:
54 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
55 | return scheduler
56 |
57 |
58 | # update learning rate (called once every epoch)
59 | def update_learning_rate(scheduler, optimizer):
60 | scheduler.step()
61 | lr = optimizer.param_groups[0]['lr']
62 | # print('learning rate = %.7f' % lr)
63 |
64 |
65 | from torch.nn import Module, Sequential
66 | from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d
67 | from torch.nn import ReLU, Sigmoid
68 | import torch
69 |
70 |
71 | def build_net():
72 |
73 | from init import Options
74 | opt = Options().parse()
75 | from monai.networks.layers import Norm
76 |
77 | # create nn-Unet
78 | if opt.resolution is None:
79 | sizes, spacings = opt.patch_size, opt.spacing
80 | else:
81 | sizes, spacings = opt.patch_size, opt.resolution
82 |
83 | strides, kernels = [], []
84 |
85 | while True:
86 | spacing_ratio = [sp / min(spacings) for sp in spacings]
87 | stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
88 | kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
89 | if all(s == 1 for s in stride):
90 | break
91 | sizes = [i / j for i, j in zip(sizes, stride)]
92 | spacings = [i * j for i, j in zip(spacings, stride)]
93 | kernels.append(kernel)
94 | strides.append(stride)
95 | strides.insert(0, len(spacings) * [1])
96 | kernels.append(len(spacings) * [3])
97 |
98 | # # create Unet
99 |
100 | nn_Unet = monai.networks.nets.DynUNet(
101 | spatial_dims=3,
102 | in_channels=opt.in_channels,
103 | out_channels=opt.out_channels,
104 | kernel_size=kernels,
105 | strides=strides,
106 | upsample_kernel_size=strides[1:],
107 | res_block=True,
108 | )
109 |
110 | init_weights(nn_Unet, init_type='normal')
111 |
112 | return nn_Unet
113 |
114 |
115 | def build_UNETR():
116 |
117 | from init import Options
118 | opt = Options().parse()
119 |
120 | # create UneTR
121 |
122 | UneTR = monai.networks.nets.UNETR(
123 | in_channels=opt.in_channels,
124 | out_channels=opt.out_channels,
125 | img_size=opt.patch_size,
126 | feature_size=32,
127 | hidden_size=768,
128 | mlp_dim=3072,
129 | num_heads=12,
130 | pos_embed="conv",
131 | norm_name="instance",
132 | res_block=True,
133 | dropout_rate=0.0,
134 | )
135 |
136 | init_weights(UneTR, init_type='normal')
137 |
138 | return UneTR
139 |
140 |
141 | if __name__ == '__main__':
142 | import time
143 | import torch
144 | from torch.autograd import Variable
145 | from torchsummaryX import summary
146 | from torch.nn import init
147 |
148 | opt = Options().parse()
149 |
150 | torch.cuda.set_device(0)
151 | # network = build_net()
152 | network = build_UNETR()
153 | net = network.cuda().eval()
154 |
155 | data = Variable(torch.randn(1, int(opt.in_channels), int(opt.patch_size[0]), int(opt.patch_size[1]), int(opt.patch_size[2]))).cuda()
156 |
157 | out = net(data)
158 |
159 | # torch.onnx.export(net, data, "Unet_model_graph.onnx")
160 |
161 | summary(net,data)
162 | print("out size: {}".format(out.size()))
163 |
164 |
165 |
166 |
167 |
168 |
169 |
--------------------------------------------------------------------------------
/organize_folder_structure.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import argparse
4 | import SimpleITK as sitk
5 | import numpy as np
6 | import random
7 | from utils import *
8 |
9 |
10 | if __name__ == "__main__":
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--images', default='./Data_folder/T2', help='path to the images')
14 | parser.add_argument('--labels', default='./Data_folder/T2_labels', help='path to the labels')
15 | parser.add_argument('--split_val', default=20, help='number of images for validation')
16 | parser.add_argument('--split_test', default=19, help='number of images for testing')
17 | parser.add_argument('--resolution', default=[0.7, 0.7, 3], help='New Resolution to resample the data to same spacing')
18 | parser.add_argument('--smooth', default=False, help='Set True if you want to smooth a bit the binary mask')
19 | args = parser.parse_args()
20 |
21 | list_images = lstFiles(args.images)
22 | list_labels = lstFiles(args.labels)
23 |
24 | mapIndexPosition = list(zip(list_images, list_labels)) # shuffle order list
25 | random.shuffle(mapIndexPosition)
26 | list_images, list_labels = zip(*mapIndexPosition)
27 |
28 | os.mkdir('./Data_folder/images')
29 | os.mkdir('./Data_folder/labels')
30 |
31 | # 1
32 | if not os.path.isdir('./Data_folder/images/train'):
33 | os.mkdir('./Data_folder/images/train/')
34 | # 2
35 | if not os.path.isdir('./Data_folder/images/val'):
36 | os.mkdir('./Data_folder/images/val')
37 |
38 | # 3
39 | if not os.path.isdir('./Data_folder/images/test'):
40 | os.mkdir('./Data_folder/images/test')
41 |
42 | # 4
43 | if not os.path.isdir('./Data_folder/labels/train'):
44 | os.mkdir('./Data_folder/labels/train')
45 |
46 | # 5
47 | if not os.path.isdir('./Data_folder/labels/val'):
48 | os.mkdir('./Data_folder/labels/val')
49 |
50 | # 6
51 | if not os.path.isdir('./Data_folder/labels/test'):
52 | os.mkdir('./Data_folder/labels/test')
53 |
54 | for i in range(len(list_images)-int(args.split_test + args.split_val)):
55 |
56 | a = list_images[int(args.split_test + args.split_val)+i]
57 | b = list_labels[int(args.split_test + args.split_val)+i]
58 |
59 | print('train',i, a,b)
60 |
61 | label = sitk.ReadImage(b)
62 | image = sitk.ReadImage(a)
63 |
64 | image = resample_sitk_image(image, spacing=args.resolution, interpolator='linear', fill_value=0)
65 | image, label = uniform_img_dimensions(image, label, nearest=True)
66 | if args.smooth is True:
67 | label = gaussian2(label)
68 |
69 | image_directory = os.path.join('./Data_folder/images/train', f"image{i:d}.nii")
70 | label_directory = os.path.join('./Data_folder/labels/train', f"label{i:d}.nii")
71 |
72 | sitk.WriteImage(image, image_directory)
73 | sitk.WriteImage(label, label_directory)
74 |
75 | for i in range(int(args.split_val)):
76 |
77 | a = list_images[int(args.split_test)+i]
78 | b = list_labels[int(args.split_test)+i]
79 |
80 | print('val',i, a,b)
81 |
82 | label = sitk.ReadImage(b)
83 | image = sitk.ReadImage(a)
84 |
85 | image = resample_sitk_image(image, spacing=args.resolution, interpolator='linear', fill_value=0)
86 | image, label = uniform_img_dimensions(image, label, nearest=True)
87 | if args.smooth is True:
88 | label = gaussian2(label)
89 |
90 | image_directory = os.path.join('./Data_folder/images/val', f"image{i:d}.nii")
91 | label_directory = os.path.join('./Data_folder/labels/val', f"label{i:d}.nii")
92 |
93 | sitk.WriteImage(image, image_directory)
94 | sitk.WriteImage(label, label_directory)
95 |
96 | for i in range(int(args.split_test)):
97 |
98 | a = list_images[i]
99 | b = list_labels[i]
100 |
101 | print('test',i,a,b)
102 |
103 | label = sitk.ReadImage(b)
104 | image = sitk.ReadImage(a)
105 |
106 | image = resample_sitk_image(image, spacing=args.resolution, interpolator='linear', fill_value=0)
107 | image, label = uniform_img_dimensions(image, label, nearest=True)
108 | if args.smooth is True:
109 | label = gaussian2(label)
110 |
111 | image_directory = os.path.join('./Data_folder/images/test', f"image{i:d}.nii")
112 | label_directory = os.path.join('./Data_folder/labels/test', f"label{i:d}.nii")
113 |
114 | sitk.WriteImage(image, image_directory)
115 | sitk.WriteImage(label, label_directory)
116 |
117 |
--------------------------------------------------------------------------------
/predict_single_image.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | # -*- coding: utf-8 -*-
3 |
4 | from utils import *
5 | import argparse
6 | from networks import build_net, build_UNETR
7 | from monai.inferers import sliding_window_inference
8 | from monai.metrics import DiceMetric
9 | from monai.data import NiftiSaver, create_test_image_3d, list_data_collate, decollate_batch
10 | from monai.transforms import (EnsureType, Compose, LoadImaged, AddChanneld, Transpose,Activations,AsDiscrete, RandGaussianSmoothd, CropForegroundd, SpatialPadd,
11 | ScaleIntensityd, ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, RandZoomd,
12 | Spacingd, Orientationd, Resized, ThresholdIntensityd, RandShiftIntensityd, BorderPadd, RandGaussianNoised, RandAdjustContrastd,NormalizeIntensityd,RandFlipd)
13 |
14 |
15 | def segment(image, label, result, weights, resolution, patch_size, network, gpu_ids):
16 |
17 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
18 |
19 | if label is not None:
20 | uniform_img_dimensions_internal(image, label, True)
21 | files = [{"image": image, "label": label}]
22 | else:
23 | files = [{"image": image}]
24 |
25 | # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
26 | original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(image, resolution)
27 |
28 | # -------------------------------
29 |
30 | if label is not None:
31 | if resolution is not None:
32 |
33 | val_transforms = Compose([
34 | LoadImaged(keys=['image', 'label']),
35 | AddChanneld(keys=['image', 'label']),
36 | # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
37 | # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
38 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
39 |
40 | NormalizeIntensityd(keys=['image']), # intensity
41 | ScaleIntensityd(keys=['image']),
42 | Spacingd(keys=['image', 'label'], pixdim=resolution, mode=('bilinear', 'nearest')), # resolution
43 |
44 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method= 'end'),
45 | ToTensord(keys=['image', 'label'])])
46 | else:
47 |
48 | val_transforms = Compose([
49 | LoadImaged(keys=['image', 'label']),
50 | AddChanneld(keys=['image', 'label']),
51 | # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
52 | # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
53 | CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground
54 |
55 | NormalizeIntensityd(keys=['image']), # intensity
56 | ScaleIntensityd(keys=['image']),
57 |
58 | SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
59 | ToTensord(keys=['image', 'label'])])
60 |
61 | else:
62 | if resolution is not None:
63 |
64 | val_transforms = Compose([
65 | LoadImaged(keys=['image']),
66 | AddChanneld(keys=['image']),
67 | # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
68 | # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
69 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
70 |
71 | NormalizeIntensityd(keys=['image']), # intensity
72 | ScaleIntensityd(keys=['image']),
73 | Spacingd(keys=['image'], pixdim=resolution, mode=('bilinear')), # resolution
74 |
75 | SpatialPadd(keys=['image'], spatial_size=patch_size, method= 'end'), # pad if the image is smaller than patch
76 | ToTensord(keys=['image'])])
77 | else:
78 |
79 | val_transforms = Compose([
80 | LoadImaged(keys=['image']),
81 | AddChanneld(keys=['image']),
82 | # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # Threshold CT
83 | # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
84 | CropForegroundd(keys=['image'], source_key='image'), # crop CropForeground
85 |
86 | NormalizeIntensityd(keys=['image']), # intensity
87 | ScaleIntensityd(keys=['image']),
88 |
89 | SpatialPadd(keys=['image'], spatial_size=patch_size, method='end'), # pad if the image is smaller than patch
90 | ToTensord(keys=['image'])])
91 |
92 | val_ds = monai.data.Dataset(data=files, transform=val_transforms)
93 | val_loader = DataLoader(val_ds, batch_size=1, num_workers=0, collate_fn=list_data_collate, pin_memory=False)
94 |
95 | dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
96 | post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
97 |
98 | if gpu_ids != '-1':
99 |
100 | # try to use all the available GPUs
101 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids
102 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103 |
104 | else:
105 | device = torch.device("cpu")
106 |
107 | # build the network
108 | if network == 'nnunet':
109 | net = build_net() # nn build_net
110 | elif network == 'unetr':
111 | net = build_UNETR() # UneTR
112 |
113 | net = net.to(device)
114 |
115 | if gpu_ids == '-1':
116 |
117 | net.load_state_dict(new_state_dict_cpu(weights))
118 |
119 | else:
120 |
121 | net.load_state_dict(new_state_dict(weights))
122 |
123 | # define sliding window size and batch size for windows inference
124 | roi_size = patch_size
125 | sw_batch_size = 4
126 |
127 | net.eval()
128 | with torch.no_grad():
129 |
130 | if label is None:
131 | for val_data in val_loader:
132 | val_images = val_data["image"].to(device)
133 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
134 | val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
135 |
136 | else:
137 | for val_data in val_loader:
138 | val_images, val_labels = val_data["image"].to(device), val_data["label"].to(device)
139 | val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
140 | val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
141 | dice_metric(y_pred=val_outputs, y=val_labels)
142 |
143 | metric = dice_metric.aggregate().item()
144 | print("Evaluation Metric (Dice):", metric)
145 |
146 | result_array = val_outputs[0].squeeze().data.cpu().numpy()
147 | # Remove the pad if the image was smaller than the patch in some directions
148 | result_array = result_array[0:resampled_size[0],0:resampled_size[1],0:resampled_size[2]]
149 |
150 | # resample back to the original resolution
151 | if resolution is not None:
152 |
153 | result_array_np = np.transpose(result_array, (2, 1, 0))
154 | result_array_temp = sitk.GetImageFromArray(result_array_np)
155 | result_array_temp.SetSpacing(resolution)
156 |
157 | # save temporary label
158 | writer = sitk.ImageFileWriter()
159 | writer.SetFileName('temp_seg.nii')
160 | writer.Execute(result_array_temp)
161 |
162 | files = [{"image": 'temp_seg.nii'}]
163 |
164 | files_transforms = Compose([
165 | LoadImaged(keys=['image']),
166 | AddChanneld(keys=['image']),
167 | Spacingd(keys=['image'], pixdim=original_resolution, mode=('nearest')),
168 | Resized(keys=['image'], spatial_size=crop_shape, mode=('nearest')),
169 | ])
170 |
171 | files_ds = Dataset(data=files, transform=files_transforms)
172 | files_loader = DataLoader(files_ds, batch_size=1, num_workers=0)
173 |
174 | for files_data in files_loader:
175 | files_images = files_data["image"]
176 |
177 | res = files_images.squeeze().data.numpy()
178 |
179 | result_array = np.rint(res)
180 |
181 | os.remove('./temp_seg.nii')
182 |
183 | # recover the cropped background before saving the image
184 | empty_array = np.zeros(original_shape)
185 | empty_array[coord1[0]:coord2[0],coord1[1]:coord2[1],coord1[2]:coord2[2]] = result_array
186 |
187 | result_seg = from_numpy_to_itk(empty_array, image)
188 |
189 | # save label
190 | writer = sitk.ImageFileWriter()
191 | writer.SetFileName(result)
192 | writer.Execute(result_seg)
193 | print("Saved Result at:", str(result))
194 |
195 |
196 | if __name__ == "__main__":
197 |
198 | parser = argparse.ArgumentParser()
199 | parser.add_argument("--image", type=str, default='./Data_folder/T2/3.nii', help='source image' )
200 | parser.add_argument("--label", type=str, default='./Data_folder/T2_labels/3.nii', help='source label, if you want to compute dice. None for new case')
201 | parser.add_argument("--result", type=str, default='./Data_folder/test_0.nii', help='path to the .nii result to save')
202 | parser.add_argument("--weights", type=str, default='./best_metric_model.pth', help='network weights to load')
203 | parser.add_argument("--resolution", default=[0.7, 0.7, 3], help='Resolution used in training phase')
204 | parser.add_argument("--patch_size", type=int, nargs=3, default=(256, 256, 16), help="Input dimension for the generator, same of training")
205 | parser.add_argument('--network', default='unetr', help='nnunet, unetr')
206 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
207 | args = parser.parse_args()
208 |
209 | segment(args.image, args.label, args.result, args.weights, args.resolution, args.patch_size, args.network, args.gpu_ids)
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | simpleITK==2.1.0
2 | torchsummaryX
3 | nibabel
4 | pillow
5 | tensorboard
6 | gdown
7 | pytorch-ignite==0.4.4
8 | itk
9 | tqdm
10 | lmdb
11 | psutil
12 | pandas
13 | einops
14 | scikit-image
15 |
--------------------------------------------------------------------------------