├── .gitignore ├── Data └── MG_scan_test.nii.gz ├── Docker └── Dockerfile ├── MULTI_SEG ├── src │ ├── Sort_New_data.py │ ├── compute_metrics.py │ ├── correct_file.py │ ├── data_split_csv.py │ ├── init_training_data.py │ ├── merge_seg.py │ ├── models.py │ ├── post_process_test.py │ ├── predict_CBCTSeg.py │ ├── rescall_all.py │ ├── train_CBCTseg.py │ └── utils.py └── vtkToSTL.py ├── README.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.DS_Store 3 | *.pyc 4 | -------------------------------------------------------------------------------- /Data/MG_scan_test.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Maxlo24/AMASSS_CBCT/43f0c69fa68894effcfeb23679083a7cf694888e/Data/MG_scan_test.nii.gz -------------------------------------------------------------------------------- /Docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y git \ 5 | # apt-get install -y nvidia-container-toolkit \ 6 | # libx11-6 \ 7 | # libgl1 \ 8 | # libopengl0 \ 9 | # libegl1 \ 10 | wget\ 11 | -y unzip 12 | 13 | RUN mkdir /app 14 | RUN mkdir data 15 | WORKDIR /app/data 16 | 17 | ARG MODEL_VERSION="1.0.2-beta.1" 18 | RUN wget https://github.com/Maxlo24/AMASSS_CBCT/releases/download/v$MODEL_VERSION/ALL_MODELS.zip 19 | 20 | RUN unzip ALL_MODELS.zip 21 | RUN rm -rf ALL_MODELS.zip 22 | 23 | 24 | ARG RELEASE_VERSION="1.0.2-beta.1" 25 | WORKDIR /app 26 | RUN wget https://github.com/Maxlo24/AMASSS_CBCT/archive/refs/tags/v$RELEASE_VERSION.zip 27 | RUN unzip v$RELEASE_VERSION.zip 28 | # RUN unzip AMASSS_CBCT-$RELEASE.zip 29 | # RUN unzip AMASSS_CBCT-1.0.0-alpha.zip 30 | RUN mv AMASSS_CBCT-$RELEASE_VERSION/MULTI_SEG MULTI_SEG 31 | # RUN mv /app/$RELEASE_VERSION/Data /app/Data 32 | RUN rm -rf v$RELEASE_VERSION.zip \ 33 | rm -rf v$RELEASE_VERSION 34 | 35 | RUN pip install -r MULTI_SEG/requirements.txt 36 | 37 | RUN apt-get update 38 | RUN apt-get install ffmpeg libsm6 libxext6 -y 39 | 40 | RUN mkdir scans 41 | -------------------------------------------------------------------------------- /MULTI_SEG/src/Sort_New_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import sys 4 | import os 5 | import shutil 6 | from utils import SetSpacing,KeepLabel 7 | 8 | def main(args): 9 | 10 | print("Reading folder : ", args.input_dir) 11 | print("Selected spacings : ", args.spacing) 12 | 13 | patients = {} 14 | 15 | normpath = os.path.normpath("/".join([args.input_dir, '**', ''])) 16 | for img_fn in sorted(glob.iglob(normpath, recursive=True)): 17 | # print(img_fn) 18 | basename = os.path.basename(img_fn) 19 | 20 | if True in [ext in img_fn for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 21 | file_name = basename.split(".")[0] 22 | patient = file_name.split("_")[0] 23 | 24 | if patient not in patients.keys(): 25 | patients[patient] = {} 26 | 27 | if True in [txt in basename for txt in ["scan","Scan"]]: 28 | patients[patient]["scan"] = img_fn 29 | 30 | elif True in [txt in basename for txt in ["seg","Seg"]]: 31 | patients[patient]["seg"] = img_fn 32 | 33 | # ========== MARILIA ======== 34 | 35 | # elements_1 = file_name.split("1") 36 | # elements_2 = file_name.split("2") 37 | 38 | # if len(elements_1)>1 or len(elements_2) >1: 39 | # if len(elements_1)>1: 40 | # # print(elements_1[0]) 41 | # patient = elements_1[0] + "-1" 42 | # elif len(elements_2)>1: 43 | # # print(elements_2[0]) 44 | # patient = elements_2[0] + "-2" 45 | 46 | # if patient not in patients.keys(): 47 | # patients[patient] = {} 48 | 49 | # if True in [txt in basename for txt in ["scan","Scan"]]: 50 | # patients[patient]["scan"] = img_fn 51 | # elif True in [txt in basename for txt in ["-MD","SEGMD","SEGmd","segMD","segmd"]]: 52 | # patients[patient]["seg"] = img_fn 53 | # elif True in [txt in basename for txt in ["MX"]]: 54 | # patients[patient]["MX"] = img_fn 55 | # elif True in [txt in basename for txt in ["MX"]]: 56 | # patients[patient]["MD"] = img_fn 57 | # elif True in [txt in basename for txt in ["TM"]]: 58 | # patients[patient]["seg"] = img_fn 59 | # elif True in [txt in basename for txt in ["FACE"]]: 60 | # patients[patient]["FACE"] = img_fn 61 | # elif True in [txt in basename for txt in ["VC"]]: 62 | # patients[patient]["seg"] = img_fn 63 | # elif True in [txt in basename for txt in ["SEGBC","seg-BC","segBC"]]: #and not True in [txt in basename for txt in ["aproxBC","regBC","BC5","scan","Scan"]]: 64 | # patients[patient]["seg"] = img_fn 65 | 66 | # ===================== 67 | 68 | 69 | # print(elements_dash) 70 | 71 | # patient = "" 72 | # if len(elements_) != 0: 73 | # if len(elements_) > 2: 74 | # patient = elements_[0] + "_" + elements_[1] 75 | # elif len(elements_) > 1: 76 | # patient = elements_[0] 77 | # if len(elements_dash) >1: 78 | # patient = elements_dash[0] 79 | 80 | # folder_name = os.path.basename(os.path.dirname(img_fn)) 81 | # if folder_name in patient: 82 | # folder_name = os.path.basename(os.path.dirname(os.path.dirname(img_fn))) 83 | # patient = folder_name + "-" + patient 84 | 85 | # # print(patient) 86 | 87 | 88 | # print(patients.keys()) 89 | 90 | error = False 91 | invalid_patient = [] 92 | for patient,data in patients.items(): 93 | if "scan" not in data.keys(): 94 | print("Missing scan for patient :",patient) 95 | error = True 96 | if patient not in invalid_patient: 97 | invalid_patient.append(patient) 98 | if "seg" not in data.keys(): 99 | print("Missing seg segmentation patient :",patient) 100 | error = True 101 | if patient not in invalid_patient: 102 | invalid_patient.append(patient) 103 | 104 | # if "MD" not in data.keys(): 105 | # print("Missing MD segmentation patient :",patient) 106 | # error = True 107 | # if "TM" not in data.keys(): 108 | # print("Missing TM segmentation patient :",patient) 109 | # error = True 110 | # if "FACE" not in data.keys(): 111 | # print("Missing FACE segmentation patient :",patient) 112 | # error = True 113 | # if "VC" not in data.keys(): 114 | # print("Missing VC segmentation patient :",patient) 115 | # error = True 116 | 117 | # print(patients) 118 | 119 | # if error: 120 | # print("ERROR : folder have missing/unrecognise files", file=sys.stderr) 121 | # raise 122 | # if patient not in patients.keys(): 123 | # patients[patient] = {"dir": os.path.dirname(img_fn)} 124 | 125 | # if True in [txt in basename for txt in ["scan","Scan"]]: 126 | # patients[patient]["scan"] = img_fn 127 | 128 | # elif True in [txt in basename for txt in ["seg","Seg"]]: 129 | # patients[patient]["seg"] = img_fn 130 | # else: 131 | # print("----> Unrecognise CBCT file found at :", img_fn) 132 | 133 | for ip in invalid_patient: 134 | del patients[ip] 135 | 136 | patient_dir = "UFG" 137 | N = 0 138 | Outpath = os.path.normpath("/".join([args.out,patient_dir])) 139 | 140 | if not os.path.exists(Outpath): 141 | os.makedirs(Outpath) 142 | 143 | for patient,data in patients.items(): 144 | 145 | scan = data["scan"] 146 | seg = data["seg"] 147 | # print(seg) 148 | 149 | 150 | # file_basename = os.path.basename(scan) 151 | # file_name = file_basename.split(".") 152 | 153 | for sp in args.spacing: 154 | spacing = str(sp).replace(".","") 155 | scan_name = patient_dir + "-" + patient + "_scan_Sp"+ spacing + ".nii.gz" 156 | seg_name = patient_dir + "-" + patient + "_seg_Sp"+ spacing + ".nii.gz" 157 | 158 | save_path = os.path.join(Outpath,seg_name) 159 | SetSpacing(scan,[sp,sp,sp],outpath=os.path.join(Outpath,scan_name)) 160 | SetSpacing(seg,[sp,sp,sp],"NearestNeighbor",save_path) 161 | KeepLabel(save_path,save_path,4) 162 | 163 | N += 1 164 | 165 | print(N) 166 | 167 | if __name__ == '__main__': 168 | parser = argparse.ArgumentParser(description='MD_reader', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 169 | 170 | input_group = parser.add_argument_group('Input files') 171 | input_group.add_argument('-i','--input_dir', type=str, help='Input directory with 3D images',required=True) 172 | 173 | output_params = parser.add_argument_group('Output parameters') 174 | output_params.add_argument('-o','--out', type=str, help='Output directory', required=True) 175 | 176 | input_group.add_argument('-sp', '--spacing', nargs="+", type=float, help='Wanted output x spacing', default=[0.5]) 177 | 178 | args = parser.parse_args() 179 | 180 | main(args) 181 | 182 | -------------------------------------------------------------------------------- /MULTI_SEG/src/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | from sklearn import metrics 5 | import time 6 | import SimpleITK as sitk 7 | import os 8 | import glob 9 | import math 10 | import tqdm 11 | 12 | # from numba import jit, prange 13 | # @jit(nopython=True, nogil=True, cache=True, parallel=True, fastmath=True) 14 | # def compute_tp_tn_fp_fn(y_true, y_pred): 15 | # tp = 0 16 | # tn = 0 17 | # fp = 0 18 | # fn = 0 19 | # for i in range(y_pred.size): 20 | # tp += y_true[i] * y_pred[i] 21 | # tn += (1-y_true[i]) * (1-y_pred[i]) 22 | # fp += (1-y_true[i]) * y_pred[i] 23 | # fn += y_true[i] * (1-y_pred[i]) 24 | 25 | 26 | def compute_tp_tn_fp_fn(y_true, y_pred): 27 | 28 | tp = np.sum(y_true*y_pred) 29 | tn = np.sum((1-y_true)*(1-y_pred)) 30 | fp = np.sum((1-y_true)*y_pred) 31 | fn = np.sum(y_true*(1-y_pred)) 32 | 33 | return tp, tn, fp, fn 34 | 35 | def compute_precision(tp, fp): 36 | return tp / (tp + fp) 37 | 38 | def compute_recall(tp, fn): 39 | return tp / (tp + fn) 40 | 41 | def compute_f1_score(precision, recall): 42 | try: 43 | return (2*precision*recall) / (precision + recall) 44 | except: 45 | return 0 46 | 47 | def compute_fbeta_score(precision, recall, beta): 48 | try: 49 | return ((1 + beta**2) * precision * recall) / (beta**2 * precision + recall) 50 | except: 51 | return 0 52 | 53 | def compute_accuracy(tp,tn,fp,fn): 54 | return (tp + tn)/(tp + tn + fp + fn) 55 | 56 | def compute_auc(GT, pred): 57 | return metrics.roc_auc_score(GT, pred) 58 | 59 | def compute_auprc(GT, pred): 60 | prec, rec, thresholds = metrics.precision_recall_curve(GT, pred) 61 | # print(prec, rec, thresholds) 62 | plt.plot(prec, rec) 63 | plt.show() 64 | # return metrics.auc(prec, rec) 65 | 66 | def compute_average_precision(GT, pred): 67 | ratio = sum(GT)/np.size(GT) 68 | return metrics.average_precision_score(GT, pred), ratio 69 | 70 | dir = "/Users/luciacev-admin/Desktop/TEST_METRICS" 71 | patients = {} 72 | normpath = os.path.normpath("/".join([dir, '**', ''])) 73 | for img_fn in sorted(glob.iglob(normpath, recursive=True)): 74 | # print(img_fn) 75 | basename = os.path.basename(img_fn) 76 | 77 | if True in [ext in basename for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 78 | file_name = basename.split(".")[0] 79 | patient = file_name.split("_Pred_Sp")[0].split("_seg_Sp")[0].split("_scan_Sp")[0] 80 | 81 | # print(patient) 82 | 83 | if patient not in patients.keys(): 84 | patients[patient] = {} 85 | 86 | if "_Pred_" in basename: 87 | patients[patient]["pred"] = img_fn 88 | 89 | elif "_seg_" in basename: 90 | patients[patient]["seg"] = img_fn 91 | # else: 92 | # print("----> Unrecognise CBCT file found at :", img_fn) 93 | 94 | # print(patients) 95 | 96 | avg_recall = [] 97 | avg_precision = [] 98 | avg_f1 = [] 99 | avg_fbeta = [] 100 | avg_acc = [] 101 | 102 | metrics_names = ['AUPRC','AUPRC - Baseline','F1_Score','Fbeta_Score','Accuracy','Recall','Precision','File'] 103 | total_values = pd.DataFrame(columns=metrics_names) 104 | 105 | 106 | startTime = time.time() 107 | 108 | 109 | for patient, data in tqdm.tqdm(patients.items()): 110 | 111 | GT = sitk.ReadImage(data["seg"]) 112 | GT = sitk.GetArrayFromImage(GT).flatten() 113 | 114 | pred = sitk.ReadImage(data["pred"]) 115 | pred = sitk.GetArrayFromImage(pred).flatten() 116 | 117 | tp, tn, fp, fn = compute_tp_tn_fp_fn(GT,pred) 118 | recall = compute_recall(tp, fn) 119 | precision = compute_precision(tp, fp) 120 | f1 = compute_f1_score(precision, recall) 121 | fbeta = compute_fbeta_score(precision, recall, 2) 122 | acc = compute_accuracy(tp, tn, fp, fn) 123 | auprc, ratio = compute_average_precision(GT, pred) 124 | 125 | avg_recall.append(recall) 126 | avg_precision.append(precision) 127 | avg_f1.append(f1) 128 | avg_fbeta.append(fbeta) 129 | avg_acc.append(acc) 130 | 131 | # print("========================") 132 | # print(patient) 133 | # # print(tp,tn,fp,fn) 134 | # print("Recall",recall) 135 | # print("Precision",precision) 136 | # print("F1",f1) 137 | # print("Fbeta",fbeta) 138 | # print("Accuracy",acc) 139 | # print("========================") 140 | 141 | metrics_line = [auprc,ratio,f1,fbeta,acc,recall,precision] 142 | metrics_line.append(os.path.basename(data["pred"]).split('.')[0]) 143 | total_values.loc[len(total_values)] = metrics_line 144 | 145 | mean_values = pd.DataFrame(columns=metrics_names) 146 | 147 | mean_line = [] 148 | std_line = [] 149 | for met_name in metrics_names[:-1]: 150 | means = total_values[met_name].mean() 151 | stds = total_values[met_name].std() 152 | mean_line.append(means) 153 | std_line.append(stds) 154 | 155 | mean_line.append("Mean") 156 | mean_values.loc[len(mean_values)] = mean_line 157 | 158 | std_line.append("STD") 159 | mean_values.loc[len(mean_values)] = std_line 160 | 161 | endTime = time.time() 162 | 163 | total_values.to_excel("All_metrics.xlsx") 164 | mean_values.to_excel("Average_metrics.xlsx") 165 | 166 | print(total_values) 167 | print("Took",endTime-startTime,"s") 168 | -------------------------------------------------------------------------------- /MULTI_SEG/src/correct_file.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import glob 4 | import sys 5 | import os 6 | 7 | from utils import( 8 | SetSpacing, 9 | CloseCBCTSeg, 10 | CorrectHisto, 11 | ) 12 | 13 | def main(args): 14 | img_fn_array = [] 15 | seg_fn_array = [] 16 | 17 | outpath = os.path.normpath("/".join([args.out])) 18 | 19 | if args.dir: 20 | normpath = os.path.normpath("/".join([args.dir, '**', ''])) 21 | for img_fn in glob.iglob(normpath, recursive=True): 22 | basename = os.path.basename(img_fn) 23 | if os.path.isfile(img_fn) and True in [ext in img_fn for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 24 | if True in [txt in basename for txt in ["scan","Scan"]]: 25 | img_obj = {} 26 | img_obj["img"] = img_fn 27 | img_obj["out"] = outpath + img_fn.replace(args.dir,'') 28 | img_fn_array.append(img_obj) 29 | if True in [txt in basename for txt in ["seg","Seg"]]: 30 | img_obj = {} 31 | img_obj["img"] = img_fn 32 | img_obj["out"] = outpath + img_fn.replace(args.dir,'') 33 | seg_fn_array.append(img_obj) 34 | 35 | for img_obj in seg_fn_array: 36 | image = img_obj["img"] 37 | out = img_obj["out"] 38 | 39 | if not os.path.exists(os.path.dirname(out)): 40 | os.makedirs(os.path.dirname(out)) 41 | CloseCBCTSeg(image, image, args.radius) 42 | 43 | for img_obj in img_fn_array: 44 | image = img_obj["img"] 45 | out = img_obj["out"] 46 | # out = img_obj["img"] 47 | 48 | if not os.path.exists(os.path.dirname(out)): 49 | os.makedirs(os.path.dirname(out)) 50 | CorrectHisto(image, image,0.01, 0.99) 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='MD_reader', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 54 | 55 | input_group = parser.add_argument_group('Input files') 56 | input_group.add_argument('-i','--dir', type=str, help='Input directory with 3D images',required=True) 57 | 58 | output_params = parser.add_argument_group('Output parameters') 59 | output_params.add_argument('-o','--out', type=str, help='Output directory') 60 | 61 | input_group.add_argument('-rad', '--radius', type=int, help='Radius of the closing', default=3) 62 | 63 | args = parser.parse_args() 64 | 65 | main(args) -------------------------------------------------------------------------------- /MULTI_SEG/src/data_split_csv.py: -------------------------------------------------------------------------------- 1 | from utils import* 2 | 3 | # dir = "/Users/luciacev-admin/Desktop/Mandible_Dataset" 4 | # dir = "/Users/luciacev-admin/Desktop/Vertebre_Dataset" 5 | dir = "/Users/luciacev-admin/Desktop/Cranial_Base_Dataset" 6 | out = "/Users/luciacev-admin/Desktop/CV_TEST" 7 | 8 | GenWorkSpace(dir,0.2,out) 9 | -------------------------------------------------------------------------------- /MULTI_SEG/src/init_training_data.py: -------------------------------------------------------------------------------- 1 | from utils import* 2 | import argparse 3 | import glob 4 | import sys 5 | import os 6 | from shutil import copyfile 7 | 8 | def main(args): 9 | 10 | print("Reading folder : ", args.input_dir) 11 | print("Selected spacings : ", args.spacing) 12 | 13 | patients = {} 14 | 15 | normpath = os.path.normpath("/".join([args.input_dir, '**', ''])) 16 | for img_fn in sorted(glob.iglob(normpath, recursive=True)): 17 | # print(img_fn) 18 | basename = os.path.basename(img_fn) 19 | 20 | if True in [ext in img_fn for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 21 | file_name = basename.split(".")[0] 22 | elements_ = file_name.split("_") 23 | elements_dash = file_name.split("-") 24 | # print(elements_dash) 25 | patient = "" 26 | if len(elements_) != 0: 27 | if len(elements_) > 2: 28 | patient = elements_[0] + "_" + elements_[1] 29 | elif len(elements_) > 1: 30 | patient = elements_[0] 31 | if len(elements_dash) >1: 32 | patient = elements_dash[0] 33 | 34 | # patient = "RC-"+elements_[0] 35 | # for elem in elements_[1:-1]: 36 | # patient += "_" + elem 37 | 38 | # print(patient) 39 | 40 | folder_name = os.path.basename(os.path.dirname(img_fn)) 41 | if folder_name in patient: 42 | folder_name = os.path.basename(os.path.dirname(os.path.dirname(img_fn))) 43 | patient = folder_name + "-" + patient 44 | 45 | print(patient) 46 | 47 | if patient not in patients.keys(): 48 | patients[patient] = {} 49 | 50 | if True in [txt in basename for txt in ["scan","Scan"]]: 51 | patients[patient]["scan"] = img_fn 52 | patients[patient]["dir"] = os.path.dirname(img_fn) 53 | 54 | elif True in [txt in basename for txt in ["seg","Seg"]]: 55 | patients[patient]["seg"] = img_fn 56 | else: 57 | print("----> Unrecognise CBCT file found at :", img_fn) 58 | 59 | # if not os.path.exists(SegOutpath): 60 | # os.makedirs(SegOutpath) 61 | 62 | error = False 63 | for patient,data in patients.items(): 64 | if "scan" not in data.keys(): 65 | print("Missing scan for patient :",patient) 66 | error = True 67 | if "seg" not in data.keys(): 68 | print("Missing segmentation patient :",patient) 69 | error = True 70 | 71 | if error: 72 | print("ERROR : folder have missing/unrecognise files", file=sys.stderr) 73 | raise 74 | 75 | Outpath = args.out 76 | if not os.path.exists(Outpath): 77 | os.makedirs(Outpath) 78 | 79 | 80 | 81 | for patient,data in patients.items(): 82 | 83 | scan = data["scan"] 84 | seg = data["seg"] 85 | 86 | patient_dir = patient.split("-")[0] 87 | patient_name = patient.split("-")[1] 88 | 89 | patient_dirname = data["dir"].replace(args.input_dir,'') 90 | ScanOutpath = os.path.join(Outpath,patient_dir) 91 | 92 | if not os.path.exists(ScanOutpath): 93 | os.makedirs(ScanOutpath) 94 | 95 | # if not os.path.exists(SegOutpath): 96 | # os.makedirs(SegOutpath) 97 | 98 | file_basename = os.path.basename(scan) 99 | file_name = file_basename.split(".") 100 | 101 | # Outpath_Seg = os.path.join(ScanOutpath, patient + "_Correct_Seg") 102 | # if not os.path.exists(Outpath_Seg): 103 | # os.makedirs(Outpath_Seg) 104 | 105 | sp = args.spacing 106 | spacing = str(sp).replace(".","") 107 | # scan_name = patient + "_scan_Sp"+ spacing + ".nii.gz" 108 | # seg_name = patient + "_seg_Sp"+ spacing + ".nii.gz" 109 | scan_name = patient + "_scan.nii.gz" 110 | seg_name = patient + "_MAND-Seg.nii.gz" 111 | 112 | 113 | 114 | SetSpacing(scan,output_spacing=sp,outpath= os.path.join(ScanOutpath,scan_name)) 115 | SetSpacing(seg,output_spacing=sp,interpolator="NearestNeighbor",outpath= os.path.join(ScanOutpath,seg_name)) 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser(description='MD_reader', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 120 | 121 | input_group = parser.add_argument_group('Input files') 122 | input_group.add_argument('-i','--input_dir', type=str, help='Input directory with 3D images',required=True) 123 | 124 | output_params = parser.add_argument_group('Output parameters') 125 | output_params.add_argument('-o','--out', type=str, help='Output directory', required=True) 126 | 127 | input_group.add_argument('-sp', '--spacing', nargs="+", type=float, help='Wanted output x spacing', default=[0.5,0.5,0.5]) 128 | 129 | args = parser.parse_args() 130 | 131 | main(args) -------------------------------------------------------------------------------- /MULTI_SEG/src/merge_seg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import sys 4 | import os 5 | import SimpleITK as sitk 6 | import numpy as np 7 | 8 | def main(args): 9 | 10 | # outpath = os.path.normpath("/".join([args.out])) 11 | 12 | structures_to_merge = args.structures 13 | structures_labels = args.labels 14 | label_dic = {} 15 | for i in range(len(structures_to_merge)): 16 | label_dic[structures_to_merge[i]] = structures_labels[i] 17 | 18 | 19 | 20 | patients = {} 21 | if args.input: 22 | normpath = os.path.normpath("/".join([args.input, '**', ''])) 23 | for img_fn in glob.iglob(normpath, recursive=True): 24 | basename = os.path.basename(img_fn) 25 | if os.path.isfile(img_fn) and True in [ext in basename for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 26 | if True in [txt in basename for txt in ["Seg"]]: 27 | patient_seg = basename.split("Seg")[0][:-1].split("_") 28 | seg_id = patient_seg.pop(-1) 29 | patient = "_".join(patient_seg) 30 | # # print(patient) 31 | # # print(seg_id) 32 | if patient not in patients.keys(): 33 | patients[patient] = {} 34 | patients[patient]["dir"] = os.path.dirname(img_fn) 35 | 36 | patients[patient][seg_id] = img_fn 37 | 38 | # if "_scan" in basename: 39 | # patient = basename.split("_scan")[0] 40 | # print(patient) 41 | # if patient not in patients.keys(): 42 | # patients[patient] = {} 43 | # patients[patient]["scan"] = img_fn 44 | 45 | 46 | 47 | 48 | # seg_fn_array.append(img_obj) 49 | # print(patients) 50 | 51 | for patient,data in patients.items(): 52 | merge_lst = [] 53 | 54 | print(patient) 55 | 56 | 57 | for id in args.merging_order: 58 | if id in data.keys() and id in structures_to_merge: 59 | merge_lst.append(id) 60 | 61 | first_id = merge_lst.pop(0) 62 | first_img = sitk.ReadImage(data[first_id]) 63 | seg = sitk.GetArrayFromImage(first_img) 64 | merged_seg = np.where(seg==1,label_dic[first_id],seg) 65 | 66 | for id in merge_lst: 67 | img = sitk.ReadImage(data[id]) 68 | seg = sitk.GetArrayFromImage(img) 69 | merged_seg = np.where(seg==1,label_dic[id],merged_seg) 70 | 71 | 72 | 73 | # for i in range(len(merge_lst)-1): 74 | # label = i+2 75 | # img = sitk.ReadImage(merge_lst[i+1]) 76 | # seg = sitk.GetArrayFromImage(img) 77 | # main_seg = np.where(seg==1,label,main_seg) 78 | 79 | output = sitk.GetImageFromArray(merged_seg) 80 | output.SetSpacing(first_img.GetSpacing()) 81 | output.SetDirection(first_img.GetDirection()) 82 | output.SetOrigin(first_img.GetOrigin()) 83 | output = sitk.Cast(output, sitk.sitkInt16) 84 | 85 | writer = sitk.ImageFileWriter() 86 | writer.SetFileName( os.path.join(data["dir"], patient+"_MERGED_Seg.nii.gz") ) 87 | writer.Execute(output) 88 | 89 | 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser(description='Merge segmentations', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 95 | 96 | input_group = parser.add_argument_group('Input files') 97 | input_group.add_argument('-i','--input', type=str, help='Input directory with 3D segmentations',required=True) 98 | 99 | output_params = parser.add_argument_group('Output parameters') 100 | output_params.add_argument('-o','--out', type=str, help='Output directory', default=parser.parse_args().input) 101 | 102 | input_group.add_argument('-s', '--structures', type=int, help='Structures to merge', default=["MAND","CB","MAX","CV","UAW"]) 103 | input_group.add_argument('-l', '--labels', type=int, help='Labels of each structures', default=[1,2,2,5,3]) 104 | 105 | input_group.add_argument('-mo','--merging_order',nargs="+", type=str, help='order of the merging', default=["CV","SKIN","UAW","CB","MAX","MAND","CAN","RCL","RCU"]) 106 | 107 | 108 | 109 | args = parser.parse_args() 110 | 111 | main(args) -------------------------------------------------------------------------------- /MULTI_SEG/src/models.py: -------------------------------------------------------------------------------- 1 | from monai.networks.nets import UNETR,UNet 2 | # from monai.networks.nets import SwinUNETR 3 | 4 | def Create_UNETR(input_channel, label_nbr,cropSize): 5 | 6 | model = UNETR( 7 | in_channels=input_channel, 8 | out_channels=label_nbr, 9 | img_size=cropSize, 10 | feature_size=16, 11 | hidden_size=768, 12 | mlp_dim=3072, 13 | num_heads=12, 14 | # feature_size=32, 15 | # hidden_size=1024, 16 | # mlp_dim=4096, 17 | # num_heads=16, 18 | pos_embed="perceptron", 19 | norm_name="instance", 20 | res_block=True, 21 | dropout_rate=0.05, 22 | ) 23 | 24 | 25 | # model = UNet( 26 | # spatial_dims=3, 27 | # in_channels=input_channel, 28 | # out_channels=label_nbr, 29 | # channels = (16,32,64,128,256), 30 | # strides=(2,2,2,2), 31 | # dropout=0.05, 32 | # ) 33 | 34 | return model 35 | 36 | # def Create_SwinUNETR(input_channel, label_nbr,cropSize): 37 | 38 | # model = SwinUNETR( 39 | # img_size=cropSize, 40 | # in_channels=input_channel, 41 | # out_channels=label_nbr, 42 | # feature_size=48, 43 | # # drop_rate=0.0, 44 | # # attn_drop_rate=0.0, 45 | # # dropout_path_rate=0.0, 46 | # use_checkpoint=True, 47 | # ) 48 | 49 | # return model 50 | -------------------------------------------------------------------------------- /MULTI_SEG/src/post_process_test.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import string 3 | from matplotlib.pyplot import axis 4 | import cc3d 5 | import numpy as np 6 | import SimpleITK as sitk 7 | import itk 8 | from utils import * 9 | 10 | seg_path_dic = { 11 | "MAND":"/Users/luciacev-admin/Desktop/MANDSEG_TEST/CP66_MAND_Pred_sp0-25.nii.gz", 12 | "SKIN":"/Users/luciacev-admin/Desktop/MANDSEG_TEST/CP66_SKIN_Pred_sp0-25.nii.gz", 13 | "CV":"/Users/luciacev-admin/Desktop/MANDSEG_TEST/CP66_CV_Pred_sp0-25.nii.gz", 14 | "CB":"/Users/luciacev-admin/Desktop/MANDSEG_TEST/CP66_CB_Pred_sp0-25.nii.gz", 15 | "MAX":"/Users/luciacev-admin/Desktop/MANDSEG_TEST/CP66_MAX_Pred_sp0-25.nii.gz", 16 | } 17 | 18 | merging_order = ["CV","CB","MAX","MAND"] 19 | 20 | outpath = "test.nii.gz" 21 | 22 | MergeSeg(seg_path_dic,outpath,merging_order) 23 | 24 | 25 | 26 | # #Get image from url 27 | # def get_image_from_url(url): 28 | # import urllib.request 29 | # import io 30 | # resp = urllib.request.urlopen(url) 31 | # image = np.asarray(bytearray(resp.read()), dtype="uint8") 32 | # image = np.reshape(image, (256, 256)) 33 | # return image 34 | 35 | 36 | 37 | # #plot image 38 | # def plot_image(image): 39 | # plt.imshow(image,cmap='gray') 40 | # plt.show() 41 | 42 | # plot_image(get_image_from_url('https://www.google.com/search?q=chat&rlz=1C5GCEM_enUS964US964&sxsrf=APq-WBsq5hGK6kOKg2_s2qx7Er00jJ8_jg:1648668396144&source=lnms&tbm=isch&sa=X&ved=2ahUKEwjQr-qwyO72AhXEXM0KHYY4DfIQ_AUoAXoECAEQAw&biw=1680&bih=882&dpr=2#imgrc=YVqXM2zc5FB_5M')) 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | # # img = itk.imread("/Users/luciacev-admin/Desktop/Scans_RS/UoP/Segs/UoP-362_seg_Sp05.nii.gz") 70 | 71 | # # img_info = itk.template(img)[1] 72 | # # pixel_type = img_info[0] 73 | # # pixel_dimension = img_info[1] 74 | # # ImageType = itk.Image[pixel_type, pixel_dimension] 75 | 76 | 77 | # # ImageType = itk.Image[itk.US, 3] 78 | # # BinaryFillholeImageFilter = itk.BinaryFillholeImageFilter[ImageType].New() 79 | # # BinaryFillholeImageFilter.SetInput(img) 80 | # # BinaryFillholeImageFilter.SetForegroundValue(1) 81 | # # BinaryFillholeImageFilter.Update() 82 | # # filled_itk_img = BinaryFillholeImageFilter.GetOutput() 83 | 84 | # # itk.imwrite(filled_itk_img,"test.nii.gz") 85 | 86 | # # output = sitk.ReadImage("/Users/luciacev-admin/Desktop/Maxime segmentations 2/AH1-seg-TM.gipl.gz") 87 | # # closing_radius = 1 88 | # # output = sitk.BinaryDilate(output, [closing_radius] * output.GetDimension()) 89 | # # output = sitk.BinaryFillhole(output) 90 | # # output = sitk.BinaryErode(output, [closing_radius] * output.GetDimension()) 91 | 92 | # # writer = sitk.ImageFileWriter() 93 | # # writer.SetFileName("SKIN_FILL.nii.gz") 94 | # # writer.Execute(output) 95 | 96 | 97 | 98 | # input_img = sitk.ReadImage("/Users/luciacev-admin/Desktop/test/RC-P1_scan_Sp05.nii.gz") 99 | # input_seg = sitk.ReadImage("/Users/luciacev-admin/Desktop/test/RC-P1_seg_Sp05.nii.gz") 100 | 101 | 102 | # closing_radius = 1 103 | 104 | # output = sitk.BinaryDilate(input_seg, [closing_radius] * input_seg.GetDimension()) 105 | # output = sitk.BinaryFillhole(output) 106 | # output = sitk.GetArrayFromImage(output) 107 | # output = np.transpose(output, (2, 0, 1)) 108 | # # output, N = cc3d.largest_k( 109 | # # labels_in, k=1, 110 | # # connectivity=26, delta=0, 111 | # # return_N=True, 112 | # # ) 113 | # output = cc3d.connected_components(output) 114 | # output = np.transpose(output, (1, 2, 0)) 115 | 116 | # # closing_radius = 3 117 | 118 | # # output = sitk.GetImageFromArray(output) 119 | # # output = sitk.BinaryDilate(output, [closing_radius] * output.GetDimension()) 120 | # # output = sitk.BinaryFillhole(output) 121 | # # output = sitk.BinaryErode(output, [closing_radius] * output.GetDimension()) 122 | 123 | # stats = cc3d.statistics(output) 124 | # tooth = stats['bounding_boxes'][1] 125 | # # print(tooth) 126 | 127 | # output = output[tooth[0].start:tooth[0].stop,tooth[1].start:tooth[1].stop,tooth[2].start:tooth[2].stop ] 128 | 129 | # # print(stats["voxel_counts"]) 130 | 131 | 132 | # # labels_out = cc3d.dust( 133 | # # labels_in, threshold=10, 134 | # # connectivity=26, in_place=False 135 | # # ) 136 | # # labels_out = cc3d.dust(labels_in) 137 | 138 | 139 | 140 | # output = sitk.GetImageFromArray(output) 141 | # output.SetSpacing(input_img.GetSpacing()) 142 | # output.SetDirection(input_img.GetDirection()) 143 | # output.SetOrigin(input_img.GetOrigin()) 144 | 145 | # writer = sitk.ImageFileWriter() 146 | # writer.SetFileName("test.nii.gz") 147 | # writer.Execute(output) 148 | 149 | # # closing_radius = 8 150 | # # output = sitk.BinaryDilate(output, [closing_radius] * output.GetDimension()) 151 | # # output = sitk.BinaryErode(output, [closing_radius] * output.GetDimension()) 152 | 153 | # # writer = sitk.ImageFileWriter() 154 | # # writer.SetFileName("closed.nii.gz") 155 | # # writer.Execute(output) 156 | 157 | # # closed = sitk.GetArrayFromImage(output) 158 | 159 | # # stats = cc3d.statistics(labels_out) 160 | # # # print(stats) 161 | # # # print("mid = ", np.mean(stats['centroids'], axis = 0)) 162 | # # mand_bbox = stats['bounding_boxes'][1] 163 | # # # print(mand_bbox) 164 | # # rng_lst = [] 165 | # # mid_lst = [] 166 | # # for slices in mand_bbox: 167 | # # rng = slices.stop-slices.start 168 | # # mid = (2/3)*rng+slices.start 169 | # # rng_lst.append(rng) 170 | # # mid_lst.append(mid) 171 | 172 | # # print(rng_lst,mid_lst) 173 | 174 | 175 | 176 | # # dif = closed - labels_out 177 | 178 | # # print(np.shape(labels_out[:,:,:150])) 179 | # # print(labels_out[:,:,:150]) 180 | # # print(np.shape(closed[:,:,150:])) 181 | # # print(closed[:,:,150:]) 182 | 183 | # # merge_slice = int(mid_lst[0]) 184 | # # print(merge_slice) 185 | # # out = np.concatenate((labels_out[:merge_slice,:,:],closed[merge_slice:,:,:]),axis=0) 186 | 187 | 188 | # # output = sitk.GetImageFromArray(out) 189 | # # output.SetSpacing(input_img.GetSpacing()) 190 | # # output.SetDirection(input_img.GetDirection()) 191 | # # output.SetOrigin(input_img.GetOrigin()) 192 | 193 | # # writer = sitk.ImageFileWriter() 194 | # # writer.SetFileName("/Users/luciacev-admin/Desktop/MANDSEG_TEST/PRED_HP/MARILIA-30_Pred_Sp05.nii.gz") 195 | # # writer.Execute(output) 196 | 197 | 198 | # """ 199 | 200 | # labels_in = np.ones((512, 512, 512), dtype=np.int32) 201 | # labels_out = cc3d.connected_components(labels_in) # 26-connected 202 | 203 | # connectivity = 6 # only 4,8 (2D) and 26, 18, and 6 (3D) are allowed 204 | # labels_out = cc3d.connected_components(labels_in, connectivity=connectivity) 205 | 206 | # # If you're working with continuously valued images like microscopy 207 | # # images you can use cc3d to perform a very rough segmentation. 208 | # # If delta = 0, standard high speed processing. If delta > 0, then 209 | # # neighbor voxel values <= delta are considered the same component. 210 | # # The algorithm can be 2-10x slower though. Zero is considered 211 | # # background and will not join to any other voxel. 212 | # labels_out = cc3d.connected_components(labels_in, delta=10) 213 | 214 | # # You can extract the number of labels (which is also the maximum 215 | # # label value) like so: 216 | # labels_out, N = cc3d.connected_components(labels_in, return_N=True) # free 217 | # # -- OR -- 218 | # labels_out = cc3d.connected_components(labels_in) 219 | # N = np.max(labels_out) # costs a full read 220 | 221 | # # You can extract individual components using numpy operators 222 | # # This approach is slow, but makes a mutable copy. 223 | # for segid in range(1, N+1): 224 | # extracted_image = labels_out * (labels_out == segid) 225 | # process(extracted_image) # stand in for whatever you'd like to do 226 | 227 | # # If a read-only image is ok, this approach is MUCH faster 228 | # # if the image has many contiguous regions. A random image 229 | # # can be slower. binary=True yields binary images instead 230 | # # of numbered images. 231 | # for label, image in cc3d.each(labels_out, binary=False, in_place=True): 232 | # process(image) # stand in for whatever you'd like to do 233 | 234 | # # Image statistics like voxel counts, bounding boxes, and centroids. 235 | # stats = cc3d.statistics(labels_out) 236 | 237 | # # Remove dust from the input image. Removes objects with 238 | # # fewer than `threshold` voxels. 239 | # labels_out = cc3d.dust( 240 | # labels_in, threshold=100, 241 | # connectivity=26, in_place=False 242 | # ) 243 | 244 | # # Get a labeling of the k largest objects in the image. 245 | # # The output will be relabeled from 1 to N. 246 | # labels_out, N = cc3d.largest_k( 247 | # labels_in, k=10, 248 | # connectivity=26, delta=0, 249 | # return_N=True, 250 | # ) 251 | # labels_in *= (labels_out > 0) # to get original labels 252 | 253 | # # We also include a region adjacency graph function 254 | # # that returns a set of undirected edges. 255 | # edges = cc3d.region_graph(labels_out, connectivity=connectivity) 256 | 257 | # # You can also generate a voxel connectivty graph that encodes 258 | # # which directions are passable from a given voxel as a bitfield. 259 | # # This could also be seen as a method of eroding voxels fractionally 260 | # # based on their label adjacencies. 261 | # # See help(cc3d.voxel_connectivity_graph) for details. 262 | # graph = cc3d.voxel_connectivity_graph(labels, connectivity=connectivity) 263 | 264 | # """ -------------------------------------------------------------------------------- /MULTI_SEG/src/predict_CBCTSeg.py: -------------------------------------------------------------------------------- 1 | from models import* 2 | from utils import* 3 | import time 4 | import os 5 | import shutil 6 | import random 7 | import string 8 | 9 | #generate random id 10 | def id_generator(size=6, chars=string.ascii_uppercase + string.digits): 11 | return ''.join(random.choice(chars) for _ in range(size)) 12 | 13 | from monai.data import ( 14 | DataLoader, 15 | Dataset, 16 | SmartCacheDataset, 17 | load_decathlon_datalist, 18 | decollate_batch, 19 | ) 20 | 21 | import argparse 22 | 23 | #region Global variables 24 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | # DEVICE = torch.device("cpu") 26 | 27 | TRANSLATE ={ 28 | "Mandible" : "MAND", 29 | "Maxilla" : "MAX", 30 | "Cranial-base" : "CB", 31 | "Cervical-vertebra" : "CV", 32 | "Root-canal" : "RC", 33 | "Mandibular-canal" : "MCAN", 34 | "Upper-airway" : "UAW", 35 | "Skin" : "SKIN", 36 | "Teeth" : "TEETH" 37 | } 38 | 39 | INV_TRANSLATE = {} 40 | for k,v in TRANSLATE.items(): 41 | INV_TRANSLATE[v] = k 42 | 43 | LABELS = { 44 | 45 | "LARGE":{ 46 | "MAND" : 1, 47 | "CB" : 2, 48 | "UAW" : 3, 49 | "MAX" : 4, 50 | "CV" : 5, 51 | "SKIN" : 6, 52 | }, 53 | "SMALL":{ 54 | "MAND" : 1, 55 | "RC" : 2, 56 | "MAX" : 4, 57 | } 58 | } 59 | 60 | 61 | LABEL_COLORS = { 62 | 1: [216, 101, 79], 63 | 2: [128, 174, 128], 64 | 3: [0, 0, 0], 65 | 4: [230, 220, 70], 66 | 5: [111, 184, 210], 67 | 6: [172, 122, 101], 68 | } 69 | 70 | NAMES_FROM_LABELS = {"LARGE":{}, "SMALL":{}} 71 | for group,data in LABELS.items(): 72 | for k,v in data.items(): 73 | NAMES_FROM_LABELS[group][v] = INV_TRANSLATE[k] 74 | 75 | 76 | MODELS_GROUP = { 77 | "LARGE": { 78 | "FF": 79 | { 80 | "MAND" : 1, 81 | "CB" : 2, 82 | "UAW" : 3, 83 | "MAX" : 4, 84 | "CV" : 5, 85 | }, 86 | "SKIN": 87 | { 88 | "SKIN" : 1, 89 | } 90 | }, 91 | 92 | 93 | "SMALL": { 94 | "HD-MAND": 95 | { 96 | "MAND" : 1 97 | }, 98 | "HD-MAX": 99 | { 100 | "MAX" : 1 101 | }, 102 | "RC": 103 | { 104 | "RC" : 1 105 | }, 106 | }, 107 | } 108 | 109 | #endregion 110 | 111 | 112 | def SaveSeg(file_path, spacing ,seg_arr, input_path,temp_path, outputdir,temp_folder, save_vtk, smoothing = 5, model_size= "LARGE"): 113 | 114 | print("Saving segmentation for ", file_path) 115 | 116 | SavePrediction(seg_arr,input_path,temp_path,output_spacing = spacing) 117 | # if clean_seg: 118 | # CleanScan(temp_path) 119 | SetSpacingFromRef( 120 | temp_path, 121 | input_path, 122 | # "Linear", 123 | outpath=file_path 124 | ) 125 | 126 | if save_vtk: 127 | SavePredToVTK(file_path,temp_folder, smoothing, out_folder=outputdir,model_size=model_size) 128 | 129 | 130 | 131 | 132 | def CropSkin(skin_seg_arr, thickness): 133 | 134 | 135 | skin_img = sitk.GetImageFromArray(skin_seg_arr) 136 | skin_img = sitk.BinaryFillhole(skin_img) 137 | 138 | eroded_img = sitk.BinaryErode(skin_img, [thickness] * skin_img.GetDimension()) 139 | 140 | skin_arr = sitk.GetArrayFromImage(skin_img) 141 | eroded_arr = sitk.GetArrayFromImage(eroded_img) 142 | 143 | croped_skin = np.where(eroded_arr==1, 0, skin_arr) 144 | 145 | out, N = cc3d.largest_k( 146 | croped_skin, k=1, 147 | connectivity=26, delta=0, 148 | return_N=True, 149 | ) 150 | 151 | 152 | return out 153 | 154 | def CleanArray(seg_arr,radius): 155 | input_img = sitk.GetImageFromArray(seg_arr) 156 | output = sitk.BinaryDilate(input_img, [radius] * input_img.GetDimension()) 157 | output = sitk.BinaryFillhole(output) 158 | output = sitk.BinaryErode(output, [radius] * output.GetDimension()) 159 | 160 | labels_in = sitk.GetArrayFromImage(output) 161 | out, N = cc3d.largest_k( 162 | labels_in, k=1, 163 | connectivity=26, delta=0, 164 | return_N=True, 165 | ) 166 | 167 | return out 168 | 169 | 170 | 171 | 172 | 173 | 174 | def main(args): 175 | 176 | cropSize = args.crop_size 177 | 178 | temp_fold = os.path.join(args.temp_fold, "temp_" + id_generator()) 179 | if not os.path.exists(temp_fold): 180 | os.makedirs(temp_fold) 181 | 182 | 183 | 184 | # Find available models in folder 185 | available_models = {} 186 | print("Loading models from", args.dir_models) 187 | normpath = os.path.normpath("/".join([args.dir_models, '**', ''])) 188 | for img_fn in glob.iglob(normpath, recursive=True): 189 | # print(img_fn) 190 | basename = os.path.basename(img_fn) 191 | if basename.endswith(".pth"): 192 | model_id = basename.split("_")[1] 193 | available_models[model_id] = img_fn 194 | 195 | print("Available models:", available_models) 196 | 197 | 198 | 199 | 200 | # Choose models to use 201 | MODELS_DICT = {} 202 | models_to_use = {} 203 | # models_ID = [] 204 | if args.high_def: 205 | model_size = "SMALL" 206 | MODELS_DICT = MODELS_GROUP["SMALL"] 207 | spacing = [0.16,0.16,0.32] 208 | 209 | else: 210 | model_size = "LARGE" 211 | MODELS_DICT = MODELS_GROUP["LARGE"] 212 | spacing = [0.4,0.4,0.4] 213 | 214 | 215 | for model_id in MODELS_DICT.keys(): 216 | if model_id in available_models.keys(): 217 | for struct in args.skul_structure: 218 | if struct in MODELS_DICT[model_id].keys(): 219 | if model_id not in models_to_use.keys(): 220 | models_to_use[model_id] = available_models[model_id] 221 | 222 | 223 | # if True in [ for struct in args.skul_structure]: 224 | 225 | 226 | 227 | print(models_to_use) 228 | 229 | 230 | 231 | # load data 232 | data_list = [] 233 | 234 | 235 | if args.output_dir != None: 236 | outputdir = args.output_dir 237 | 238 | 239 | number_of_scans = 0 240 | if os.path.isfile(args.input): 241 | print("Loading scan :", args.input) 242 | img_fn = args.input 243 | basename = os.path.basename(img_fn) 244 | new_path = os.path.join(temp_fold,basename) 245 | temp_pred_path = os.path.join(temp_fold,"temp_Pred.nii.gz") 246 | if not os.path.exists(new_path): 247 | CorrectHisto(img_fn, new_path,0.01, 0.99) 248 | # new_path = img_fn 249 | data_list.append({"scan":new_path, "name":img_fn, "temp_path":temp_pred_path}) 250 | number_of_scans += 1 251 | 252 | if args.output_dir == None: 253 | outputdir = os.path.dirname(args.input) 254 | 255 | else: 256 | 257 | if args.output_dir == None: 258 | outputdir = args.input 259 | 260 | scan_dir = args.input 261 | print("Loading data from",scan_dir ) 262 | normpath = os.path.normpath("/".join([scan_dir, '**', ''])) 263 | for img_fn in sorted(glob.iglob(normpath, recursive=True)): 264 | # print(img_fn) 265 | basename = os.path.basename(img_fn) 266 | 267 | if True in [ext in basename for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 268 | if not True in [txt in basename for txt in ["_Pred","seg","Seg"]]: 269 | number_of_scans += 1 270 | 271 | 272 | counter = 0 273 | for img_fn in sorted(glob.iglob(normpath, recursive=True)): 274 | # print(img_fn) 275 | basename = os.path.basename(img_fn) 276 | 277 | if True in [ext in basename for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 278 | if not True in [txt in basename for txt in ["_Pred","seg","Seg"]]: 279 | new_path = os.path.join(temp_fold,basename) 280 | temp_pred_path = os.path.join(temp_fold,"temp_Pred.nii.gz") 281 | if not os.path.exists(new_path): 282 | CorrectHisto(img_fn, new_path,0.01, 0.99) 283 | data_list.append({"scan":new_path, "name":img_fn, "temp_path":temp_pred_path}) 284 | counter += 1 285 | 286 | 287 | 288 | 289 | #endregion 290 | 291 | 292 | # region prepare data 293 | 294 | pred_transform = CreatePredTransform(spacing) 295 | 296 | pred_ds = Dataset( 297 | data=data_list, 298 | transform=pred_transform, 299 | ) 300 | pred_loader = DataLoader( 301 | dataset=pred_ds, 302 | batch_size=1, 303 | shuffle=False, 304 | num_workers=args.nbr_CPU_worker, 305 | pin_memory=True 306 | ) 307 | # endregion 308 | 309 | 310 | startTime = time.time() 311 | seg_not_to_clean = ["CV","RC"] 312 | 313 | 314 | with torch.no_grad(): 315 | for step, batch in enumerate(pred_loader): 316 | 317 | #region PREDICTION 318 | 319 | input_img, input_path,temp_path = (batch["scan"].to(DEVICE), batch["name"],batch["temp_path"]) 320 | 321 | image = input_path[0] 322 | print("Working on :",image) 323 | baseName = os.path.basename(image) 324 | scan_name= baseName.split(".") 325 | # print(baseName) 326 | pred_id = "_XXXX-Seg_"+ args.prediction_ID 327 | 328 | if "_scan" in baseName: 329 | pred_name = baseName.replace("_scan",pred_id) 330 | elif "_Scan" in baseName: 331 | pred_name = baseName.replace("_Scan",pred_id) 332 | else: 333 | pred_name = "" 334 | for i,element in enumerate(scan_name): 335 | if i == 0: 336 | pred_name += element + pred_id 337 | else: 338 | pred_name += "." + element 339 | 340 | 341 | if args.save_in_folder: 342 | outputdir += "/" + scan_name[0] + "_" + "SegOut" 343 | print("Output dir :",outputdir) 344 | 345 | if not os.path.exists(outputdir): 346 | os.makedirs(outputdir) 347 | 348 | 349 | prediction_segmentation = {} 350 | 351 | 352 | 353 | for model_id,model_path in models_to_use.items(): 354 | 355 | net = Create_UNETR( 356 | input_channel = 1, 357 | label_nbr= len(MODELS_DICT[model_id].keys()) + 1, 358 | cropSize=cropSize 359 | ).to(DEVICE) 360 | 361 | 362 | # net = Create_SwinUNETR( 363 | # input_channel = 1, 364 | # label_nbr= len(MODELS_DICT[model_id].keys()) + 1, 365 | # cropSize=cropSize 366 | # ).to(DEVICE) 367 | 368 | 369 | 370 | print("Loading model", model_path) 371 | net.load_state_dict(torch.load(model_path,map_location=DEVICE)) 372 | net.eval() 373 | 374 | 375 | val_outputs = sliding_window_inference(input_img, cropSize, args.nbr_GPU_worker, net,overlap=args.precision) 376 | 377 | pred_data = torch.argmax(val_outputs, dim=1).detach().cpu().type(torch.int16) 378 | 379 | segmentations = pred_data.permute(0,3,2,1) 380 | 381 | # print("Segmentations shape :",segmentations.shape) 382 | 383 | seg = segmentations.squeeze(0) 384 | 385 | seg_arr = seg.numpy()[:] 386 | 387 | 388 | 389 | for struct, label in MODELS_DICT[model_id].items(): 390 | 391 | sep_arr = np.where(seg_arr == label, 1,0) 392 | 393 | if (struct == "SKIN"): 394 | sep_arr = CropSkin(sep_arr,5) 395 | # sep_arr = GenerateMask(sep_arr,20) 396 | elif not True in [struct == id for id in seg_not_to_clean]: 397 | sep_arr = CleanArray(sep_arr,2) 398 | 399 | prediction_segmentation[struct] = sep_arr 400 | 401 | 402 | 403 | #endregion 404 | 405 | 406 | 407 | #region ===== SAVE RESULT ===== 408 | 409 | seg_to_save = {} 410 | for struct in args.skul_structure: 411 | seg_to_save[struct] = prediction_segmentation[struct] 412 | 413 | save_vtk = args.gen_vtk 414 | 415 | if "SEPARATE" in args.merge or len(args.skul_structure) == 1: 416 | for struct,segmentation in seg_to_save.items(): 417 | file_path = os.path.join(outputdir,pred_name.replace('XXXX',struct)) 418 | SaveSeg( 419 | file_path = file_path, 420 | spacing = spacing, 421 | seg_arr=segmentation, 422 | input_path=input_path[0], 423 | outputdir=outputdir, 424 | temp_path=temp_path[0], 425 | temp_folder=temp_fold, 426 | save_vtk=args.gen_vtk, 427 | smoothing=args.vtk_smooth, 428 | model_size=model_size 429 | ) 430 | save_vtk = False 431 | 432 | if "MERGE" in args.merge and len(args.skul_structure) > 1: 433 | print("Merging") 434 | file_path = os.path.join(outputdir,pred_name.replace('XXXX',"MERGED")) 435 | merged_seg = np.zeros(seg_arr.shape) 436 | for struct in args.merging_order: 437 | if struct in seg_to_save.keys(): 438 | merged_seg = np.where(seg_to_save[struct] == 1, LABELS[model_size][struct], merged_seg) 439 | SaveSeg( 440 | file_path = file_path, 441 | spacing = spacing, 442 | seg_arr=merged_seg, 443 | input_path=input_path[0], 444 | outputdir=outputdir, 445 | temp_path=temp_path[0], 446 | temp_folder=temp_fold, 447 | save_vtk=save_vtk, 448 | model_size=model_size 449 | ) 450 | 451 | 452 | #endregion 453 | 454 | 455 | try: 456 | shutil.rmtree(temp_fold) 457 | except OSError as e: 458 | print("Error: %s : %s" % (temp_fold, e.strerror)) 459 | 460 | print("Done in %.2f seconds" % (time.time() - startTime)) 461 | 462 | 463 | #endregion 464 | 465 | 466 | 467 | if __name__ == "__main__": 468 | parser = argparse.ArgumentParser(description='Perform CBCT segmentation', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 469 | 470 | input_group = parser.add_argument_group('directory') 471 | 472 | input_group.add_argument('-i','--input', type=str, help='Path to the scans folder', default='/app/data/scans') 473 | input_group.add_argument('-o', '--output_dir', type=str, help='Folder to save output', default=None) 474 | input_group.add_argument('-dm', '--dir_models', type=str, help='Folder with the models', default='/app/data/ALL_MODELS') 475 | input_group.add_argument('-temp', '--temp_fold', type=str, help='temporary folder', default='..') 476 | 477 | input_group.add_argument('-ss', '--skul_structure', nargs="+", type=str, help='Skul structure to segment', default=["CV","UAW","CB","MAX","MAND"]) 478 | input_group.add_argument('-hd','--high_def', type=bool, help='Use high def models',default=False) 479 | input_group.add_argument('-m', '--merge', nargs="+", type=str, help='merge the segmentations', default=["MERGE"]) 480 | 481 | input_group.add_argument('-sf', '--save_in_folder', type=bool, help='Save the output in one folder', default=True) 482 | input_group.add_argument('-id', '--prediction_ID', type=str, help='Generate vtk files', default="Pred") 483 | 484 | input_group.add_argument('-vtk', '--gen_vtk', type=bool, help='Genrate vtk file', default=True) 485 | input_group.add_argument('-vtks','--vtk_smooth', type=int, help='Smoothness of the vtk', default=5) 486 | 487 | 488 | input_group.add_argument('-sp', '--spacing', nargs="+", type=float, help='Wanted output x spacing', default=[0.4,0.4,0.4]) 489 | input_group.add_argument('-cs', '--crop_size', nargs="+", type=float, help='Wanted crop size', default=[128,128,128]) 490 | input_group.add_argument('-pr', '--precision', type=float, help='precision of the prediction', default=0.5) 491 | input_group.add_argument('-mo','--merging_order',nargs="+", type=str, help='order of the merging', default=["SKIN","CV","UAW","CB","MAX","MAND","CAN","RC"]) 492 | 493 | input_group.add_argument('-ncw', '--nbr_CPU_worker', type=int, help='Number of worker', default=5) 494 | input_group.add_argument('-ngw', '--nbr_GPU_worker', type=int, help='Number of worker', default=1) 495 | 496 | 497 | args = parser.parse_args() 498 | main(args) 499 | -------------------------------------------------------------------------------- /MULTI_SEG/src/rescall_all.py: -------------------------------------------------------------------------------- 1 | from utils import* 2 | import argparse 3 | import glob 4 | import sys 5 | import os 6 | 7 | 8 | def main(args): 9 | 10 | print("Reading folder : ", args.input_dir) 11 | print("Selected spacings : ", args.spacing) 12 | 13 | patients = {} 14 | 15 | 16 | spacing = args.spacing 17 | 18 | normpath = os.path.normpath("/".join([args.input_dir, '**', ''])) 19 | for img_fn in sorted(glob.iglob(normpath, recursive=True)): 20 | # print(img_fn) 21 | basename = os.path.basename(img_fn) 22 | 23 | if True in [ext in img_fn for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 24 | 25 | if "_scan" in img_fn: 26 | SetSpacing(img_fn,spacing,outpath=img_fn) 27 | else: 28 | SetSpacing(img_fn,spacing,interpolator="NearestNeighbor" ,outpath=img_fn) 29 | 30 | 31 | 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser(description='MD_reader', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 36 | 37 | input_group = parser.add_argument_group('Input files') 38 | input_group.add_argument('-i','--input_dir', type=str, help='Input directory with 3D images',required=True) 39 | 40 | # output_params = parser.add_argument_group('Output parameters') 41 | # output_params.add_argument('-o','--out_dir', type=str, help='Output directory', required=True) 42 | 43 | input_group.add_argument('-sp', '--spacing', nargs="+", type=float, help='Wanted output x spacing', default=[0.4,0.4,0.4]) 44 | 45 | args = parser.parse_args() 46 | 47 | main(args) -------------------------------------------------------------------------------- /MULTI_SEG/src/train_CBCTseg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import * 3 | from utils import * 4 | 5 | import argparse 6 | 7 | import logging 8 | import sys 9 | 10 | from monai.config import print_config 11 | from monai.metrics import DiceMetric 12 | from monai.losses import DiceCELoss 13 | 14 | from monai.data import ( 15 | DataLoader, 16 | CacheDataset, 17 | SmartCacheDataset, 18 | load_decathlon_datalist, 19 | decollate_batch, 20 | ) 21 | 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | 25 | def main(args): 26 | 27 | # ##################################### 28 | # Init_param 29 | # ##################################### 30 | label_nbr = args.nbr_label 31 | nbr_workers = args.nbr_worker 32 | 33 | cropSize = args.crop_size 34 | 35 | train_transforms = CreateTrainTransform(cropSize,1,4) 36 | val_transforms = CreateValidationTransform() 37 | 38 | trainingSet,validationSet = GetTrainValDataset(args.dir_patients,args.test_percentage/100) 39 | 40 | # print(validationSet) 41 | # model = Create_UNETR( 42 | # input_channel=1, 43 | # label_nbr=label_nbr, 44 | # cropSize=cropSize 45 | # ).to(DEVICE) 46 | 47 | model = Create_SwinUNETR( 48 | input_channel=1, 49 | label_nbr=label_nbr, 50 | cropSize=cropSize 51 | ).to(DEVICE) 52 | 53 | # model.load_state_dict(torch.load("/Users/luciacev-admin/Documents/Projects/Benchmarks/CBCT_Seg_benchmark/data/best_model.pth",map_location=DEVICE)) 54 | 55 | 56 | 57 | torch.backends.cudnn.benchmark = True 58 | 59 | train_ds = CacheDataset( 60 | data=trainingSet, 61 | transform=train_transforms, 62 | cache_rate=1.0, 63 | num_workers=nbr_workers, 64 | ) 65 | train_loader = DataLoader( 66 | train_ds, batch_size=1, 67 | shuffle=True, 68 | num_workers=nbr_workers, 69 | pin_memory=True 70 | ) 71 | val_ds = CacheDataset( 72 | data=validationSet, 73 | transform=val_transforms, 74 | cache_rate=1.0, 75 | num_workers=nbr_workers 76 | ) 77 | val_loader = DataLoader( 78 | val_ds, batch_size=1, 79 | shuffle=False, 80 | num_workers=nbr_workers, 81 | pin_memory=True 82 | ) 83 | 84 | # case_num = 0 85 | # img = val_ds[case_num]["scan"] 86 | # label = val_ds[case_num]["seg"] 87 | # size = img.shape 88 | # PlotState(img,label,int(size[1]/2),int(size[2]/2),int(size[1]/3.5)) 89 | # for i,data in enumerate(train_ds[case_num]): 90 | # img = data["scan"] 91 | # label = data["seg"] 92 | # size = img.shape 93 | # PlotState(img,label,int(size[1]/2),int(size[2]/2),int(size[1]/2)) 94 | 95 | TM = TrainingMaster( 96 | model = model, 97 | train_loader=train_loader, 98 | val_loader=val_loader, 99 | save_model_dir=args.dir_model, 100 | save_runs_dir=args.dir_data, 101 | nbr_label = label_nbr, 102 | FOV=cropSize, 103 | device=DEVICE 104 | ) 105 | 106 | # TM.Train() 107 | # TM.Validate() 108 | TM.Process(args.max_epoch) 109 | 110 | class TrainingMaster: 111 | def __init__( 112 | self, 113 | model, 114 | train_loader, 115 | val_loader, 116 | save_model_dir, 117 | save_runs_dir, 118 | nbr_label = 2, 119 | FOV = [64,64,64], 120 | device = DEVICE, 121 | ) -> None: 122 | self.model = model 123 | self.device = device 124 | self.loss_function = DiceCELoss(to_onehot_y=True, softmax=True) 125 | self.optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) 126 | self.post_label = AsDiscrete(to_onehot=True,num_classes=nbr_label) 127 | self.post_pred = AsDiscrete(argmax=True, to_onehot=True,num_classes=nbr_label) 128 | self.dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) 129 | 130 | self.save_model_dir = save_model_dir 131 | if not os.path.exists(self.save_model_dir): 132 | os.makedirs(self.save_model_dir) 133 | 134 | run_path = save_runs_dir + "/Runs" 135 | if not os.path.exists(run_path): 136 | os.makedirs(run_path) 137 | self.tensorboard = SummaryWriter(run_path) 138 | 139 | self.val_loader = val_loader 140 | self.train_loader = train_loader 141 | self.FOV = FOV 142 | 143 | self.epoch = 0 144 | self.best_dice = 0 145 | self.loss_lst = [] 146 | self.dice_lst = [] 147 | 148 | self.predictor = 10 149 | 150 | 151 | 152 | 153 | def Process(self,num_epoch): 154 | for epoch in range(num_epoch): 155 | self.Train() 156 | self.Validate() 157 | self.epoch += 1 158 | self.tensorboard.close() 159 | 160 | def Train(self): 161 | self.model.train() 162 | epoch_loss = 0 163 | steps = 0 164 | epoch_iterator = tqdm( 165 | self.train_loader, desc="Training (loss=X.X)", dynamic_ncols=True 166 | ) 167 | for step, batch in enumerate(epoch_iterator): 168 | steps += 1 169 | x, y = (batch["scan"].to(self.device), batch["seg"].to(self.device)) 170 | 171 | # print(batch["file_name"][0]) 172 | # x, y = self.RandomPermutChannels(x,y) 173 | # print(x.shape,x.dtype,y.shape,y.dtype) 174 | logit_map = self.model(x) 175 | # print(logit_map.shape,logit_map.dtype) 176 | loss = self.loss_function(logit_map, y) 177 | loss.backward() 178 | epoch_loss += loss.item() 179 | self.optimizer.step() 180 | self.optimizer.zero_grad() 181 | epoch_iterator.set_description( 182 | "Training (loss=%2.5f)" % (loss) 183 | ) 184 | mean_loss = epoch_loss/steps 185 | self.loss_lst.append(mean_loss) 186 | self.tensorboard.add_scalar("Training loss",mean_loss,self.epoch) 187 | self.tensorboard.close() 188 | 189 | 190 | 191 | def Validate(self): 192 | self.model.eval() 193 | dice_vals = list() 194 | epoch_iterator_val = tqdm( 195 | self.val_loader, desc="Validate (dice=X.X)", dynamic_ncols=True 196 | ) 197 | with torch.no_grad(): 198 | for step, batch in enumerate(epoch_iterator_val): 199 | val_inputs, val_labels = (batch["scan"].to(self.device), batch["seg"].to(self.device)) 200 | # val_inputs, val_labels = self.RandomPermutChannels(val_inputs,val_labels) 201 | 202 | # print("IN INFO") 203 | # print(val_inputs) 204 | # print(torch.min(val_inputs),torch.max(val_inputs)) 205 | # print(val_inputs.shape) 206 | # print(val_inputs.dtype) 207 | 208 | val_outputs = sliding_window_inference(val_inputs, self.FOV, self.predictor, self.model,overlap=0.2) 209 | val_labels_list = decollate_batch(val_labels) 210 | val_labels_convert = [ 211 | self.post_label(val_label_tensor) for val_label_tensor in val_labels_list 212 | ] 213 | val_outputs_list = decollate_batch(val_outputs) 214 | val_output_convert = [ 215 | self.post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list 216 | ] 217 | self.dice_metric(y_pred=val_output_convert, y=val_labels_convert) 218 | dice = self.dice_metric.aggregate().item() 219 | dice_vals.append(dice) 220 | epoch_iterator_val.set_description( 221 | "Validate (dice=%2.5f)" % (dice) 222 | ) 223 | # self.SaveScans(val_inputs,val_outputs,step) 224 | self.dice_metric.reset() 225 | 226 | 227 | mean_dice_val = np.mean(dice_vals) 228 | self.dice_lst.append(mean_dice_val) 229 | 230 | if mean_dice_val > self.best_dice: 231 | torch.save(self.model.state_dict(), os.path.join(self.save_model_dir,"best_model.pth")) 232 | print("Model Was Saved ! Current Best Avg. Dice: {} Previous Best Avg. Dice: {}".format(mean_dice_val, self.best_dice)) 233 | self.best_dice = mean_dice_val 234 | else: 235 | print("Model Was Not Saved ! Best Avg. Dice: {} Current Avg. Dice: {}".format(self.best_dice, mean_dice_val)) 236 | 237 | self.tensorboard.add_scalar("Validation dice",mean_dice_val,self.epoch) 238 | 239 | self.PrintSlices(val_inputs,val_labels,val_outputs) 240 | self.tensorboard.close() 241 | 242 | def RandomPermutChannels(self,batch,batch2): 243 | prob = np.random.rand() 244 | if prob < 0.25: 245 | permImg = batch.permute(0,1,2,4,3) 246 | permImg2 = batch2.permute(0,1,2,4,3) 247 | elif prob < 0.50: 248 | permImg = batch.permute(0,1,4,3,2) 249 | permImg2 = batch2.permute(0,1,4,3,2) 250 | elif prob < 0.75: 251 | permImg = batch.permute(0,1,3,2,4) 252 | permImg2 = batch2.permute(0,1,3,2,4) 253 | else: 254 | permImg = batch 255 | permImg2 = batch2 256 | return permImg,permImg2 257 | 258 | def PrintSlices(self,val_inputs,val_labels,val_outputs): 259 | 260 | size = val_inputs.shape[4] 261 | seg = torch.argmax(val_outputs, dim=1).detach() 262 | 263 | inpt_lst = [] 264 | lab_lst = [] 265 | seg_lst = [] 266 | for slice in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8]: 267 | slice_nbr = int(size*slice) 268 | 269 | inpt_lst.append(val_inputs.cpu()[0, 0, :, :, slice_nbr].unsqueeze(0)) 270 | lab_lst.append(val_labels.cpu()[0, 0, :, :, slice_nbr].unsqueeze(0)) 271 | seg_lst.append(seg.cpu()[0, :, :, slice_nbr].unsqueeze(0)) 272 | 273 | img_lst = inpt_lst + lab_lst + seg_lst 274 | slice_view = torch.cat(img_lst,dim=0).unsqueeze(1) 275 | self.tensorboard.add_images("Validation images",slice_view,self.epoch) 276 | 277 | def SaveScans(self,val_inputs,val_outputs,step): 278 | 279 | data = torch.argmax(val_outputs, dim=1).detach().cpu().type(torch.int16) 280 | print(data.shape) 281 | img = data.numpy()[0][:] 282 | output = sitk.GetImageFromArray(img) 283 | 284 | writer = sitk.ImageFileWriter() 285 | writer.SetFileName(str(step)+'_seg.nii.gz') 286 | writer.Execute(output) 287 | 288 | img = val_inputs.squeeze(0).numpy()[0][:] 289 | output = sitk.GetImageFromArray(img) 290 | 291 | writer = sitk.ImageFileWriter() 292 | writer.SetFileName(str(step)+'_scan.nii.gz') 293 | writer.Execute(output) 294 | 295 | 296 | 297 | 298 | 299 | # ##################################### 300 | # Args 301 | # ##################################### 302 | 303 | if __name__ == '__main__': 304 | parser = argparse.ArgumentParser(description='Training to find ROI for Automatic Landmarks Identification', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 305 | 306 | input_group = parser.add_argument_group('dir') 307 | input_group.add_argument('--dir_project', type=str, help='Directory with all the project',default='/Users/luciacev-admin/Documents/Projects/Benchmarks/CBCT_Seg_benchmark') 308 | input_group.add_argument('--dir_data', type=str, help='Input directory with 3D images', default=parser.parse_args().dir_project+'/data') 309 | input_group.add_argument('--dir_patients', type=str, help='Input directory with 3D images',default=parser.parse_args().dir_data+'/Patients') #default = "/Users/luciacev-admin/Desktop/Mandible_Dataset")# 310 | input_group.add_argument('--dir_model', type=str, help='Output directory of the training',default=parser.parse_args().dir_data+'/Models') 311 | 312 | input_group.add_argument('-mn', '--model_name', type=str, help='Name of the model', default="MandSeg_model") 313 | input_group.add_argument('-vp', '--test_percentage', type=int, help='Percentage of data to keep for validation', default=13) 314 | input_group.add_argument('-cs', '--crop_size', nargs="+", type=float, help='Wanted crop size', default=[128 ,128, 128]) 315 | input_group.add_argument('-me', '--max_epoch', type=int, help='Number of training epocs', default=250) 316 | input_group.add_argument('-nl', '--nbr_label', type=int, help='Number of label', default=6) 317 | input_group.add_argument('-bs', '--batch_size', type=int, help='batch size', default=10) 318 | input_group.add_argument('-nw', '--nbr_worker', type=int, help='Number of worker', default=10) 319 | 320 | 321 | 322 | args = parser.parse_args() 323 | 324 | main(args) -------------------------------------------------------------------------------- /MULTI_SEG/src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | import itk 4 | import vtk 5 | 6 | import os 7 | import matplotlib.pyplot as plt 8 | from scipy import stats 9 | 10 | from tqdm import tqdm 11 | from sklearn.model_selection import train_test_split 12 | import torch 13 | 14 | import datetime 15 | import glob 16 | import sys 17 | import cc3d 18 | import shutil 19 | 20 | 21 | 22 | # ----- MONAI ------ 23 | 24 | from monai.inferers import sliding_window_inference 25 | from monai.transforms import ( 26 | AsDiscrete, 27 | AddChanneld, 28 | AddChannel, 29 | Compose, 30 | CropForegroundd, 31 | LoadImage, 32 | LoadImaged, 33 | Orientationd, 34 | RandFlipd, 35 | RandCropByPosNegLabeld, 36 | RandSpatialCropd, 37 | RandShiftIntensityd, 38 | ScaleIntensityd, 39 | ScaleIntensity, 40 | Spacingd, 41 | Spacing, 42 | Rotate90d, 43 | RandRotate90d, 44 | ToTensord, 45 | ToTensor, 46 | SaveImaged, 47 | SaveImage, 48 | RandCropByLabelClassesd, 49 | Lambdad, 50 | CastToTyped, 51 | SpatialCrop, 52 | BorderPadd, 53 | RandAdjustContrastd, 54 | HistogramNormalized, 55 | NormalizeIntensityd, 56 | ) 57 | 58 | #region Global variables 59 | 60 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 61 | DATA_TYPE = torch.float32 62 | 63 | 64 | TRANSLATE ={ 65 | "Mandible" : "MAND", 66 | "Maxilla" : "MAX", 67 | "Cranial-base" : "CB", 68 | "Cervical-vertebra" : "CV", 69 | "Root-canal" : "RC", 70 | "Mandibular-canal" : "MCAN", 71 | "Upper-airway" : "UAW", 72 | "Skin" : "SKIN", 73 | "Teeth" : "TEETH" 74 | } 75 | 76 | INV_TRANSLATE = {} 77 | for k,v in TRANSLATE.items(): 78 | INV_TRANSLATE[v] = k 79 | 80 | LABELS = { 81 | 82 | "LARGE":{ 83 | "MAND" : 1, 84 | "CB" : 2, 85 | "UAW" : 3, 86 | "MAX" : 4, 87 | "CV" : 5, 88 | "SKIN" : 6, 89 | }, 90 | "SMALL":{ 91 | "MAND" : 1, 92 | "RC" : 2, 93 | "MAX" : 4, 94 | } 95 | } 96 | 97 | 98 | LABEL_COLORS = { 99 | 1: [216, 101, 79], 100 | 2: [128, 174, 128], 101 | 3: [0, 0, 0], 102 | 4: [230, 220, 70], 103 | 5: [111, 184, 210], 104 | 6: [172, 122, 101], 105 | } 106 | 107 | NAMES_FROM_LABELS = {"LARGE":{}, "SMALL":{}} 108 | for group,data in LABELS.items(): 109 | for k,v in data.items(): 110 | NAMES_FROM_LABELS[group][v] = INV_TRANSLATE[k] 111 | 112 | 113 | MODELS_GROUP = { 114 | "LARGE": { 115 | "FF": 116 | { 117 | "MAND" : 1, 118 | "CB" : 2, 119 | "UAW" : 3, 120 | "MAX" : 4, 121 | "CV" : 5, 122 | }, 123 | "SKIN": 124 | { 125 | "SKIN" : 1, 126 | } 127 | }, 128 | 129 | 130 | "SMALL": { 131 | "HD-MAND": 132 | { 133 | "MAND" : 1 134 | }, 135 | "HD-MAX": 136 | { 137 | "MAX" : 1 138 | }, 139 | "RC": 140 | { 141 | "RC" : 1 142 | }, 143 | }, 144 | } 145 | 146 | #endregion 147 | 148 | 149 | 150 | 151 | 152 | 153 | ######## ######## ### ## ## ###### ######## ####### ######## ## ## ###### 154 | ## ## ## ## ## ### ## ## ## ## ## ## ## ## ### ### ## ## 155 | ## ## ## ## ## #### ## ## ## ## ## ## ## #### #### ## 156 | ## ######## ## ## ## ## ## ###### ###### ## ## ######## ## ### ## ###### 157 | ## ## ## ######### ## #### ## ## ## ## ## ## ## ## ## 158 | ## ## ## ## ## ## ### ## ## ## ## ## ## ## ## ## ## ## 159 | ## ## ## ## ## ## ## ###### ## ####### ## ## ## ## ###### 160 | 161 | 162 | def CreateTrainTransform(CropSize = [64,64,64],padding=10,num_sample=10): 163 | train_transforms = Compose( 164 | [ 165 | LoadImaged(keys=["scan", "seg"]), 166 | AddChanneld(keys=["scan", "seg"]), 167 | BorderPadd(keys=["scan", "seg"],spatial_border=padding), 168 | ScaleIntensityd( 169 | keys=["scan"],minv = 0.0, maxv = 1.0, factor = None 170 | ), 171 | # CropForegroundd(keys=["scan", "seg"], source_key="scan"), 172 | RandCropByPosNegLabeld( 173 | keys=["scan", "seg"], 174 | label_key="seg", 175 | spatial_size=CropSize, 176 | pos=1, 177 | neg=1, 178 | num_samples=num_sample, 179 | image_key="scan", 180 | image_threshold=0, 181 | ), 182 | RandFlipd( 183 | keys=["scan", "seg"], 184 | spatial_axis=[0], 185 | prob=0.20, 186 | ), 187 | RandFlipd( 188 | keys=["scan", "seg"], 189 | spatial_axis=[1], 190 | prob=0.20, 191 | ), 192 | RandFlipd( 193 | keys=["scan", "seg"], 194 | spatial_axis=[2], 195 | prob=0.20, 196 | ), 197 | RandRotate90d( 198 | keys=["scan", "seg"], 199 | prob=0.10, 200 | max_k=3, 201 | ), 202 | RandShiftIntensityd( 203 | keys=["scan"], 204 | offsets=0.10, 205 | prob=0.50, 206 | ), 207 | RandAdjustContrastd( 208 | keys=["scan"], 209 | prob=0.8, 210 | gamma = (0.5,2) 211 | ), 212 | ToTensord(keys=["scan", "seg"]), 213 | ] 214 | ) 215 | 216 | return train_transforms 217 | 218 | def CreateValidationTransform(): 219 | 220 | val_transforms = Compose( 221 | [ 222 | LoadImaged(keys=["scan", "seg"]), 223 | AddChanneld(keys=["scan", "seg"]), 224 | ScaleIntensityd( 225 | keys=["scan"],minv = 0.0, maxv = 1.0, factor = None 226 | ), 227 | # CropForegroundd(keys=["scan", "seg"], source_key="scan"), 228 | RandFlipd( 229 | keys=["scan", "seg"], 230 | spatial_axis=[0], 231 | prob=0.20, 232 | ), 233 | RandFlipd( 234 | keys=["scan", "seg"], 235 | spatial_axis=[1], 236 | prob=0.20, 237 | ), 238 | RandFlipd( 239 | keys=["scan", "seg"], 240 | spatial_axis=[2], 241 | prob=0.20, 242 | ), 243 | RandRotate90d( 244 | keys=["scan", "seg"], 245 | prob=0.10, 246 | max_k=3, 247 | ), 248 | RandAdjustContrastd( 249 | keys=["scan"], 250 | prob=0.8, 251 | gamma = (0.5,2) 252 | ), 253 | ToTensord(keys=["scan", "seg"]), 254 | ] 255 | ) 256 | 257 | 258 | # val_transforms = Compose( 259 | # [ 260 | # LoadImaged(keys=["scan", "seg"]), 261 | # AddChanneld(keys=["scan", "seg"]), 262 | # ScaleIntensityd( 263 | # keys=["scan"],minv = 0.0, maxv = 1.0, factor = None 264 | # ), 265 | # ToTensord(keys=["scan", "seg"]), 266 | # ] 267 | # ) 268 | 269 | return val_transforms 270 | 271 | def CreatePredTransform(spacing): 272 | pred_transforms = Compose( 273 | [ 274 | LoadImaged(keys=["scan"]), 275 | AddChanneld(keys=["scan"]), 276 | ScaleIntensityd( 277 | keys=["scan"],minv = 0.0, maxv = 1.0, factor = None 278 | ), 279 | Spacingd(keys=["scan"],pixdim=spacing), 280 | ToTensord(keys=["scan"]), 281 | ] 282 | ) 283 | return pred_transforms 284 | 285 | def CreatePredictTransform(data,spacing): 286 | 287 | pre_transforms = Compose( 288 | [AddChannel(),ScaleIntensity(minv = 0.0, maxv = 1.0, factor = None),ToTensor()] 289 | ) 290 | 291 | input_img = sitk.ReadImage(data) 292 | img = input_img 293 | img = ItkToSitk(Rescale(data,[spacing,spacing,spacing])) 294 | img = sitk.GetArrayFromImage(img) 295 | # img = CorrectImgContrast(img,0.,0.99) 296 | pre_img = pre_transforms(img) 297 | pre_img = pre_img.type(DATA_TYPE) 298 | return pre_img,input_img 299 | 300 | 301 | def CorrectImgContrast(img,min_porcent,max_porcent): 302 | img_min = np.min(img) 303 | img_max = np.max(img) 304 | img_range = img_max - img_min 305 | # print(img_min,img_max,img_range) 306 | 307 | definition = 1000 308 | histo = np.histogram(img,definition) 309 | cum = np.cumsum(histo[0]) 310 | cum = cum - np.min(cum) 311 | cum = cum / np.max(cum) 312 | 313 | res_high = list(map(lambda i: i> max_porcent, cum)).index(True) 314 | res_max = (res_high * img_range)/definition + img_min 315 | 316 | res_low = list(map(lambda i: i> min_porcent, cum)).index(True) 317 | res_min = (res_low * img_range)/definition + img_min 318 | 319 | img = np.where(img > res_max, res_max,img) 320 | img = np.where(img < res_min, res_min,img) 321 | 322 | return img 323 | 324 | ######## ######## ### #### ## ## #### ## ## ###### 325 | ## ## ## ## ## ## ### ## ## ### ## ## ## 326 | ## ## ## ## ## ## #### ## ## #### ## ## 327 | ## ######## ## ## ## ## ## ## ## ## ## ## ## #### 328 | ## ## ## ######### ## ## #### ## ## #### ## ## 329 | ## ## ## ## ## ## ## ### ## ## ### ## ## 330 | ## ## ## ## ## #### ## ## #### ## ## ###### 331 | 332 | 333 | def GenWorkSpace(dir,test_percentage,out_dir): 334 | 335 | data_dic = {} 336 | normpath = os.path.normpath("/".join([dir, '**', ''])) 337 | for img_fn in sorted(glob.iglob(normpath, recursive=True)): 338 | # print(img_fn) 339 | basename = os.path.basename(img_fn) 340 | 341 | if True in [ext in basename for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 342 | file_name = basename.split(".")[0] 343 | elements_dash = file_name.split("-") 344 | file_folder = elements_dash[0] 345 | info = elements_dash[1].split("_scan_Sp")[0].split("_seg_Sp") 346 | patient = info[0] 347 | 348 | # print(patient) 349 | 350 | if file_folder not in data_dic.keys(): 351 | data_dic[file_folder] = {} 352 | 353 | if patient not in data_dic[file_folder].keys(): 354 | data_dic[file_folder][patient] = {} 355 | 356 | if "_scan" in basename: 357 | data_dic[file_folder][patient]["scan"] = img_fn 358 | 359 | elif "_seg" in basename: 360 | data_dic[file_folder][patient]["seg"] = img_fn 361 | else: 362 | print("----> Unrecognise CBCT file found at :", img_fn) 363 | 364 | # print(data_dic) 365 | error = False 366 | folder_dic = {} 367 | for folder,patients in data_dic.items(): 368 | if folder not in folder_dic.keys(): 369 | folder_dic[folder] = [] 370 | for patient,data in patients.items(): 371 | if "scan" not in data.keys(): 372 | print("Missing scan for patient :",patient,"at",data["dir"]) 373 | error = True 374 | if "seg" not in data.keys(): 375 | print("Missing segmentation patient :",patient,"at",data["dir"]) 376 | error = True 377 | folder_dic[folder].append(data) 378 | 379 | if error: 380 | print("ERROR : folder have missing/unrecognise files", file=sys.stderr) 381 | raise 382 | 383 | 384 | # print(folder_dic) 385 | train_data,valid_data = [],[] 386 | num_patient = 0 387 | nbr_cv_fold = int(1/test_percentage) 388 | # print(nbr_cv_fold) 389 | # nbr_cv_fold = 1 390 | i = 0 391 | for i in range(nbr_cv_fold): 392 | 393 | cv_dir_out = os.path.join(out_dir,"CV_fold_" + str(i)) 394 | if not os.path.exists(cv_dir_out): 395 | os.makedirs(cv_dir_out) 396 | 397 | data_fold = os.path.join(cv_dir_out,"data") 398 | if not os.path.exists(data_fold): 399 | os.makedirs(data_fold) 400 | 401 | patients_fold = os.path.join(data_fold,"Patients") 402 | if not os.path.exists(patients_fold): 403 | os.makedirs(patients_fold) 404 | 405 | test_fold = os.path.join(data_fold,"test") 406 | if not os.path.exists(test_fold): 407 | os.makedirs(test_fold) 408 | 409 | for folder,patients in folder_dic.items(): 410 | 411 | len_lst = len(patients) 412 | len_test = int(len_lst/nbr_cv_fold) 413 | start = i*len_test 414 | end = (i+1)*len_test 415 | if end > len_lst: end = len_lst 416 | training_patients = patients[:start] + patients[end:] 417 | test_patients = patients[start:end] 418 | 419 | train_cv_dir_out = os.path.join(patients_fold,folder) 420 | if not os.path.exists(train_cv_dir_out): 421 | os.makedirs(train_cv_dir_out) 422 | 423 | for patient in training_patients: 424 | shutil.copyfile(patient["scan"], os.path.join(train_cv_dir_out,os.path.basename(patient["scan"]))) 425 | shutil.copyfile(patient["seg"], os.path.join(train_cv_dir_out,os.path.basename(patient["seg"]))) 426 | 427 | test_cv_dir_out = os.path.join(test_fold,folder) 428 | if not os.path.exists(test_cv_dir_out): 429 | os.makedirs(test_cv_dir_out) 430 | 431 | for patient in test_patients: 432 | shutil.copyfile(patient["scan"], os.path.join(test_cv_dir_out,os.path.basename(patient["scan"]))) 433 | shutil.copyfile(patient["seg"], os.path.join(test_cv_dir_out,os.path.basename(patient["seg"]))) 434 | 435 | 436 | # print(training_patients) 437 | # print(test_patients) 438 | 439 | 440 | 441 | def GetTrainValDataset(dir,val_percentage): 442 | data_dic = {} 443 | 444 | print("Loading data from :",dir) 445 | normpath = os.path.normpath("/".join([dir, '**', ''])) 446 | for img_fn in sorted(glob.iglob(normpath, recursive=True)): 447 | # print(img_fn) 448 | basename = os.path.basename(img_fn) 449 | # print(basename) 450 | 451 | if True in [ext in basename for ext in [".nrrd", ".nrrd.gz", ".nii", ".nii.gz", ".gipl", ".gipl.gz"]]: 452 | file_name = basename.split(".")[0] 453 | # elements_uder = file_name.split("_") 454 | patient = basename.split("_MERGED")[0].split("_scan")[0].split("_SKIN")[0] 455 | file_folder = os.path.basename(os.path.dirname(img_fn)) 456 | 457 | # elements_dash = file_name.split("-") 458 | # file_folder = elements_dash[0] 459 | # info = elements_dash[1].split("_scan_Sp")[0].split("_seg_Sp") 460 | # patient = info[0] 461 | 462 | 463 | # file_folder = "test" 464 | # patient = elements_uder[0] 465 | 466 | # print(patient) 467 | 468 | if file_folder not in data_dic.keys(): 469 | data_dic[file_folder] = {} 470 | 471 | if patient not in data_dic[file_folder].keys(): 472 | data_dic[file_folder][patient] = {} 473 | 474 | if "_scan" in basename: 475 | data_dic[file_folder][patient]["scan"] = img_fn 476 | data_dic[file_folder][patient]["file_name"] = img_fn 477 | 478 | elif "MERGED-Seg" in basename: 479 | data_dic[file_folder][patient]["seg"] = img_fn 480 | 481 | # seg_img = sitk.ReadImage(img_fn) 482 | # seg_arr = sitk.GetArrayFromImage(seg_img) 483 | # seg_max = np.max(seg_arr) 484 | # seg_min = np.min(seg_arr) 485 | 486 | 487 | # if seg_max > 6: 488 | # print("----> Segmentation image has more than 6 labels :",img_fn) 489 | 490 | # for label in range(1,6): 491 | # # print(label) 492 | # if label not in seg_arr: 493 | # print(f"----> Segmentation image has missing label {label} :",img_fn) 494 | 495 | # if seg_max < 0: 496 | # print("----> Segmentation image has neg label :",img_fn) 497 | # else: 498 | # print("----> Unrecognise CBCT file found at :", img_fn) 499 | 500 | 501 | # print(data_dic) 502 | error = False 503 | folder_dic = {} 504 | for folder,patients in data_dic.items(): 505 | if folder not in folder_dic.keys(): 506 | folder_dic[folder] = [] 507 | for patient,data in patients.items(): 508 | if "scan" not in data.keys(): 509 | print("Missing scan for patient :",patient) 510 | error = True 511 | if "seg" not in data.keys(): 512 | print("Missing segmentation patient :",patient) 513 | error = True 514 | folder_dic[folder].append(data) 515 | 516 | if error: 517 | print("ERROR : folder have missing/unrecognise files", file=sys.stderr) 518 | raise 519 | 520 | 521 | # print(folder_dic) 522 | train_data,valid_data = [],[] 523 | num_patient = 0 524 | for folder,patients in folder_dic.items(): 525 | tr,val = train_test_split(patients,test_size=val_percentage,shuffle=True) 526 | train_data += tr 527 | valid_data += val 528 | num_patient += len(patients) 529 | 530 | print("Total patient:", num_patient) 531 | 532 | return train_data,valid_data 533 | 534 | 535 | ######## ####### ####### ## ###### 536 | ## ## ## ## ## ## ## ## 537 | ## ## ## ## ## ## ## 538 | ## ## ## ## ## ## ###### 539 | ## ## ## ## ## ## ## 540 | ## ## ## ## ## ## ## ## 541 | ## ####### ####### ######## ###### 542 | 543 | def MergeSeg(seg_path_dic,out_path,seg_order): 544 | merge_lst = [] 545 | for id in seg_order: 546 | if id in seg_path_dic.keys(): 547 | merge_lst.append(seg_path_dic[id]) 548 | 549 | first_img = sitk.ReadImage(merge_lst[0]) 550 | main_seg = sitk.GetArrayFromImage(first_img) 551 | for i in range(len(merge_lst)-1): 552 | label = i+2 553 | img = sitk.ReadImage(merge_lst[i+1]) 554 | seg = sitk.GetArrayFromImage(img) 555 | main_seg = np.where(seg==1,label,main_seg) 556 | 557 | output = sitk.GetImageFromArray(main_seg) 558 | output.SetSpacing(first_img.GetSpacing()) 559 | output.SetDirection(first_img.GetDirection()) 560 | output.SetOrigin(first_img.GetOrigin()) 561 | output = sitk.Cast(output, sitk.sitkInt16) 562 | 563 | writer = sitk.ImageFileWriter() 564 | writer.SetFileName(out_path) 565 | writer.Execute(output) 566 | return output 567 | 568 | 569 | def CorrectHisto(filepath,outpath,min_porcent=0.01,max_porcent = 0.95,i_min=-1500, i_max=4000): 570 | 571 | print("Correcting scan contrast :", filepath) 572 | input_img = sitk.ReadImage(filepath) 573 | input_img = sitk.Cast(input_img, sitk.sitkFloat32) 574 | img = sitk.GetArrayFromImage(input_img) 575 | 576 | 577 | img_min = np.min(img) 578 | img_max = np.max(img) 579 | img_range = img_max - img_min 580 | # print(img_min,img_max,img_range) 581 | 582 | definition = 1000 583 | histo = np.histogram(img,definition) 584 | cum = np.cumsum(histo[0]) 585 | cum = cum - np.min(cum) 586 | cum = cum / np.max(cum) 587 | 588 | res_high = list(map(lambda i: i> max_porcent, cum)).index(True) 589 | res_max = (res_high * img_range)/definition + img_min 590 | 591 | res_low = list(map(lambda i: i> min_porcent, cum)).index(True) 592 | res_min = (res_low * img_range)/definition + img_min 593 | 594 | res_min = max(res_min,i_min) 595 | res_max = min(res_max,i_max) 596 | 597 | 598 | # print(res_min,res_min) 599 | 600 | img = np.where(img > res_max, res_max,img) 601 | img = np.where(img < res_min, res_min,img) 602 | 603 | output = sitk.GetImageFromArray(img) 604 | output.SetSpacing(input_img.GetSpacing()) 605 | output.SetDirection(input_img.GetDirection()) 606 | output.SetOrigin(input_img.GetOrigin()) 607 | output = sitk.Cast(output, sitk.sitkInt16) 608 | 609 | 610 | writer = sitk.ImageFileWriter() 611 | writer.SetFileName(outpath) 612 | writer.Execute(output) 613 | return output 614 | 615 | 616 | def CloseCBCTSeg(filepath,outpath, closing_radius = 1): 617 | """ 618 | Close the holes in the CBCT 619 | 620 | Parameters 621 | ---------- 622 | filePath 623 | path of the image file 624 | radius 625 | radius of the closing to apply to the seg 626 | outpath 627 | path to save the new image 628 | """ 629 | 630 | print("Reading:", filepath) 631 | input_img = sitk.ReadImage(filepath) 632 | img = sitk.GetArrayFromImage(input_img) 633 | 634 | img = np.where(img > 0, 1,img) 635 | output = sitk.GetImageFromArray(img) 636 | output.SetSpacing(input_img.GetSpacing()) 637 | output.SetDirection(input_img.GetDirection()) 638 | output.SetOrigin(input_img.GetOrigin()) 639 | 640 | output = sitk.BinaryDilate(output, [closing_radius] * output.GetDimension()) 641 | output = sitk.BinaryFillhole(output) 642 | output = sitk.BinaryErode(output, [closing_radius] * output.GetDimension()) 643 | 644 | writer = sitk.ImageFileWriter() 645 | writer.SetFileName(outpath) 646 | writer.Execute(output) 647 | return output 648 | 649 | def ItkToSitk(itk_img): 650 | new_sitk_img = sitk.GetImageFromArray(itk.GetArrayFromImage(itk_img), isVector=itk_img.GetNumberOfComponentsPerPixel()>1) 651 | new_sitk_img.SetOrigin(tuple(itk_img.GetOrigin())) 652 | new_sitk_img.SetSpacing(tuple(itk_img.GetSpacing())) 653 | new_sitk_img.SetDirection(itk.GetArrayFromMatrix(itk_img.GetDirection()).flatten()) 654 | return new_sitk_img 655 | 656 | 657 | def Rescale(filepath,output_spacing=[0.5, 0.5, 0.5]): 658 | print("Resample :", filepath, ", with spacing :", output_spacing) 659 | img = itk.imread(filepath) 660 | 661 | spacing = np.array(img.GetSpacing()) 662 | output_spacing = np.array(output_spacing) 663 | 664 | if not np.array_equal(spacing,output_spacing): 665 | 666 | size = itk.size(img) 667 | scale = spacing/output_spacing 668 | 669 | output_size = (np.array(size)*scale).astype(int).tolist() 670 | output_origin = img.GetOrigin() 671 | 672 | #Find new origin 673 | output_physical_size = np.array(output_size)*np.array(output_spacing) 674 | input_physical_size = np.array(size)*spacing 675 | output_origin = np.array(output_origin) - (output_physical_size - input_physical_size)/2.0 676 | 677 | img_info = itk.template(img)[1] 678 | pixel_type = img_info[0] 679 | pixel_dimension = img_info[1] 680 | 681 | VectorImageType = itk.Image[pixel_type, pixel_dimension] 682 | InterpolatorType = itk.LinearInterpolateImageFunction[VectorImageType, itk.D] 683 | 684 | interpolator = InterpolatorType.New() 685 | resampled_img = ResampleImage(img,output_size,output_spacing,output_origin,img.GetDirection(),interpolator,VectorImageType) 686 | return resampled_img 687 | 688 | else: 689 | return img 690 | 691 | 692 | 693 | def ResampleImage(input,size,spacing,origin,direction,interpolator,IVectorImageType,OVectorImageType): 694 | ResampleType = itk.ResampleImageFilter[IVectorImageType, OVectorImageType] 695 | 696 | # print(input) 697 | 698 | resampleImageFilter = ResampleType.New() 699 | resampleImageFilter.SetInput(input) 700 | resampleImageFilter.SetOutputSpacing(spacing.tolist()) 701 | resampleImageFilter.SetOutputOrigin(origin) 702 | resampleImageFilter.SetOutputDirection(direction) 703 | resampleImageFilter.SetInterpolator(interpolator) 704 | resampleImageFilter.SetSize(size) 705 | resampleImageFilter.Update() 706 | 707 | resampled_img = resampleImageFilter.GetOutput() 708 | return resampled_img 709 | 710 | 711 | def SetSpacing(filepath,output_spacing=[0.5, 0.5, 0.5],interpolator="Linear",outpath=-1): 712 | """ 713 | Set the spacing of the image at the wanted scale 714 | 715 | Parameters 716 | ---------- 717 | filePath 718 | path of the image file 719 | output_spacing 720 | whanted spacing of the new image file (default : [0.5, 0.5, 0.5]) 721 | outpath 722 | path to save the new image 723 | """ 724 | 725 | print("Reading:", filepath) 726 | img = itk.imread(filepath) 727 | 728 | # Dimension = 3 729 | # InputPixelType = itk.D 730 | 731 | # InputImageType = itk.Image[InputPixelType, Dimension] 732 | 733 | # reader = itk.ImageFileReader[InputImageType].New() 734 | # reader.SetFileName(filepath) 735 | # img = reader.GetOutput() 736 | 737 | spacing = np.array(img.GetSpacing()) 738 | output_spacing = np.array(output_spacing) 739 | 740 | if not np.array_equal(spacing,output_spacing): 741 | 742 | size = itk.size(img) 743 | scale = spacing/output_spacing 744 | 745 | output_size = (np.array(size)*scale).astype(int).tolist() 746 | output_origin = img.GetOrigin() 747 | 748 | #Find new origin 749 | # output_physical_size = np.array(output_size)*np.array(output_spacing) 750 | # input_physical_size = np.array(size)*spacing 751 | # output_origin = np.array(input_origin) - (output_physical_size - input_physical_size)/2.0 752 | 753 | img_info = itk.template(img)[1] 754 | pixel_type = img_info[0] 755 | pixel_dimension = img_info[1] 756 | 757 | print(pixel_type) 758 | 759 | VectorImageType = itk.Image[pixel_type, pixel_dimension] 760 | 761 | if interpolator == "NearestNeighbor": 762 | InterpolatorType = itk.NearestNeighborInterpolateImageFunction[VectorImageType, itk.D] 763 | # print("Rescale Seg with spacing :", output_spacing) 764 | elif interpolator == "Linear": 765 | InterpolatorType = itk.LinearInterpolateImageFunction[VectorImageType, itk.D] 766 | # print("Rescale Scan with spacing :", output_spacing) 767 | 768 | interpolator = InterpolatorType.New() 769 | resampled_img = ResampleImage(img,output_size,output_spacing,output_origin,img.GetDirection(),interpolator,VectorImageType,VectorImageType) 770 | 771 | if outpath != -1: 772 | itk.imwrite(resampled_img, outpath) 773 | return resampled_img 774 | 775 | else: 776 | # print("Already at the wanted spacing") 777 | if outpath != -1: 778 | itk.imwrite(img, outpath) 779 | return img 780 | 781 | 782 | 783 | def SavePrediction(img,ref_filepath, outpath, output_spacing): 784 | 785 | # print("Saving prediction for : ", ref_filepath) 786 | 787 | # print(data) 788 | 789 | ref_img = sitk.ReadImage(ref_filepath) 790 | 791 | 792 | 793 | output = sitk.GetImageFromArray(img) 794 | output.SetSpacing(output_spacing) 795 | output.SetDirection(ref_img.GetDirection()) 796 | output.SetOrigin(ref_img.GetOrigin()) 797 | output = sitk.Cast(output, sitk.sitkInt16) 798 | 799 | writer = sitk.ImageFileWriter() 800 | writer.SetFileName(outpath) 801 | writer.Execute(output) 802 | 803 | 804 | 805 | def CleanScan(file_path): 806 | input_img = sitk.ReadImage(file_path) 807 | 808 | 809 | closing_radius = 2 810 | output = sitk.BinaryDilate(input_img, [closing_radius] * input_img.GetDimension()) 811 | output = sitk.BinaryFillhole(output) 812 | output = sitk.BinaryErode(output, [closing_radius] * output.GetDimension()) 813 | 814 | labels_in = sitk.GetArrayFromImage(input_img) 815 | out, N = cc3d.largest_k( 816 | labels_in, k=1, 817 | connectivity=26, delta=0, 818 | return_N=True, 819 | ) 820 | output = sitk.GetImageFromArray(out) 821 | # closed = sitk.GetArrayFromImage(output) 822 | 823 | # stats = cc3d.statistics(out) 824 | # mand_bbox = stats['bounding_boxes'][1] 825 | # rng_lst = [] 826 | # mid_lst = [] 827 | # for slices in mand_bbox: 828 | # rng = slices.stop-slices.start 829 | # mid = (2/3)*rng+slices.start 830 | # rng_lst.append(rng) 831 | # mid_lst.append(mid) 832 | 833 | # merge_slice = int(mid_lst[0]) 834 | # out = np.concatenate((out[:merge_slice,:,:],closed[merge_slice:,:,:]),axis=0) 835 | # output = sitk.GetImageFromArray(out) 836 | 837 | output.SetSpacing(input_img.GetSpacing()) 838 | output.SetDirection(input_img.GetDirection()) 839 | output.SetOrigin(input_img.GetOrigin()) 840 | output = sitk.Cast(output, sitk.sitkInt16) 841 | 842 | writer = sitk.ImageFileWriter() 843 | writer.SetFileName(file_path) 844 | writer.Execute(output) 845 | 846 | 847 | def SetSpacingFromRef(filepath,refFile,interpolator = "NearestNeighbor",outpath=-1): 848 | """ 849 | Set the spacing of the image the same as the reference image 850 | 851 | Parameters 852 | ---------- 853 | filepath 854 | image file 855 | refFile 856 | path of the reference image 857 | interpolator 858 | Type of interpolation 'NearestNeighbor' or 'Linear' 859 | outpath 860 | path to save the new image 861 | """ 862 | 863 | img = itk.imread(filepath) 864 | ref = itk.imread(refFile) 865 | 866 | img_sp = np.array(img.GetSpacing()) 867 | img_size = np.array(itk.size(img)) 868 | 869 | ref_sp = np.array(ref.GetSpacing()) 870 | ref_size = np.array(itk.size(ref)) 871 | ref_origin = ref.GetOrigin() 872 | ref_direction = ref.GetDirection() 873 | 874 | Dimension = 3 875 | InputPixelType = itk.D 876 | 877 | InputImageType = itk.Image[InputPixelType, Dimension] 878 | 879 | reader = itk.ImageFileReader[InputImageType].New() 880 | reader.SetFileName(filepath) 881 | img = reader.GetOutput() 882 | 883 | # reader2 = itk.ImageFileReader[InputImageType].New() 884 | # reader2.SetFileName(refFile) 885 | # ref = reader2.GetOutput() 886 | 887 | if not (np.array_equal(img_sp,ref_sp) and np.array_equal(img_size,ref_size)): 888 | img_info = itk.template(img)[1] 889 | Ipixel_type = img_info[0] 890 | Ipixel_dimension = img_info[1] 891 | 892 | ref_info = itk.template(ref)[1] 893 | Opixel_type = ref_info[0] 894 | Opixel_dimension = ref_info[1] 895 | 896 | OVectorImageType = itk.Image[Opixel_type, Opixel_dimension] 897 | IVectorImageType = itk.Image[Ipixel_type, Ipixel_dimension] 898 | 899 | if interpolator == "NearestNeighbor": 900 | InterpolatorType = itk.NearestNeighborInterpolateImageFunction[InputImageType, itk.D] 901 | # print("Rescale Seg with spacing :", output_spacing) 902 | elif interpolator == "Linear": 903 | InterpolatorType = itk.LinearInterpolateImageFunction[InputImageType, itk.D] 904 | # print("Rescale Scan with spacing :", output_spacing) 905 | 906 | interpolator = InterpolatorType.New() 907 | resampled_img = ResampleImage(img,ref_size.tolist(),ref_sp,ref_origin,ref_direction,interpolator,InputImageType,InputImageType) 908 | 909 | output = ItkToSitk(resampled_img) 910 | output = sitk.Cast(output, sitk.sitkInt16) 911 | 912 | # if img_sp[0] > ref_sp[0]: 913 | closing_radius = 2 914 | MedianFilter = sitk.MedianImageFilter() 915 | MedianFilter.SetRadius(closing_radius) 916 | output = MedianFilter.Execute(output) 917 | 918 | 919 | if outpath != -1: 920 | writer = sitk.ImageFileWriter() 921 | writer.SetFileName(outpath) 922 | writer.Execute(output) 923 | # itk.imwrite(resampled_img, outpath) 924 | return output 925 | 926 | else: 927 | output = ItkToSitk(img) 928 | output = sitk.Cast(output, sitk.sitkInt16) 929 | if outpath != -1: 930 | writer = sitk.ImageFileWriter() 931 | writer.SetFileName(outpath) 932 | writer.Execute(output) 933 | return output 934 | 935 | 936 | 937 | def KeepLabel(filepath,outpath,labelToKeep): 938 | 939 | # print("Reading:", filepath) 940 | input_img = sitk.ReadImage(filepath) 941 | img = sitk.GetArrayFromImage(input_img) 942 | 943 | for i in range(np.max(img)): 944 | label = i+1 945 | if label != labelToKeep: 946 | img = np.where(img == label, 0,img) 947 | 948 | img = np.where(img > 0, 1,img) 949 | 950 | output = sitk.GetImageFromArray(img) 951 | output.SetSpacing(input_img.GetSpacing()) 952 | output.SetDirection(input_img.GetDirection()) 953 | output.SetOrigin(input_img.GetOrigin()) 954 | 955 | writer = sitk.ImageFileWriter() 956 | writer.SetFileName(outpath) 957 | writer.Execute(output) 958 | return output 959 | 960 | def Write(vtkdata, output_name): 961 | outfilename = output_name 962 | print("Writting:", outfilename) 963 | polydatawriter = vtk.vtkPolyDataWriter() 964 | polydatawriter.SetFileName(outfilename) 965 | polydatawriter.SetInputData(vtkdata) 966 | polydatawriter.Write() 967 | 968 | def SavePredToVTK(file_path,temp_folder,smoothing, out_folder, model_size): 969 | print("Generating VTK for ", file_path) 970 | 971 | img = sitk.ReadImage(file_path) 972 | img_arr = sitk.GetArrayFromImage(img) 973 | 974 | 975 | present_labels = [] 976 | for label in range(np.max(img_arr)): 977 | if label+1 in img_arr: 978 | present_labels.append(label+1) 979 | 980 | for i in present_labels: 981 | label = i 982 | seg = np.where(img_arr == label, 1,0) 983 | 984 | output = sitk.GetImageFromArray(seg) 985 | 986 | output.SetOrigin(img.GetOrigin()) 987 | output.SetSpacing(img.GetSpacing()) 988 | output.SetDirection(img.GetDirection()) 989 | output = sitk.Cast(output, sitk.sitkInt16) 990 | 991 | temp_path = temp_folder +f"/tempVTK_{label}.nrrd" 992 | # print(temp_path) 993 | 994 | writer = sitk.ImageFileWriter() 995 | writer.SetFileName(temp_path) 996 | writer.Execute(output) 997 | 998 | surf = vtk.vtkNrrdReader() 999 | surf.SetFileName(temp_path) 1000 | surf.Update() 1001 | # print(surf) 1002 | 1003 | dmc = vtk.vtkDiscreteMarchingCubes() 1004 | dmc.SetInputConnection(surf.GetOutputPort()) 1005 | dmc.GenerateValues(100, 1, 100) 1006 | 1007 | # LAPLACIAN smooth 1008 | SmoothPolyDataFilter = vtk.vtkSmoothPolyDataFilter() 1009 | SmoothPolyDataFilter.SetInputConnection(dmc.GetOutputPort()) 1010 | SmoothPolyDataFilter.SetNumberOfIterations(smoothing) 1011 | SmoothPolyDataFilter.SetFeatureAngle(120.0) 1012 | SmoothPolyDataFilter.SetRelaxationFactor(0.6) 1013 | SmoothPolyDataFilter.Update() 1014 | 1015 | model = SmoothPolyDataFilter.GetOutput() 1016 | 1017 | color = vtk.vtkUnsignedCharArray() 1018 | color.SetName("Colors") 1019 | color.SetNumberOfComponents(3) 1020 | color.SetNumberOfTuples( model.GetNumberOfCells() ) 1021 | 1022 | for i in range(model.GetNumberOfCells()): 1023 | color_tup=LABEL_COLORS[label] 1024 | color.SetTuple(i, color_tup) 1025 | 1026 | model.GetCellData().SetScalars(color) 1027 | 1028 | 1029 | # model.GetPointData().SetS 1030 | 1031 | # SINC smooth 1032 | # smoother = vtk.vtkWindowedSincPolyDataFilter() 1033 | # smoother.SetInputConnection(dmc.GetOutputPort()) 1034 | # smoother.SetNumberOfIterations(30) 1035 | # smoother.BoundarySmoothingOff() 1036 | # smoother.FeatureEdgeSmoothingOff() 1037 | # smoother.SetFeatureAngle(120.0) 1038 | # smoother.SetPassBand(0.001) 1039 | # smoother.NonManifoldSmoothingOn() 1040 | # smoother.NormalizeCoordinatesOn() 1041 | # smoother.Update() 1042 | 1043 | # print(SmoothPolyDataFilter.GetOutput()) 1044 | 1045 | # outputFilename = "Test.vtk" 1046 | outpath = out_folder + "/VTK files/" + os.path.basename(file_path).split('.')[0] + f"_{NAMES_FROM_LABELS[model_size][label]}_model.vtk" 1047 | 1048 | if not os.path.exists(os.path.dirname(outpath)): 1049 | os.makedirs(os.path.dirname(outpath)) 1050 | Write(model, outpath) 1051 | 1052 | 1053 | 1054 | def ConvertSimpleItkImageToItkImage(_sitk_image: sitk.Image, _pixel_id_value): 1055 | """ 1056 | Converts SimpleITK image to ITK image 1057 | :param _sitk_image: SimpleITK image 1058 | :param _pixel_id_value: Type of the pixel in SimpleITK format (for example: itk.F, itk.UC) 1059 | :return: ITK image 1060 | """ 1061 | array: np.ndarray = sitk.GetArrayFromImage(_sitk_image) 1062 | itk_image: itk.Image = itk.GetImageFromArray(array) 1063 | itk_image = CopyImageMetaInformationFromSimpleItkImageToItkImage(itk_image, _sitk_image, _pixel_id_value) 1064 | return itk_image 1065 | 1066 | def CopyImageMetaInformationFromSimpleItkImageToItkImage(_itk_image: itk.Image, _reference_sitk_image: sitk.Image, _output_pixel_type) -> itk.Image: 1067 | """ 1068 | Copies the meta information from SimpleITK image to ITK image 1069 | :param _itk_image: Source ITK image 1070 | :param _reference_sitk_image: Original SimpleITK image from which will be copied the meta information 1071 | :param _pixel_type: Type of the pixel in SimpleITK format (for example: itk.F, itk.UC) 1072 | :return: ITK image with the new meta information 1073 | """ 1074 | _itk_image.SetOrigin(_reference_sitk_image.GetOrigin()) 1075 | _itk_image.SetSpacing(_reference_sitk_image.GetSpacing()) 1076 | 1077 | # Setting the direction (cosines of the study coordinate axis direction in the space) 1078 | reference_image_direction: np.ndarray = np.eye(3) 1079 | np_dir_vnl = itk.GetVnlMatrixFromArray(reference_image_direction) 1080 | itk_image_direction = _itk_image.GetDirection() 1081 | itk_image_direction.GetVnlMatrix().copy_in(np_dir_vnl.data_block()) 1082 | 1083 | dimension: int = _itk_image.GetImageDimension() 1084 | input_image_type = type(_itk_image) 1085 | output_image_type = itk.Image[_output_pixel_type, dimension] 1086 | 1087 | castImageFilter = itk.CastImageFilter[input_image_type, output_image_type].New() 1088 | castImageFilter.SetInput(_itk_image) 1089 | castImageFilter.Update() 1090 | result_itk_image: itk.Image = castImageFilter.GetOutput() 1091 | 1092 | return result_itk_image 1093 | 1094 | def PlotState(img,label,x,y,z): 1095 | img_shape = img.shape 1096 | label_shape = label.shape 1097 | print(f"image shape: {img_shape}, label shape: {label_shape}") 1098 | plt.figure("scan", (18, 6)) 1099 | plt.subplot(3, 2, 1) 1100 | plt.title("scan") 1101 | plt.imshow(img[0, :, :, z].detach().cpu(), cmap="gray") 1102 | plt.subplot(3, 2, 2) 1103 | plt.title("seg") 1104 | plt.imshow(label[0, :, :, z].detach().cpu()) 1105 | plt.subplot(3, 2, 3) 1106 | plt.imshow(img[0, :, y, :].detach().cpu(), cmap="gray") 1107 | plt.subplot(3, 2, 4) 1108 | plt.imshow(label[0, :, y, :].detach().cpu()) 1109 | plt.subplot(3, 2, 5) 1110 | plt.imshow(img[0, x, :, :].detach().cpu(), cmap="gray") 1111 | plt.subplot(3, 2, 6) 1112 | plt.imshow(label[0, x, :, :].detach().cpu()) 1113 | plt.show() 1114 | 1115 | -------------------------------------------------------------------------------- /MULTI_SEG/vtkToSTL.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import vtk 4 | import argparse 5 | 6 | def convertFile(filepath, outdir): 7 | if not os.path.isdir(outdir): 8 | os.makedirs(outdir) 9 | if os.path.isfile(filepath): 10 | basename = os.path.basename(filepath) 11 | print("Copying file:", basename) 12 | basename = os.path.splitext(basename)[0] 13 | outfile = os.path.join(outdir, basename+".stl") 14 | reader = vtk.vtkGenericDataObjectReader() 15 | reader.SetFileName(filepath) 16 | reader.Update() 17 | writer = vtk.vtkSTLWriter() 18 | writer.SetInputConnection(reader.GetOutputPort()) 19 | writer.SetFileName(outfile) 20 | return writer.Write()==1 21 | return False 22 | 23 | def convertFiles(indir, outdir): 24 | files = os.listdir(indir) 25 | files = [ os.path.join(indir,f) for f in files if f.endswith('.vtk') ] 26 | ret = 0 27 | print("In:", indir) 28 | print("Out:", outdir) 29 | for f in files: 30 | ret += convertFile(f, outdir) 31 | print("Successfully converted %d out of %d files." % (ret, len(files))) 32 | 33 | def run(args): 34 | convertFiles(args.indir, args.outdir) 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser(description="VTK to STL converter") 38 | parser.add_argument('indir', help="Path to input directory.") 39 | parser.add_argument('--outdir', '-o', default= parser.parse_args().indir, help="Path to output directory.") 40 | parser.set_defaults(func=run) 41 | args = parser.parse_args() 42 | ret = args.func(args) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automatic Multi-Anatomical Skull Structure Segmentation of Cone-Beam Computed Tomography scans Using 3D UNETR 2 | *AMASSS-CBCT* 3 | 4 | 5 | # Presentation 6 | 7 | - The segmentation of medical and dental images is a fundamental step in automated clinical decision support systems. It supports the entire clinical workflows from diagnosis, therapy planning, intervention and follow-up. 8 | - We propose a novel tool to accurately process a full face segmentation in about 5 minutes that would otherwise require an average of 7h of manual work by experienced clinicians. 9 | - This work focuses on the integration of the state-of-the-art UNEt TRansformers (UNETR) of the Medical Open Network for Artificial Intelligence (MONAI) framework. 10 | - We trained and tested our models using 618 de-identified Cone-Beam Computed Tomography (CBCT) volumetric images of the head acquired with several parameters from different centers for a generalized clinical application. 11 | - Our results on a 5 fold cross-validation showed high accuracy and robustness with an F1 score up to $0.962\pm0.02$. 12 | - Full face model made by combining the mandible, the maxilla, the cranial base, the cervical vertebra and the skin segmentation: 13 | 14 | --- 15 | ![Segmentation](https://user-images.githubusercontent.com/46842010/155926868-ca81d82b-8735-4f33-97af-0c3d616e6910.png) 16 | 17 | 18 | # How to use AMASSS-CBCT 19 | 20 | - Using the [AMASSS-CBCT 3D Slicer module](https://github.com/DCBIA-OrthoLab/SlicerAutomatedDentalTools). You can segment a CBCT scan with no coding language required using 3D Slicer GUI. 21 | - Using the [DSCI](https://dsci.dent.umich.edu/#/) web-based plateform, you can segment a CBCT scan with no coding language required. 22 | - Localy, you can use the scipts on you computer using docker or the source code on github. 23 | 24 | 25 | # Local usage 26 | 27 | ## Arguments to run AMASSS-CBCT prediction script 28 | 29 | Recquired arguments: 30 | ``` 31 | -i or 32 | -o 33 | -dm 34 | ``` 35 | 36 | Segmentation options: 37 | ``` 38 | -ss 39 | -hd 40 | -m 41 | ``` 42 | 43 | Valid arguments for `-ss`: 44 | - `MAND` (mandible) 45 | - `MAX` (maxilla) 46 | - `CB` (cranial base) 47 | - `CV` (cervical vertebra) 48 | - `UAW` (upper airway) 49 | - `SKIN` (skin) 50 | - `RC` (root canal) 51 | 52 | 53 | 54 | Save options: 55 | ``` 56 | -sf 57 | -id 58 | ``` 59 | 60 | 3D surface generation arguments: 61 | ``` 62 | -vtk 63 | -vtks 64 | ``` 65 | 66 | Technical arguments: 67 | ``` 68 | -sp 69 | -cs 70 | -pr 71 | -mo 72 | ``` 73 | 74 | Computing power arguments: 75 | 76 | ``` 77 | -ncw 78 | -ngw 79 | ``` 80 | 81 | 82 | ## Use Docker image 83 | You can get the AMASSS docker image by running the folowing command lines. 84 | 85 | **Building using the DockerFile** 86 | 87 | From the DockerFile directory: 88 | ``` 89 | docker pull dcbia/amasss:latest 90 | ``` 91 | 92 | From the DockerFile directory: 93 | 94 | ``` 95 | docker build -t amasss . 96 | ``` 97 | 98 | **Automatic segmentation** 99 | *Running on CPU* 100 | ``` 101 | docker run --rm --shm-size=5gb -v :/app/data/scans amasss:latest python3 /app/MULTI_SEG/src/predict_CBCTSeg.py 102 | ``` 103 | *Running on GPU* 104 | ``` 105 | docker run --rm --shm-size=5gb --gpus all -v :/app/data/scans amasss:latest python3 /app/MULTI_SEG/src/predict_CBCTSeg.py 106 | ``` 107 | 108 | **Informations** 109 | - A ***test scan*** "MG_scan_test.nii.gz" is provided in the Data folder of the AMASSS repositorie. 110 | - If the prediction with the ***GPU is not working***, make sure you installed the NVIDIA Container Toolkit : 111 | https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker 112 | 113 | **Segmentation options/arguments exemple** 114 | - By default, the mandible (MAND), the maxilla (MAX), the cranial base (CB), the cervical vertebra (CV) and the upper airway (UAW) structures are segmented and a merged segmentation is generated. 115 | To choose which structure to segment, you can use the following arguments: 116 | ``` 117 | -ss MAND MAX CB CV UAW 118 | ``` 119 | To deactivate the merging step, you can use the following argument: 120 | ``` 121 | -m False 122 | ``` 123 | - By default the prediction will use 5 CPU process and use a batch size of 5 on the GPU (which requires around 8GB on the GPU), you can use the following argument to change this numbers: 124 | ``` 125 | -ncw 2 -ngw 2 126 | ``` 127 | (ncw for the CPU and ngw for the GPU) 128 | 129 | ___ 130 | 131 | ## Using the scripts 132 | ## Prerequisites 133 | 134 | python 3.8.8 with the librairies: 135 | 136 | **Main librairies:** 137 | 138 | > monai==0.7.0 \ 139 | > torch==1.10.1 \ 140 | > itk==5.2.1 \ 141 | > numpy==1.20.1 \ 142 | > simpleitk==2.1.1 143 | 144 | You can install the required librairies by running the following command: 145 | 146 | ``` 147 | pip install -r requirements.txt 148 | ``` 149 | 150 | 151 | ## Running the code 152 | 153 | Using the script [predict_CBCTSeg.py](MULTI_SEG/src/predict_CBCTSeg.py) 154 | 155 | Basic usage: 156 | ``` 157 | python3 predict_CBCTSeg.py -i -o -dm 158 | ``` 159 | Example to segment the mandible, the maxilla, the cranial base, the cervical vertebra and the upper airway in a large FOV scan (the MG_scan_test.nii.gz) with 3D surface generation and merging: 160 | 161 | ``` 162 | python3 predict_CBCTSeg.py -i ./Data/MG_scan_test.nii.gz -o ./Data -dm -ss MAND MAX CB CV UAW -vtk True 163 | ``` 164 | 165 | Prediction steps 166 | 167 | ![prediction](https://user-images.githubusercontent.com/46842010/155927157-19206e54-7a90-4816-8eb7-72369a04c39e.png) 168 | 169 | 170 | 171 | # Train AMASSS-CBCT 172 | 173 | ### Prepare the data 174 | 175 | **Spacing** 176 | To run the preprocess to organise the files and set them at the wanted spacing: 177 | 178 | Change the scan spacing to the wanted spacing using the script [init_training_data.py](MULTI_SEG/src/init_training_data.py): 179 | 180 | ``` 181 | python3 init_training_data.py -i "path of the input folder with the scans and the segs" -o "path of the output folder" 182 | ``` 183 | By defaul the spacing is set at 0.5 but we can change and add other spacing with the argument : 184 | ``` 185 | -sp 0.X1 0.X2 ... 0.Xn 186 | ```` 187 | 188 | **Contrast adjustment** 189 | To run the preprocess to correct the image contrast and fill the holes in the segmentations 190 | run the folowing command line using the script [correct_file.py](MULTI_SEG/src/correct_file.py): 191 | 192 | ``` 193 | python3 correct_file.py -i -o -rad 194 | ``` 195 | 196 | Expected results of the contrast adjustment : 197 | ![ContrastAdjust](https://user-images.githubusercontent.com/46842010/155178176-7e735867-4ad2-412d-9ac0-c47fe9d7cd8e.png) 198 | 199 | 200 | ### Organise the training folder 201 | 202 | Organise the training folder as follows: 203 | 204 | Screen Shot 2022-02-22 at 12 11 56 PM 205 | 206 | 207 | ### Start the training 208 | 209 | Use the script [train_CBCTseg.py](MULTI_SEG/src/train_CBCTseg.py): 210 | 211 | 212 | ``` 213 | python3 train_CBCTseg.py --dir_project 214 | 215 | ``` 216 | Aditional options: 217 | ``` 218 | -mn 219 | -vp 220 | -cs 221 | -me 222 | -nl 223 | -bs 224 | -nw 225 | ``` 226 | 227 | You can launch a TensorBoard session to follow the training progress: 228 | 229 | 230 | ___ 231 | 232 | 233 | 234 | Results 235 | 236 | ![RESULTS](https://user-images.githubusercontent.com/46842010/155927668-906b4fae-4249-4556-a4fa-7a622e9c6c81.png) 237 | 238 | 239 | 240 | 241 | # Acknowledgements 242 | 243 | Authors: Maxime Gillot (University of Michigan), Baptiste Baquero (UoM), Celia Le (UoM), Romain Deleat-Besson (UoM), Lucia Cevidanes (UoM), Jonas Bianchi (UoM), Marcela Gurgel (UoM), Marilia Yatabe (UoM), Najla Al Turkestani (UoM), Kayvan Najarian (UoM), Reza Soroushmehr (UoM), Steve Pieper (ISOMICS), Ron Kikinis (Harvard Medical School), Beatriz Paniagua ( Kitware ), Jonathan Gryak (UoM), Marcos Ioshida (UoM), Camila Massaro (UoM), Liliane Gomes (UoM), Heesoo Oh (University of Pacific), Karine Evangelista (UoM), Cauby Chaves Jr (University of Ceara), Daniela Garib (University of São Paulo), Fábio Costa (University of Ceara), Erika Benavides (UoM), Fabiana Soki (UoM), Jean-Christophe Fillion-Robin (Kitware), Hina Joshi (University of North Narolina), Juan Prieto (Dept. of Psychiatry UNC at Chapel Hill) 244 | 245 | Supported by NIDCR R01 024450, AA0F Grabber Family Teaching and Research Award and by Research Enhancement Award Activity 141 from the University of the Pacific, Arthur A. Dugoni School of Dentistry. 246 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | async-timeout==4.0.1 5 | cachetools==4.2.2 6 | charset-normalizer==2.0.9 7 | connected-components-3d==3.9.1 8 | einops==0.3.2 9 | frozenlist==1.2.0 10 | fvcore==0.1.5.post20211023 11 | gdown==3.13.0 12 | google-auth==1.35.0 13 | google-auth-oauthlib==0.4.6 14 | grpcio==1.40.0 15 | iopath==0.1.9 16 | itk==5.2.1.post1 17 | itk-core==5.2.1.post1 18 | itk-filtering==5.2.1.post1 19 | itk-io==5.2.1.post1 20 | itk-numerics==5.2.1.post1 21 | itk-registration==5.2.1.post1 22 | itk-segmentation==5.2.1.post1 23 | markdown==3.3.4 24 | medpy==0.4.0 25 | monai==0.7.0 26 | multidict==5.2.0 27 | nibabel==3.2.1 28 | oauthlib==3.1.1 29 | plotly==5.4.0 30 | protobuf==3.17.3 31 | pyasn1==0.4.8 32 | pyasn1-modules==0.2.8 33 | pytorch-ignite==0.4.2 34 | pyvista==0.33.3 35 | requests-oauthlib==1.3.0 36 | rsa==4.7.2 37 | scooby==0.5.11 38 | simpleitk==2.1.1 39 | sklearn==0.0 40 | tenacity==8.0.1 41 | tensorboard==2.6.0 42 | tensorboard-data-server==0.6.1 43 | tensorboard-plugin-wit==1.8.0 44 | termcolor==1.1.0 45 | torch==1.10.1 46 | torch-tb-profiler==0.3.1 47 | torchsummary==1.5.1 48 | torchvision==0.11.2 49 | vtk==9.1.0 50 | wslink==1.2.1 51 | yacs==0.1.8 52 | yarl==1.7.2 53 | seaborn==0.11.2 54 | numpy==1.21.5 --------------------------------------------------------------------------------