├── LICENSE ├── README.md ├── apple_mps_test.py ├── images ├── graphs.PNG └── liver_segmentation.PNG ├── preporcess.py ├── requirements.txt ├── testing.ipynb ├── train.py └── utilities.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Mokhtari Mohammed El Amine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![GitHub issues](https://img.shields.io/github/issues/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch)](https://github.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch/issues) [![GitHub stars](https://img.shields.io/github/stars/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch)](https://github.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch/stargazers) [![GitHub license](https://img.shields.io/github/license/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch)](https://github.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch) [![GitHub forks](https://img.shields.io/github/forks/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch)](https://github.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch/network) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch) [![YouTube Video Views](https://img.shields.io/youtube/views/AU4KlXKKnac?style=social)](https://youtu.be/AU4KlXKKnac) ![GitHub watchers](https://img.shields.io/github/watchers/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch?style=social) 2 | # Liver Segmentation Using Monai and PyTorch 3 | You'll find all the Python files you need to accomplish liver segmentation with Monai and PyTorch in this repo, and you can use the same code to segment other organs as well. 4 | 5 | Link to the original course [here](https://www.learn.pycad.co/course/liver-segmentation). 6 | 7 | ![Output image](https://github.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch/blob/main/images/liver_segmentation.PNG) 8 | 9 | So do this project, you will find some scripts that I wrote by myself and others that I took from Monai's tutorials. For this reason you need to take a look to their original repo and [website](https://monai.io/) to get more information. 10 | 11 | ## Cloning the repo 12 | You can start by cloning this repo in your wordspace and then start playing with the function to make your project done. 13 | ``` 14 | git clone https://github.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch 15 | ``` 16 | ``` 17 | cd ./Liver-Segmentation-Using-Monai-and-PyTorch 18 | ``` 19 | ## Packages that need to be installed: 20 | ``` 21 | pip install monai 22 | ``` 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | ## Showing a patient from the dataset 27 | Some of the most common queries I had while utilizing medical imaging were regarding how to present a patient. To address this, I created explicit scripts for how to show a patient from the training and testing datasets, which you can see here. 28 | 29 | ```Python 30 | def show_patient(data, SLICE_NUMBER=1, train=True, test=False): 31 | """ 32 | This function is to show one patient from your datasets, so that you can si if the it is okay or you need 33 | to change/delete something. 34 | `data`: this parameter should take the patients from the data loader, which means you need to can the function 35 | prepare first and apply the transforms that you want after that pass it to this function so that you visualize 36 | the patient with the transforms that you want. 37 | `SLICE_NUMBER`: this parameter will take the slice number that you want to display/show 38 | `train`: this parameter is to say that you want to display a patient from the training data (by default it is true) 39 | `test`: this parameter is to say that you want to display a patient from the testing patients. 40 | """ 41 | 42 | check_patient_train, check_patient_test = data 43 | 44 | view_train_patient = first(check_patient_train) 45 | view_test_patient = first(check_patient_test) 46 | 47 | 48 | if train: 49 | plt.figure("Visualization Train", (12, 6)) 50 | plt.subplot(1, 2, 1) 51 | plt.title(f"vol {SLICE_NUMBER}") 52 | plt.imshow(view_train_patient["vol"][0, 0, :, :, SLICE_NUMBER], cmap="gray") 53 | 54 | plt.subplot(1, 2, 2) 55 | plt.title(f"seg {SLICE_NUMBER}") 56 | plt.imshow(view_train_patient["seg"][0, 0, :, :, SLICE_NUMBER]) 57 | plt.show() 58 | 59 | if test: 60 | plt.figure("Visualization Test", (12, 6)) 61 | plt.subplot(1, 2, 1) 62 | plt.title(f"vol {SLICE_NUMBER}") 63 | plt.imshow(view_test_patient["vol"][0, 0, :, :, SLICE_NUMBER], cmap="gray") 64 | 65 | plt.subplot(1, 2, 2) 66 | plt.title(f"seg {SLICE_NUMBER}") 67 | plt.imshow(view_test_patient["seg"][0, 0, :, :, SLICE_NUMBER]) 68 | plt.show() 69 | 70 | ``` 71 | 72 | But before calling this function, you need to do the preprocess to your data, in fact this function will help you to visualize your patients after applying the different transforms so that you will know if you need to change some parameters or not. 73 | The function that does the preprocess can be found in the `preprocess.py` file and in that file you will find the function `prepare()` that you can use for the preprocess. 74 | 75 | ## Training 76 | After understanding how to do the preprocess you can start import the `3D Unet` from monai and defining the parameters of the model (dimensions, input channels, output channels...). 77 | 78 | ```Python 79 | model = UNet( 80 | dimensions=3, 81 | in_channels=1, 82 | out_channels=2, 83 | channels=(16, 32, 64, 128, 256), 84 | strides=(2, 2, 2, 2), 85 | num_res_units=2, 86 | norm=Norm.BATCH, 87 | ).to(device) 88 | ``` 89 | 90 | And to run the code, you can use the scripts `train.py` that will call the train function that I have created using the same principal used in Monai's tutorials. 91 | 92 | ## Testing the model 93 | To test the model, there is the jupyter notebook `testing.ipynb` file that contains the different codes that you need. You will find the part to plot the training/testing graphs about the loss and the dice coefficient and of course you will find the the part to show the results of one of the test data to see the output of your model. 94 | 95 | ![Output image](https://github.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch/blob/main/images/graphs.PNG) 96 | 97 | ---------------------------------------------------------------------------------------------------------------------------------- 98 | Before using the code, I recommend that you watch my course, in which I explain everything in this repo, or at the very least read my blog entries, in which I explain how to use the various scripts so that you don't get confused. 99 | 100 | You can read about the tutorial in my blog post series starting by [this one.](https://pycad.co/liver-segmentation-part-1/) 101 | 102 | ## Conversion tools 103 | 104 | For the different preparations that I talked about in the course/blogs I have create a simple GUI that help you to do the conversion in a few clicks, please see more information at [this link](https://pycad.co/pycad-convert/). 105 | 106 | ![154864750-c55a3129-67c7-438a-8549-e2c45c433048](https://user-images.githubusercontent.com/37108394/156251291-a0911b63-41b6-4c8a-820b-a9bfec5e452b.png) 107 | 108 | ## 📩 Newsletter 109 | Stay up-to-date on the latest in computer vision and medical imaging! Subscribe to my newsletter now for insights and analysis on the cutting-edge developments in this exciting field. 110 | 111 | https://pycad.co/join-us/ 112 | 113 | ## 🆕 NEW 114 | 115 | Learn how to effectively manage and process DICOM files in Python with our comprehensive course, designed to equip you with the skills and knowledge you need to succeed. 116 | 117 | https://www.learn.pycad.co/course/dicom-simplified 118 | 119 | -------------------------------------------------------------------------------- /apple_mps_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | if not torch.backends.mps.is_available(): 4 | if not torch.backends.mps.is_built(): 5 | print("MPS not available because the current PyTorch install was not " 6 | "built with MPS enabled.") 7 | else: 8 | print("MPS not available because the current MacOS version is not 12.3+ " 9 | "and/or you do not have an MPS-enabled device on this machine.") 10 | 11 | else: 12 | print("All ok") -------------------------------------------------------------------------------- /images/graphs.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch/8a29106bff02f9cbae9c41098929dd2ee9c3f1a5/images/graphs.PNG -------------------------------------------------------------------------------- /images/liver_segmentation.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amine0110/Liver-Segmentation-Using-Monai-and-PyTorch/8a29106bff02f9cbae9c41098929dd2ee9c3f1a5/images/liver_segmentation.PNG -------------------------------------------------------------------------------- /preporcess.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | #import shutil 4 | #from tqdm import tqdm 5 | #import dicom2nifti 6 | #import numpy as np 7 | #import nibabel as nib 8 | from monai.transforms import ( 9 | Compose, 10 | EnsureChannelFirstD, 11 | LoadImaged, 12 | Resized, 13 | ToTensord, 14 | Spacingd, 15 | Orientationd, 16 | ScaleIntensityRanged, 17 | CropForegroundd, 18 | 19 | ) 20 | from monai.data import DataLoader, Dataset, CacheDataset 21 | from monai.utils import set_determinism 22 | 23 | 24 | def prepare(in_dir, pixdim=(1.5, 1.5, 1.0), a_min=-200, a_max=200, spatial_size=[128, 128, 64], cache=True): 25 | """ 26 | This function is for preprocessing, it contains only the basic transforms, but you can add more operations that you 27 | find in the Monai documentation. 28 | https://monai.io/docs.html 29 | """ 30 | 31 | set_determinism(seed=0) 32 | 33 | path_train_volumes = sorted(glob(os.path.join(in_dir, "TrainVolumes", "*.nii.gz"))) 34 | path_train_segmentation = sorted(glob(os.path.join(in_dir, "TrainSegmentation", "*.nii.gz"))) 35 | 36 | path_test_volumes = sorted(glob(os.path.join(in_dir, "TestVolumes", "*.nii.gz"))) 37 | path_test_segmentation = sorted(glob(os.path.join(in_dir, "TestSegmentation", "*.nii.gz"))) 38 | 39 | train_files = [{"vol": image_name, "seg": label_name} for image_name, label_name in 40 | zip(path_train_volumes, path_train_segmentation)] 41 | test_files = [{"vol": image_name, "seg": label_name} for image_name, label_name in 42 | zip(path_test_volumes, path_test_segmentation)] 43 | 44 | train_transforms = Compose( 45 | [ 46 | LoadImaged(keys=["vol", "seg"]), 47 | EnsureChannelFirstD(keys=["vol", "seg"]), 48 | Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")), 49 | Orientationd(keys=["vol", "seg"], axcodes="RAS"), 50 | ScaleIntensityRanged(keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True), 51 | CropForegroundd(keys=["vol", "seg"], source_key="vol"), 52 | Resized(keys=["vol", "seg"], spatial_size=spatial_size), 53 | ToTensord(keys=["vol", "seg"]), 54 | 55 | ] 56 | ) 57 | 58 | test_transforms = Compose( 59 | [ 60 | LoadImaged(keys=["vol", "seg"]), 61 | EnsureChannelFirstD(keys=["vol", "seg"]), 62 | Spacingd(keys=["vol", "seg"], pixdim=pixdim, mode=("bilinear", "nearest")), 63 | Orientationd(keys=["vol", "seg"], axcodes="RAS"), 64 | ScaleIntensityRanged(keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True), 65 | CropForegroundd(keys=['vol', 'seg'], source_key='vol'), 66 | Resized(keys=["vol", "seg"], spatial_size=spatial_size), 67 | ToTensord(keys=["vol", "seg"]), 68 | 69 | ] 70 | ) 71 | 72 | if cache: 73 | train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0) 74 | train_loader = DataLoader(train_ds, batch_size=1) 75 | 76 | test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_rate=1.0) 77 | test_loader = DataLoader(test_ds, batch_size=1) 78 | 79 | return train_loader, test_loader 80 | 81 | else: 82 | train_ds = Dataset(data=train_files, transform=train_transforms) 83 | train_loader = DataLoader(train_ds, batch_size=1) 84 | 85 | test_ds = Dataset(data=test_files, transform=test_transforms) 86 | test_loader = DataLoader(test_ds, batch_size=1) 87 | 88 | return train_loader, test_loader 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | tqdm 4 | glob2 5 | dicom2nifti 6 | pytest-shutil 7 | nibabel 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from monai.networks.nets import UNet 2 | from monai.networks.layers import Norm 3 | from monai.losses import DiceLoss, DiceCELoss 4 | 5 | import torch 6 | from preporcess import prepare 7 | from utilities import train 8 | 9 | 10 | data_dir = 'D:/Youtube/Organ and Tumor Segmentation/datasets/Task03_Liver/Data_Train_Test' 11 | model_dir = 'D:/Youtube/Organ and Tumor Segmentation/results/results' 12 | data_in = prepare(data_dir, cache=True) 13 | 14 | device = torch.device("cuda:0") 15 | model = UNet( 16 | dimensions=3, 17 | in_channels=1, 18 | out_channels=2, 19 | channels=(16, 32, 64, 128, 256), 20 | strides=(2, 2, 2, 2), 21 | num_res_units=2, 22 | norm=Norm.BATCH, 23 | ).to(device) 24 | 25 | 26 | #loss_function = DiceCELoss(to_onehot_y=True, sigmoid=True, squared_pred=True, ce_weight=calculate_weights(1792651250,2510860).to(device)) 27 | loss_function = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True) 28 | optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=1e-5, amsgrad=True) 29 | 30 | if __name__ == '__main__': 31 | train(model, data_in, loss_function, optimizer, 600, model_dir) -------------------------------------------------------------------------------- /utilities.py: -------------------------------------------------------------------------------- 1 | from monai.utils import first 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import os 5 | import numpy as np 6 | from monai.losses import DiceLoss 7 | from tqdm import tqdm 8 | 9 | def dice_metric(predicted, target): 10 | ''' 11 | In this function we take `predicted` and `target` (label) to calculate the dice coeficient then we use it 12 | to calculate a metric value for the training and the validation. 13 | ''' 14 | dice_value = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True) 15 | value = 1 - dice_value(predicted, target).item() 16 | return value 17 | 18 | def calculate_weights(val1, val2): 19 | ''' 20 | In this function we take the number of the background and the forgroud pixels to return the `weights` 21 | for the cross entropy loss values. 22 | ''' 23 | count = np.array([val1, val2]) 24 | summ = count.sum() 25 | weights = count/summ 26 | weights = 1/weights 27 | summ = weights.sum() 28 | weights = weights/summ 29 | return torch.tensor(weights, dtype=torch.float32) 30 | 31 | def train(model, data_in, loss, optim, max_epochs, model_dir, test_interval=1 , device=torch.device("cuda:0")): 32 | best_metric = -1 33 | best_metric_epoch = -1 34 | save_loss_train = [] 35 | save_loss_test = [] 36 | save_metric_train = [] 37 | save_metric_test = [] 38 | train_loader, test_loader = data_in 39 | 40 | for epoch in range(max_epochs): 41 | print("-" * 10) 42 | print(f"epoch {epoch + 1}/{max_epochs}") 43 | model.train() 44 | train_epoch_loss = 0 45 | train_step = 0 46 | epoch_metric_train = 0 47 | for batch_data in train_loader: 48 | 49 | train_step += 1 50 | 51 | volume = batch_data["vol"] 52 | label = batch_data["seg"] 53 | label = label != 0 54 | volume, label = (volume.to(device), label.to(device)) 55 | 56 | optim.zero_grad() 57 | outputs = model(volume) 58 | 59 | train_loss = loss(outputs, label) 60 | 61 | train_loss.backward() 62 | optim.step() 63 | 64 | train_epoch_loss += train_loss.item() 65 | print( 66 | f"{train_step}/{len(train_loader) // train_loader.batch_size}, " 67 | f"Train_loss: {train_loss.item():.4f}") 68 | 69 | train_metric = dice_metric(outputs, label) 70 | epoch_metric_train += train_metric 71 | print(f'Train_dice: {train_metric:.4f}') 72 | 73 | print('-'*20) 74 | 75 | train_epoch_loss /= train_step 76 | print(f'Epoch_loss: {train_epoch_loss:.4f}') 77 | save_loss_train.append(train_epoch_loss) 78 | np.save(os.path.join(model_dir, 'loss_train.npy'), save_loss_train) 79 | 80 | epoch_metric_train /= train_step 81 | print(f'Epoch_metric: {epoch_metric_train:.4f}') 82 | 83 | save_metric_train.append(epoch_metric_train) 84 | np.save(os.path.join(model_dir, 'metric_train.npy'), save_metric_train) 85 | 86 | if (epoch + 1) % test_interval == 0: 87 | 88 | model.eval() 89 | with torch.no_grad(): 90 | test_epoch_loss = 0 91 | test_metric = 0 92 | epoch_metric_test = 0 93 | test_step = 0 94 | 95 | for test_data in test_loader: 96 | 97 | test_step += 1 98 | 99 | test_volume = test_data["vol"] 100 | test_label = test_data["seg"] 101 | test_label = test_label != 0 102 | test_volume, test_label = (test_volume.to(device), test_label.to(device),) 103 | 104 | test_outputs = model(test_volume) 105 | 106 | test_loss = loss(test_outputs, test_label) 107 | test_epoch_loss += test_loss.item() 108 | test_metric = dice_metric(test_outputs, test_label) 109 | epoch_metric_test += test_metric 110 | 111 | 112 | test_epoch_loss /= test_step 113 | print(f'test_loss_epoch: {test_epoch_loss:.4f}') 114 | save_loss_test.append(test_epoch_loss) 115 | np.save(os.path.join(model_dir, 'loss_test.npy'), save_loss_test) 116 | 117 | epoch_metric_test /= test_step 118 | print(f'test_dice_epoch: {epoch_metric_test:.4f}') 119 | save_metric_test.append(epoch_metric_test) 120 | np.save(os.path.join(model_dir, 'metric_test.npy'), save_metric_test) 121 | 122 | if epoch_metric_test > best_metric: 123 | best_metric = epoch_metric_test 124 | best_metric_epoch = epoch + 1 125 | torch.save(model.state_dict(), os.path.join( 126 | model_dir, "best_metric_model.pth")) 127 | 128 | print( 129 | f"current epoch: {epoch + 1} current mean dice: {test_metric:.4f}" 130 | f"\nbest mean dice: {best_metric:.4f} " 131 | f"at epoch: {best_metric_epoch}" 132 | ) 133 | 134 | 135 | print( 136 | f"train completed, best_metric: {best_metric:.4f} " 137 | f"at epoch: {best_metric_epoch}") 138 | 139 | 140 | def show_patient(data, SLICE_NUMBER=1, train=True, test=False): 141 | """ 142 | This function is to show one patient from your datasets, so that you can si if the it is okay or you need 143 | to change/delete something. 144 | 145 | `data`: this parameter should take the patients from the data loader, which means you need to can the function 146 | prepare first and apply the transforms that you want after that pass it to this function so that you visualize 147 | the patient with the transforms that you want. 148 | `SLICE_NUMBER`: this parameter will take the slice number that you want to display/show 149 | `train`: this parameter is to say that you want to display a patient from the training data (by default it is true) 150 | `test`: this parameter is to say that you want to display a patient from the testing patients. 151 | """ 152 | 153 | check_patient_train, check_patient_test = data 154 | 155 | view_train_patient = first(check_patient_train) 156 | view_test_patient = first(check_patient_test) 157 | 158 | 159 | if train: 160 | plt.figure("Visualization Train", (12, 6)) 161 | plt.subplot(1, 2, 1) 162 | plt.title(f"vol {SLICE_NUMBER}") 163 | plt.imshow(view_train_patient["vol"][0, 0, :, :, SLICE_NUMBER], cmap="gray") 164 | 165 | plt.subplot(1, 2, 2) 166 | plt.title(f"seg {SLICE_NUMBER}") 167 | plt.imshow(view_train_patient["seg"][0, 0, :, :, SLICE_NUMBER]) 168 | plt.show() 169 | 170 | if test: 171 | plt.figure("Visualization Test", (12, 6)) 172 | plt.subplot(1, 2, 1) 173 | plt.title(f"vol {SLICE_NUMBER}") 174 | plt.imshow(view_test_patient["vol"][0, 0, :, :, SLICE_NUMBER], cmap="gray") 175 | 176 | plt.subplot(1, 2, 2) 177 | plt.title(f"seg {SLICE_NUMBER}") 178 | plt.imshow(view_test_patient["seg"][0, 0, :, :, SLICE_NUMBER]) 179 | plt.show() 180 | 181 | 182 | def calculate_pixels(data): 183 | val = np.zeros((1, 2)) 184 | 185 | for batch in tqdm(data): 186 | batch_label = batch["seg"] != 0 187 | _, count = np.unique(batch_label, return_counts=True) 188 | 189 | if len(count) == 1: 190 | count = np.append(count, 0) 191 | val += count 192 | 193 | print('The last values:', val) 194 | return val 195 | --------------------------------------------------------------------------------