├── utils ├── __init__.py ├── ddsm_downloader.py ├── ddsm_png_converter.py └── ddsm_preprocessor.py ├── datasets ├── __init__.py ├── classification_dataset.py └── generic_dataset.py ├── transforms ├── __init__.py ├── patches_random.py ├── patches_centered.py └── patches_normal.py ├── .gitignore ├── .gitattributes ├── requirements.txt ├── config.json ├── examples ├── whole_image_classification_dataset.py ├── config.json ├── random_patch_classification_dataset.py ├── centered_patch_classification_dataset.py ├── centered_patch_classification_train_val_split.py └── centered_patch_classification_crossval.py ├── setup.py ├── README.md └── ddsm_dataset_factory.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | CBIS_DDSM 2 | .idea 3 | config.json 4 | *.pyc 5 | venv 6 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pydicom 2 | Pillow 3 | tqdm 4 | requests 5 | numpy 6 | pandas 7 | matplotlib 8 | opencv-python 9 | shapely -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "download_path" : "./CBIS_DDSM/", 3 | "manifest" : "./resources/CBIS-DDSM-All-doiJNLP-zzWs5zfZ.tcia", 4 | "mass_train_csv": "./resources/mass_case_description_train_set.csv", 5 | "mass_test_csv": "./resources/mass_case_description_test_set.csv", 6 | "calc_train_csv": "./resources/calc_case_description_train_set.csv", 7 | "calc_test_csv": "./resources/calc_case_description_test_set.csv" 8 | } 9 | -------------------------------------------------------------------------------- /examples/whole_image_classification_dataset.py: -------------------------------------------------------------------------------- 1 | from ddsm_dataset_factory import CBISDDSMDatasetFactory 2 | 3 | dataset = CBISDDSMDatasetFactory('./config.json') \ 4 | .drop_attributes("assessment", "breast_density", "subtlety") \ 5 | .map_attribute_value('pathology', {'BENIGN_WITHOUT_CALLBACK': 'BENIGN'}) \ 6 | .show_counts() \ 7 | .create_classification('pathology') 8 | dataset.visualize() -------------------------------------------------------------------------------- /examples/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "download_path" : "../CBIS_DDSM/", 3 | "manifest" : "../resources/CBIS-DDSM-All-doiJNLP-zzWs5zfZ.tcia", 4 | "mass_train_csv": "../resources/mass_case_description_train_set.csv", 5 | "mass_test_csv": "../resources/mass_case_description_test_set.csv", 6 | "calc_train_csv": "../resources/calc_case_description_train_set.csv", 7 | "calc_test_csv": "../resources/calc_case_description_test_set.csv" 8 | } 9 | -------------------------------------------------------------------------------- /examples/random_patch_classification_dataset.py: -------------------------------------------------------------------------------- 1 | from ddsm_dataset_factory import CBISDDSMDatasetFactory 2 | 3 | dataset = CBISDDSMDatasetFactory('./config.json') \ 4 | .drop_attributes("assessment", "breast_density", "subtlety") \ 5 | .map_attribute_value('pathology', {'BENIGN_WITHOUT_CALLBACK': 'BENIGN'}) \ 6 | .show_counts() \ 7 | .lesion_patches_random() \ 8 | .create_classification('pathology') 9 | dataset.visualize() -------------------------------------------------------------------------------- /examples/centered_patch_classification_dataset.py: -------------------------------------------------------------------------------- 1 | from ddsm_dataset_factory import CBISDDSMDatasetFactory 2 | 3 | dataset = CBISDDSMDatasetFactory('./config.json', 4 | include_train_set=True, 5 | include_test_set=True, 6 | include_calcifications=False, 7 | include_masses=True) \ 8 | .drop_attributes("assessment", "breast_density", "subtlety") \ 9 | .map_attribute_value('pathology', {'BENIGN_WITHOUT_CALLBACK': 'BENIGN'}) \ 10 | .show_counts() \ 11 | .lesion_patches_centered() \ 12 | .create_classification('pathology') 13 | dataset.visualize() -------------------------------------------------------------------------------- /examples/centered_patch_classification_train_val_split.py: -------------------------------------------------------------------------------- 1 | from ddsm_dataset_factory import CBISDDSMDatasetFactory 2 | from torchvision import transforms 3 | 4 | dataset = CBISDDSMDatasetFactory('./config.json') \ 5 | .drop_attributes("assessment", "breast_density", "subtlety") \ 6 | .map_attribute_value('pathology', {'BENIGN_WITHOUT_CALLBACK': 'BENIGN'}) \ 7 | .show_counts() \ 8 | .lesion_patches_centered() \ 9 | .add_image_transforms([transforms.Resize(512)]) \ 10 | .add_image_transforms([transforms.RandomAffine(degrees=180, scale=(0.7, 1.5)), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.RandomVerticalFlip()], for_val=False) \ 13 | .add_image_transforms([transforms.Lambda(lambda x: x.repeat(3, 1, 1))]) \ 14 | .create_classification('pathology', mask_input=True) 15 | 16 | print(len(dataset)) 17 | train_set, val_set = dataset.split_train_val(0.2, shuffle=True) 18 | print(len(train_set), len(val_set)) 19 | train_set.visualize() -------------------------------------------------------------------------------- /examples/centered_patch_classification_crossval.py: -------------------------------------------------------------------------------- 1 | from ddsm_dataset_factory import CBISDDSMDatasetFactory 2 | from torchvision import transforms 3 | 4 | dataset = CBISDDSMDatasetFactory('./config.json') \ 5 | .drop_attributes("assessment", "breast_density", "subtlety") \ 6 | .map_attribute_value('pathology', {'BENIGN_WITHOUT_CALLBACK': 'BENIGN'}) \ 7 | .show_counts() \ 8 | .lesion_patches_centered() \ 9 | .add_image_transforms([transforms.Resize(512)]) \ 10 | .add_image_transforms([transforms.RandomAffine(degrees=180, scale=(0.7, 1.5)), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.RandomVerticalFlip()], for_val=False) \ 13 | .add_image_transforms([transforms.Lambda(lambda x: x.repeat(3, 1, 1))]) \ 14 | .create_classification('pathology', mask_input=True) 15 | 16 | cv_datasets = dataset.split_crossval(5) 17 | 18 | #fold 1 19 | train_set, val_set = cv_datasets[0] 20 | print(len(train_set), len(val_set)) 21 | train_set.visualize() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | from utils.ddsm_downloader import CBISDDSMDownloader 5 | from utils.ddsm_png_converter import CBISDDSMConverter 6 | from utils.ddsm_preprocessor import CBISDDSMPreprocessor 7 | 8 | parser = argparse.ArgumentParser(prog='CBIS DDSM Setup', 9 | description="Welcome to CBIS DDSM Dataloader library.\n" 10 | "Setup includes: \n" 11 | " 1. downloading the database, \n" 12 | " 2. converting images to PNG \n" 13 | " 3. organizing lesion information for PyTorch dataset creation \n" 14 | "Please ensure that the information in 'config.json' file is correctly set and the files exist.\n" 15 | "This process will take a while.\n") 16 | parser.add_argument('-c', '--config_file', default='./config.json', 17 | help='Path to the configuration file. Default=config.json') 18 | parser.add_argument('-d', action='store_true', help='If used, dcm file will be deleted during conversion, to free up space.' 19 | 'However, if download runs again it will need to download the whole dataset again.') 20 | args = parser.parse_args() 21 | 22 | with open(args.config_file, 'r') as cf: 23 | config = json.load(cf) 24 | 25 | downloader = CBISDDSMDownloader(config['manifest'], config['download_path']) 26 | downloader.start() 27 | downloader.start() # Run twice for checks 28 | 29 | converter = CBISDDSMConverter(config['download_path'], delete_dcm=args.d) 30 | converter.start() 31 | converter.start() # Run twice for checks 32 | 33 | preprocessor = CBISDDSMPreprocessor(config['download_path'], 34 | (config['mass_train_csv'], config['calc_train_csv']), 35 | (config['mass_test_csv'], config['calc_test_csv'])) 36 | preprocessor.start() 37 | -------------------------------------------------------------------------------- /transforms/patches_random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _find_boundaries(x, y, w, h, image_shape, patch_size, min_overlap): 5 | max_h, min_h = max(h, patch_size[1]), min(h, patch_size[1]) 6 | max_w, min_w = max(w, patch_size[0]), min(w, patch_size[0]) 7 | 8 | min_y = y - max_h + min_overlap * min_h 9 | min_x = x - max_w + min_overlap * min_w 10 | max_y = y + h - int(min_overlap * min_h) 11 | max_x = x + w - int(min_overlap * min_w) 12 | 13 | min_y = int(max(min_y, 0)) 14 | min_x = int(max(min_x, 0)) 15 | max_y = int(max(min(max_y, image_shape[0] - patch_size[1] - 1), min_y)) 16 | max_x = int(max(min(max_x, image_shape[1] - patch_size[0] - 1), min_x)) 17 | 18 | return min_x, max_x, min_y, max_y 19 | 20 | class RandomPatches(torch.nn.Module): 21 | def __init__(self, patch_size=(1024, 1024), min_overlap=0.9): 22 | super(RandomPatches, self).__init__() 23 | self.min_overlap = min_overlap 24 | self.patch_size = patch_size 25 | 26 | def forward(self, sample): 27 | image_tensor_list, item = sample['image_tensor_list'], sample['item'] 28 | image_shape = image_tensor_list[-1].shape[1:3] 29 | 30 | abnorm_w = (item['maxx'] - item['minx']) / 2 31 | abnorm_x = int(abnorm_w + item['minx']) 32 | abnorm_h = (item['maxy'] - item['miny']) / 2 33 | abnorm_y = int(abnorm_h + item['miny']) 34 | 35 | min_x, max_x, min_y, max_y = _find_boundaries(abnorm_x, abnorm_y, abnorm_w, abnorm_h, image_shape, self.patch_size, 36 | self.min_overlap) 37 | 38 | patch_y = torch.randint(min_y, max_y + 1, (1,)) 39 | patch_x = torch.randint(min_x, max_x + 1, (1,)) 40 | 41 | out_tensors = [] 42 | for image_tensor in image_tensor_list: 43 | image_tensor_cropped = image_tensor[:, patch_y: patch_y + self.patch_size[1], patch_x: patch_x + self.patch_size[0]] 44 | out_tensors.append(image_tensor_cropped) 45 | 46 | sample = {'image_tensor_list': out_tensors, 'item': item} 47 | return sample 48 | 49 | def __repr__(self): 50 | detail = f"(patch_size={self.patch_size}, min_overlap={self.min_overlap})" 51 | return f"{self.__class__.__name__}{detail}" -------------------------------------------------------------------------------- /transforms/patches_centered.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class CenteredPatches(torch.nn.Module): 5 | def __init__(self, patch_size): 6 | super(CenteredPatches, self).__init__() 7 | self.patch_size = patch_size 8 | 9 | def forward(self, sample): 10 | image_tensor_list, item = sample['image_tensor_list'], sample['item'] 11 | image_shape = image_tensor_list[-1].shape[1:3] 12 | 13 | cx = item['cx'] 14 | cy = item['cy'] 15 | 16 | minx_naive = int(cx - self.patch_size[0] / 2) 17 | minx = max((0, minx_naive)) 18 | dx1 = minx - minx_naive 19 | 20 | maxx_naive = int(cx + self.patch_size[0] / 2) 21 | maxx = min((maxx_naive, image_shape[1])) 22 | dx2 = maxx_naive - maxx 23 | 24 | if dx1 != 0 and dx2 != 0: 25 | print('Warning: patch size bigger than image x-dimension. Please select a smaller patch size.') 26 | else: 27 | minx -= dx2 28 | maxx += dx1 29 | 30 | miny_naive = int(cy - self.patch_size[1] / 2) 31 | miny = max((0, miny_naive)) 32 | dy1 = miny - miny_naive 33 | 34 | maxy_naive = int(cy + self.patch_size[1] / 2) 35 | maxy = min((maxy_naive, image_shape[0])) 36 | dy2 = maxy_naive - maxy 37 | 38 | if dy1 != 0 and dy2 != 0: 39 | print('Warning: patch size bigger than image y-dimension. Please select a smaller patch size.') 40 | else: 41 | miny -= dy2 42 | maxy += dy1 43 | 44 | out_tensors = [] 45 | for image_tensor in image_tensor_list: 46 | image_tensor = image_tensor[:, miny: maxy, minx: maxx] 47 | out_tensors.append(image_tensor) 48 | 49 | sample = {'image_tensor_list': out_tensors, 'item': item} 50 | return sample 51 | 52 | def __repr__(self): 53 | detail = f"(patch_size={self.patch_size})" 54 | return f"{self.__class__.__name__}{detail}" 55 | 56 | def centered_patch_transform(patch_size=(1024, 1024)): 57 | def perform(sample): 58 | image_tensor_list, item = sample['image_tensor_list'], sample['item'] 59 | image_shape = image_tensor_list[-1].shape[1:3] 60 | 61 | cx = item['cx'] 62 | cy = item['cy'] 63 | 64 | minx_naive = int(cx - patch_size[0] / 2) 65 | minx = max((0, minx_naive)) 66 | dx1 = minx - minx_naive 67 | 68 | maxx_naive = int(cx + patch_size[0] / 2) 69 | maxx = min((maxx_naive, image_shape[1])) 70 | dx2 = maxx_naive - maxx 71 | 72 | if dx1 != 0 and dx2 != 0: 73 | print('Warning: patch size bigger than image x-dimension. Please select a smaller patch size.') 74 | else: 75 | minx -= dx2 76 | maxx += dx1 77 | 78 | miny_naive = int(cy - patch_size[1] / 2) 79 | miny = max((0, miny_naive)) 80 | dy1 = miny - miny_naive 81 | 82 | maxy_naive = int(cy + patch_size[1] / 2) 83 | maxy = min((maxy_naive, image_shape[0])) 84 | dy2 = maxy_naive - maxy 85 | 86 | if dy1 != 0 and dy2 != 0: 87 | print('Warning: patch size bigger than image y-dimension. Please select a smaller patch size.') 88 | else: 89 | miny -= dy2 90 | maxy += dy1 91 | 92 | out_tensors = [] 93 | for image_tensor in image_tensor_list: 94 | image_tensor = image_tensor[:, miny: maxy, minx: maxx] 95 | out_tensors.append(image_tensor) 96 | 97 | sample = {'image_tensor_list': out_tensors, 'item': item} 98 | return sample 99 | 100 | return perform -------------------------------------------------------------------------------- /utils/ddsm_downloader.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import requests 3 | import zipfile 4 | from io import BytesIO 5 | import os 6 | import json 7 | from tqdm import tqdm 8 | import argparse 9 | 10 | BASE_URL_IMAGE = 'https://services.cancerimagingarchive.net/nbia-api/services/v1/getImage?SeriesInstanceUID={}' 11 | BASE_URL_METADATA = 'https://services.cancerimagingarchive.net/nbia-api/services/v1/getSeriesMetaData?SeriesInstanceUID={}' 12 | 13 | 14 | class CBISDDSMDownloader: 15 | def __init__(self, manifest_path, download_path, skip_existing=True): 16 | self.__skip_existing = skip_existing 17 | self.__download_path = download_path 18 | self.__manifest_file_path = manifest_path 19 | self.__image_series_UID = [] 20 | 21 | def __parse_manifest(self): 22 | with open(self.__manifest_file_path) as file: 23 | found_starting_line_flag = False 24 | for line in file: 25 | if not found_starting_line_flag: 26 | if 'ListOfSeriesToDownload=' in line: 27 | found_starting_line_flag = True 28 | else: 29 | self.__image_series_UID.append(line.strip()) 30 | 31 | if found_starting_line_flag: 32 | print("Found {} items to download.".format(len(self.__image_series_UID))) 33 | else: 34 | print("Incorrect format of the manifest file provided!") 35 | 36 | @staticmethod 37 | def __get_metadata(series_uid): 38 | response = requests.get(BASE_URL_METADATA.format(series_uid)) 39 | response_dict = json.loads(response.content.decode("utf-8"))[0] 40 | return response_dict 41 | 42 | @staticmethod 43 | def __exists(download_path, num_imgs): 44 | if os.path.exists(download_path): 45 | folder_contents = os.listdir(download_path) 46 | num_dcm = len(list(item for item in folder_contents if item.endswith('.dcm'))) 47 | if num_dcm == num_imgs: 48 | return True 49 | return False 50 | 51 | @staticmethod 52 | def __download_extract_image(series_uid, path): 53 | response = requests.get(BASE_URL_IMAGE.format(series_uid)) 54 | with zipfile.ZipFile(BytesIO(response.content)) as z: 55 | z.extractall(path) 56 | 57 | def __payload(self, seriesUID): 58 | metadata = self.__get_metadata(seriesUID) 59 | folder_name = metadata['Subject ID'] 60 | series_uid = metadata['Series UID'] 61 | study_uid = metadata['Study UID'] 62 | num_imgs = int(metadata['Number of Images']) 63 | download_path = os.path.join(self.__download_path, folder_name, study_uid, series_uid) 64 | 65 | if self.__skip_existing and self.__exists(download_path, num_imgs): 66 | return 67 | 68 | self.__download_extract_image(seriesUID, download_path) 69 | 70 | def start(self): 71 | self.__parse_manifest() 72 | 73 | with concurrent.futures.ThreadPoolExecutor() as executor: 74 | # Start the load operations and mark each future with its URL 75 | future_to_url = {executor.submit(self.__payload, uid): uid for uid in self.__image_series_UID} 76 | for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(self.__image_series_UID), unit="file"): 77 | url = future_to_url[future] 78 | try: 79 | future.result() 80 | except Exception as exc: 81 | print(f"{url} generated an exception: {exc}") 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser(prog='CBIS DDSM Downloader') 86 | parser.add_argument('-m', '--manifest', default='../resources/CBIS-DDSM-All-doiJNLP-zzWs5zfZ.tcia', 87 | help='Path to the manifest file.') 88 | parser.add_argument('-p', '--path', default='../CBIS_DDSM', 89 | help='Path to the download folder. It will be created if not existing.') 90 | args = parser.parse_args() 91 | downloader = CBISDDSMDownloader(args.manifest, args.path) 92 | downloader.start() 93 | -------------------------------------------------------------------------------- /datasets/classification_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets.generic_dataset import CBISDDSMGenericDataset 2 | 3 | 4 | class CBISDDSMClassificationDataset(CBISDDSMGenericDataset): 5 | def __init__(self, dataframe, 6 | download_path, 7 | label_field, 8 | label_list, 9 | masks=False, 10 | transform=None, 11 | train_image_transform=None, 12 | train_image_transform_for_mask_flags=None, 13 | test_image_transform=None, 14 | test_image_transform_for_mask_flags=None): 15 | super().__init__(dataframe, 16 | download_path, 17 | masks=masks, 18 | transform=transform, 19 | train_image_transform=train_image_transform, 20 | train_image_transform_for_mask_flags=train_image_transform_for_mask_flags, 21 | test_image_transform=test_image_transform, 22 | test_image_transform_for_mask_flags=test_image_transform_for_mask_flags) 23 | 24 | self.label_field = label_field 25 | self.label_list = label_list 26 | 27 | def __getitem__(self, idx): 28 | image_tensor, item = super().__getitem__(idx) 29 | 30 | label_full = item[self.label_field] 31 | label = self.label_list.index(label_full) 32 | 33 | return image_tensor, label 34 | 35 | def _get_label_visualize(self, label): 36 | label = self.label_list[label] 37 | return label 38 | 39 | @property 40 | def num_classes(self): 41 | return len(self.label_list) 42 | 43 | def split_train_val(self, val_ratio, shuffle=False, random_state=None): 44 | df1, df2 = self._split_dataframe(val_ratio, shuffle, random_state) 45 | val_dataset = CBISDDSMClassificationDataset(df2, self.download_path, self.label_field, self.label_list, 46 | masks=self.include_masks, transform=self.transform, 47 | train_image_transform=self._train_image_transforms, 48 | train_image_transform_for_mask_flags=self._train_image_transform_for_mask_flags, 49 | test_image_transform=self._test_image_transforms, 50 | test_image_transform_for_mask_flags=self._test_image_transform_for_mask_flags) 51 | val_dataset.test_mode() 52 | train_dataset = CBISDDSMClassificationDataset(df1, self.download_path, self.label_field, self.label_list, 53 | masks=self.include_masks, transform=self.transform, 54 | train_image_transform=self._train_image_transforms, 55 | train_image_transform_for_mask_flags=self._train_image_transform_for_mask_flags, 56 | test_image_transform=self._test_image_transforms, 57 | test_image_transform_for_mask_flags=self._test_image_transform_for_mask_flags) 58 | train_dataset.train_mode() 59 | return train_dataset, val_dataset 60 | 61 | def split_crossval(self, folds, shuffle=False, random_state=None): 62 | dataframe_pairs = self._split_dataframe_crossval(folds, shuffle, random_state) 63 | dataset_pairs = [] 64 | for i in range(folds): 65 | train_dataset = CBISDDSMClassificationDataset(dataframe_pairs[i][0], self.download_path, self.label_field, 66 | self.label_list, 67 | masks=self.include_masks, transform=self.transform, 68 | train_image_transform=self._train_image_transforms, 69 | train_image_transform_for_mask_flags=self._train_image_transform_for_mask_flags, 70 | test_image_transform=self._test_image_transforms, 71 | test_image_transform_for_mask_flags=self._test_image_transform_for_mask_flags) 72 | train_dataset.train_mode() 73 | val_dataset = CBISDDSMClassificationDataset(dataframe_pairs[i][1], self.download_path, self.label_field, 74 | self.label_list, 75 | masks=self.include_masks, transform=self.transform, 76 | train_image_transform=self._train_image_transforms, 77 | train_image_transform_for_mask_flags=self._train_image_transform_for_mask_flags, 78 | test_image_transform=self._test_image_transforms, 79 | test_image_transform_for_mask_flags=self._test_image_transform_for_mask_flags) 80 | val_dataset.test_mode() 81 | dataset_pairs.append((train_dataset, val_dataset)) 82 | return dataset_pairs 83 | -------------------------------------------------------------------------------- /utils/ddsm_png_converter.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import os 3 | import pydicom 4 | from PIL import Image 5 | from tqdm import tqdm 6 | import argparse 7 | 8 | 9 | class CBISDDSMConverter: 10 | def __init__(self, download_path, skip_existing=True, delete_dcm=False): 11 | self.__download_path = download_path 12 | self.__skip_existing = skip_existing 13 | self.__delete_dcm = delete_dcm 14 | self.__initialize_lists() 15 | 16 | def __initialize_lists(self): 17 | self.__dcm_image_list = [] 18 | self.__to_delete_dcm_image_list = [] 19 | self.__num_skipped = 0 20 | 21 | def __find_images(self, root_path): 22 | directory_list = os.listdir(root_path) 23 | for dir in directory_list: 24 | dir_path_1 = os.path.join(root_path, dir) 25 | if not os.path.isdir(dir_path_1): 26 | continue 27 | directory_list_1 = os.listdir(dir_path_1) 28 | for dir_1 in directory_list_1: 29 | dir_path_2 = os.path.join(dir_path_1, dir_1) 30 | directory_list_2 = os.listdir(dir_path_2) 31 | for dir_2 in directory_list_2: 32 | dir_path = os.path.join(dir_path_2, dir_2) 33 | contents_list = os.listdir(dir_path) 34 | dcm_list = list(item for item in contents_list if item.endswith('.dcm')) 35 | png_list = list(item for item in contents_list if item.endswith('.png')) 36 | if self.__delete_dcm: 37 | for img in dcm_list: 38 | img_path = os.path.join(dir_path, img) 39 | self.__to_delete_dcm_image_list.append(img_path) 40 | if self.__skip_existing and len(dcm_list) == len(png_list): 41 | self.__num_skipped += len(dcm_list) 42 | continue 43 | for img in dcm_list: 44 | img_path = os.path.join(dir_path, img) 45 | self.__dcm_image_list.append(img_path) 46 | print("Found {} dcm images to convert. Skipped {}.".format(len(self.__dcm_image_list), self.__num_skipped)) 47 | 48 | @staticmethod 49 | def __get_png_path(dcm_path): 50 | path, name_ext = os.path.split(dcm_path) 51 | name, _ = os.path.splitext(name_ext) 52 | name_int = int(name) - 1 # Start numbering from 0 53 | name = str(name_int).zfill(6) 54 | output_path = os.path.join(path, name + '.png') 55 | return output_path 56 | 57 | @staticmethod 58 | def __dicom_to_png(input_path, output_path): 59 | ds = pydicom.dcmread(input_path, force=True) 60 | pixel_array = ds.pixel_array 61 | image = Image.fromarray(pixel_array) 62 | image.save(output_path, format='PNG', lossless=True) 63 | 64 | def __payload_convert(self, input_path): 65 | output_path = self.__get_png_path(input_path) 66 | self.__dicom_to_png(input_path, output_path) 67 | 68 | def __payload_delete(self, input_path): 69 | os.remove(input_path) 70 | 71 | def start(self): 72 | self.__initialize_lists() 73 | self.__find_images(self.__download_path) 74 | num_fails = 0 75 | with concurrent.futures.ThreadPoolExecutor() as executor: 76 | # Start the load operations and mark each future with its URL 77 | future_to_url = {executor.submit(self.__payload_convert, uid): uid for uid in self.__dcm_image_list} 78 | for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(self.__dcm_image_list), 79 | unit="file"): 80 | url = future_to_url[future] 81 | try: 82 | future.result() 83 | except Exception as exc: 84 | num_fails += 1 85 | print(f"{url} generated an exception: {exc}") 86 | if num_fails > 0: 87 | print( 88 | 'Conversion failed for {} dcm images. Please re-run the downloader to fix incorrect downloads.'.format( 89 | num_fails)) 90 | if self.__delete_dcm: 91 | print('Cleaning up DICOM images...') 92 | with concurrent.futures.ThreadPoolExecutor() as executor: 93 | # Start the load operations and mark each future with its URL 94 | future_to_url = {executor.submit(self.__payload_delete, uid): uid for uid in self.__to_delete_dcm_image_list} 95 | for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(self.__to_delete_dcm_image_list), 96 | unit="file"): 97 | url = future_to_url[future] 98 | try: 99 | future.result() 100 | except Exception as exc: 101 | num_fails += 1 102 | print(f"{url} generated an exception: {exc}") 103 | if num_fails > 0: 104 | print( 105 | 'Deletion failed for {} dcm images. Please re-run the downloader to fix incorrect downloads.'.format( 106 | num_fails)) 107 | 108 | 109 | if __name__ == "__main__": 110 | parser = argparse.ArgumentParser(prog='CBIS DDSM Converter') 111 | parser.add_argument('-p', '--path', default='../CBIS_DDSM', 112 | help='Path to the download folder. It will be created if not existing.') 113 | args = parser.parse_args() 114 | downloader = CBISDDSMConverter(args.path, delete_dcm=True) 115 | downloader.start() 116 | -------------------------------------------------------------------------------- /transforms/patches_normal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transforms.patches_random import _find_boundaries 4 | 5 | class PatchesNormalWrapper(torch.nn.Module): 6 | def __init__(self, other_tranform, probability=0.5, patch_size=(1024, 1024), min_breast_overlap=0.5, max_abnorm_overlap=0.1, max_tries=5): 7 | super(PatchesNormalWrapper, self).__init__() 8 | self.other_tranform = other_tranform 9 | self.probability = probability 10 | self.patch_size = patch_size 11 | self.min_breast_overlap = min_breast_overlap 12 | self.max_abnorm_overlap = max_abnorm_overlap 13 | self.max_tries = max_tries 14 | 15 | def forward(self, sample): 16 | choice = torch.randint(0, 100, (1,)) 17 | if choice >= self.probability * 100: 18 | return self.other_tranform(sample) 19 | else: 20 | image_tensor_list, item = sample['image_tensor_list'], sample['item'] 21 | image_shape = image_tensor_list[0].shape 22 | 23 | abnorm_w = (item['maxx'] - item['minx']) / 2 24 | abnorm_x = int(abnorm_w + item['minx']) 25 | abnorm_h = (item['maxy'] - item['miny']) / 2 26 | abnorm_y = int(abnorm_h + item['miny']) 27 | 28 | abnorm_min_x, abnorm_max_x, abnorm_min_y, abnorm_max_y = _find_boundaries(abnorm_x, abnorm_y, 29 | abnorm_w, abnorm_h, 30 | image_shape, self.patch_size, 31 | 1 - self.max_abnorm_overlap) 32 | 33 | breast_w = (item['breast_maxx'] - item['breast_minx']) / 2 34 | breast_x = int(breast_w + item['breast_minx']) 35 | breast_h = (item['breast_maxy'] - item['breast_miny']) / 2 36 | breast_y = int(breast_h + item['breast_miny']) 37 | 38 | breast_min_x, breast_max_x, breast_min_y, breast_max_y = _find_boundaries(breast_x, breast_y, 39 | breast_w, breast_h, 40 | image_shape, self.patch_size, 41 | self.min_breast_overlap) 42 | counter = 0 43 | while (True): 44 | patch_y = torch.randint(breast_min_y, breast_max_y, (1,)) 45 | patch_x = torch.randint(breast_min_x, breast_max_x, (1,)) 46 | 47 | if (patch_x < abnorm_min_x or patch_x > abnorm_max_x) and ( 48 | patch_y < abnorm_min_y or patch_y > abnorm_max_y): 49 | break 50 | counter += 1 51 | if counter == 5: 52 | # print('Giving up') 53 | return self.other_tranform(sample) 54 | 55 | out_tensors = [] 56 | for image_tensor in image_tensor_list: 57 | image_tensor = image_tensor[patch_y: patch_y + self.patch_size[1], patch_x: patch_x + self.patch_size[0]] 58 | out_tensors.append(image_tensor) 59 | 60 | item['pathology'] = 'NORMAL' 61 | sample = {'image_tensor_list': out_tensors, 'item': item} 62 | return sample 63 | pass 64 | 65 | 66 | def normal_patch_transform_wrapper(other_tranform, probability=0.5, patch_size=(1024, 1024), min_breast_overlap=0.5, max_abnorm_overlap=0.1, max_tries=5): 67 | def perform(sample): 68 | choice = torch.randint(0, 100, (1,)) 69 | if choice >= probability * 100: 70 | return other_tranform(sample) 71 | else: 72 | image_tensor_list, item = sample['image_tensor_list'], sample['item'] 73 | image_shape = image_tensor_list[0].shape 74 | 75 | abnorm_w = (item['maxx'] - item['minx']) / 2 76 | abnorm_x = int(abnorm_w + item['minx']) 77 | abnorm_h = (item['maxy'] - item['miny']) / 2 78 | abnorm_y = int(abnorm_h + item['miny']) 79 | 80 | abnorm_min_x, abnorm_max_x, abnorm_min_y, abnorm_max_y = _find_boundaries(abnorm_x, abnorm_y, 81 | abnorm_w, abnorm_h, 82 | image_shape, patch_size, 83 | 1 - max_abnorm_overlap) 84 | 85 | breast_w = (item['breast_maxx'] - item['breast_minx']) / 2 86 | breast_x = int(breast_w + item['breast_minx']) 87 | breast_h = (item['breast_maxy'] - item['breast_miny']) / 2 88 | breast_y = int(breast_h + item['breast_miny']) 89 | 90 | breast_min_x, breast_max_x, breast_min_y, breast_max_y = _find_boundaries(breast_x, breast_y, 91 | breast_w, breast_h, 92 | image_shape, patch_size, 93 | min_breast_overlap) 94 | counter = 0 95 | while (True): 96 | patch_y = torch.randint(breast_min_y, breast_max_y, (1,)) 97 | patch_x = torch.randint(breast_min_x, breast_max_x, (1,)) 98 | 99 | if (patch_x < abnorm_min_x or patch_x > abnorm_max_x) and ( 100 | patch_y < abnorm_min_y or patch_y > abnorm_max_y): 101 | break 102 | counter += 1 103 | if counter == 5: 104 | # print('Giving up') 105 | return other_tranform(sample) 106 | 107 | out_tensors = [] 108 | for image_tensor in image_tensor_list: 109 | image_tensor = image_tensor[patch_y: patch_y + patch_size[1], patch_x: patch_x + patch_size[0]] 110 | out_tensors.append(image_tensor) 111 | 112 | item['pathology'] = 'NORMAL' 113 | sample = {'image_tensor_list': out_tensors, 'item': item} 114 | return sample 115 | pass 116 | 117 | return perform 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CBIS DDSM Dataloader 2 | 3 | This repository facilitates the pre-processing of the CBIS-DDSM mammographic database. 4 | It involves downloading the whole database, converting the DICOM images to PNG and parsing the database files. 5 | Finally, a custom PyTorch Dataset and Dataloader can be created in a versatile way, covering whole mammograms, patches 6 | and segmentation masks of the lesions (see examples below). 7 | 8 | ## Setup 9 | First, install the project requirements using 10 | 11 | ```shell 12 | pip install -r requirements.txt 13 | ``` 14 | Next, the `config.json` file should be edited to specify the `download_path` variable. 15 | By default, the database will be downloaded in the root directory, in a new folder `CBIS-DDSM`. Alternatively, edit 16 | the file specifying a different dir: 17 | ```json lines 18 | { 19 | "download_path": "< path to download>", 20 | ... 21 | } 22 | ``` 23 | Next, run `setup.py`. The command-line arguments are supported: 24 | ```shell 25 | optional arguments: 26 | -h, --help show this help message and exit 27 | -c CONFIG_FILE, --config_file CONFIG_FILE 28 | Path to the configuration file. Default=config.json 29 | -d If used, dcm file will be deleted during conversion, to free up space.However, if download runs again it will need to download the whole dataset again. 30 | ``` 31 | The `setup.py` script will download the database to the provided path, convert 32 | the images to PNG format and pre-process the database csv files. Note that separate codes for each one of these 33 | processes are provided in the `utils` folder. 34 | 35 | ## Creating a dataset 36 | Datasets are created using the class `CBISDDSMDatasetFactory` that provides a versatile way to filter lesions, 37 | manage their attributes and apply transformations on the corresponding images. A detailed description of the factory 38 | functions is given below. Additionally, the folder `examples` provides common cases of dataset creation. Keep in mind that a different copy of `config.json` is 39 | provided in this folder, that should point to the same `download_path` with the original setup config. 40 | 41 | ### Lesion type and subset selection Training / Testing subsets 42 | The CBIS-DDSM database provides two distinct subsets for `mass` and `calcifiactions` lesion types 43 | that are further split into two subsets for training and testing purposes. 44 | With `CBISDDSMDatasetFactory` the selection of the appropriate subset or the merging of subsets is 45 | supported via the constructor arguments: 46 | 47 | ```python 48 | CBISDDSMDatasetFactory('./config.json', 49 | include_train_set=True, 50 | include_test_set=False, 51 | include_calcifications=False, 52 | include_masses=True) 53 | 54 | ``` 55 | 56 | ### Attribute manipulation 57 | In CBIS-DDSM, a broad set of attributes is provided for each lesion. The `CBISDDSMDatasetFactory` provides the function 58 | `.map_attribute_value()` to change a specific value into another one. For example, it is a common case for the 59 | `pathology` label `BENIGN_WITHOUT_CALLBACK` to be changed to `BENIGN`. This can be achieved with the following code: 60 | ```python 61 | dataset = CBISDDSMDatasetFactory('./config.json') \ 62 | .map_attribute_value('pathology', {'BENIGN_WITHOUT_CALLBACK': 'BENIGN'}) 63 | ``` 64 | Additionally, attributes that are not relevant can be dropped via the `.drop_attributes()` method: 65 | 66 | ```python 67 | dataset = CBISDDSMDatasetFactory('./config.json') \ 68 | .drop_attributes("assessment", "breast_density", "subtlety") 69 | ``` 70 | 71 | ### Patch transforms 72 | 73 | By default, the examples above will provide whole mammogram images. However, processing lesion patches is a common case 74 | for mammographic CAD systems. `CBISDDSMDatasetFactory` provides two types of patch transforms: 75 | 76 | #### A. Centered patch transform 77 | By using the option 78 | ```python 79 | .lesion_patches_centered(shape = (1024, 1024)) 80 | ``` 81 | the factory will provide patches that are centered around each lesion. 82 | The `shape` parameter specifies the dimensions of the patch in pixels. 83 | The default size is set to `(1024, 1024)`, which is sufficient for all the masses in the dataset. 84 | 85 | ```python 86 | dataset = CBISDDSMDatasetFactory('./config.json') \ 87 | .lesion_patches_centered() 88 | ``` 89 | Please note that there are some cases where the mass is located near the boundary of the mammogram image. 90 | In these cases the patch is adjusted (translated) to contain the mass even if it is not centered. 91 | An example of this option is given in `examples/centered_patch_classification_dataset.py`. 92 | #### B. Random patch transform 93 | By using the option 94 | ```python 95 | .lesion_patches_random(shape = (1024, 1024), min_overlap=0.9) 96 | ``` 97 | the factory will provide random patches of size `shape`, sampled on random locations around the lesion. 98 | The `min_overlap` parameter specifies the minimum percentage of overlap that the patch should have with the lesion. 99 | An example of this option is given in `examples/random_patch_classification_dataset.py`. 100 | ### Image transforms 101 | `CBISDDSMDatasetFactory` supports the application of PyTorch image transforms on the CBIS-DDSM samples, 102 | both whole images and patches. This is achieved via the function 103 | ```python 104 | .add_image_transforms(transform_list, for_train = True, for_val = True) 105 | ``` 106 | that accepts a list of transforms. The parameters `for_train` and `for_val` constrain the application of the 107 | transform to a specific mode (training mode or validation mode). In this way, the preprocessing transforms can be applied 108 | to all the samples, but the augmentation transforms can be applied only for training. 109 | After the dataset creation, the functions `.train_mode()` and `test_mode()` activate the corresponding configuration. 110 | An example of this option is given in `examples/centered_patch_classification_train_val_split.py`. 111 | ### Splitting 112 | The dataset returned from `CBISDDSMDatasetFactory` provides two options for splitting the dataset for training and validation 113 | purposed. 114 | 115 | #### A. Train-val split 116 | By using the option 117 | ```python 118 | dataset.split_train_val(self, val_ratio, shuffle=False, random_state=None) 119 | ``` 120 | the dataset will return a tuple with two distinct datasets, one for training and one for testing. 121 | The parameter `val_ratio` specifies the ratio that will be held out for validation. 122 | An example of this option is given in `examples/centered_patch_classification_train_val_split.py` 123 | #### B. Cross-validation 124 | By using the option 125 | ```python 126 | dataset.split_crossval(self, folds, shuffle=False, random_state=None) 127 | ``` 128 | the dataset will return a tuple with `folds` splits of the dataset in training/validation. For each split, the training 129 | dataset will contain a ratio of `(folds - 1)/folds` of the total samples while the validation set will 130 | contain `1/folds` of the total samples. The partitioning is performed in a mutually exclusive fashion, i.e. 131 | a sample is used exactly `folds` times for validation. An example of this option is given in 132 | `examples/centered_patch_classification_crossval.py`. -------------------------------------------------------------------------------- /datasets/generic_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union, Tuple 3 | 4 | from torch.utils.data import Dataset 5 | from PIL import Image, ImageFile 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True # Workaround found in: https://stackoverflow.com/questions/42462431/oserror-broken-data-stream-when-reading-image-file 7 | import torch 8 | from matplotlib import pyplot as plt 9 | from torchvision.transforms import functional as F 10 | from torchvision.transforms import Compose 11 | import pandas as pd 12 | 13 | class CBISDDSMGenericDataset(Dataset): 14 | def __init__(self, 15 | dataframe: pd.DataFrame, 16 | download_path: str, 17 | masks: bool = False, 18 | transform: Compose = None, 19 | train_image_transform: Union[List[torch.nn.Module], Tuple[torch.nn.Module]] = None, 20 | train_image_transform_for_mask_flags=None, 21 | test_image_transform: Union[List[torch.nn.Module], Tuple[torch.nn.Module]] = None, 22 | test_image_transform_for_mask_flags=None): 23 | self.dataframe: pd.DataFrame = dataframe 24 | self.download_path: str = download_path 25 | self.transform: Compose = transform 26 | self.include_masks: bool = masks 27 | self.current_index: int = 0 28 | self.__train_mode: bool = True 29 | self.__test_mode: bool = False 30 | self._train_image_transforms: Union[List[torch.nn.Module], Tuple[torch.nn.Module]] = train_image_transform 31 | self._train_image_transform_for_mask_flags = train_image_transform_for_mask_flags 32 | self._test_image_transforms = test_image_transform 33 | self._test_image_transform_for_mask_flags = test_image_transform_for_mask_flags 34 | 35 | def __getitem__(self, index): 36 | item = self.dataframe.iloc[index].to_dict() 37 | img_path = os.path.join(self.download_path, item['image_path']) 38 | image = Image.open(img_path) 39 | max_value = 65536 if image.mode == 'I' else 256 40 | image_tensor = F.pil_to_tensor(image).float() 41 | image_tensor /= max_value 42 | image_tensor_list = [image_tensor] 43 | 44 | if self.include_masks: 45 | mask_img_path = os.path.join(self.download_path, item['mask_path']) 46 | mask_image = Image.open(mask_img_path) 47 | mask_image_tensor = F.pil_to_tensor(mask_image).float() 48 | mask_image_tensor /= 255 49 | image_tensor_list.append(mask_image_tensor) 50 | 51 | sample = {'image_tensor_list': image_tensor_list, 'item': item} 52 | 53 | if self.transform is not None: 54 | sample = self.transform(sample) 55 | 56 | if self.__train_mode and self._train_image_transforms is not None: 57 | for transform, mask_flag in zip(self._train_image_transforms, self._train_image_transform_for_mask_flags): 58 | state = torch.get_rng_state() 59 | for i in range(len(sample['image_tensor_list'])): 60 | torch.set_rng_state(state) 61 | sample['image_tensor_list'][i] = transform(sample['image_tensor_list'][i]) 62 | if not mask_flag: 63 | break 64 | elif self.__test_mode and self._test_image_transforms is not None: 65 | for transform, mask_flag in zip(self._test_image_transforms, self._test_image_transform_for_mask_flags): 66 | state = torch.get_rng_state() 67 | for i in range(len(sample['image_tensor_list'])): 68 | torch.set_rng_state(state) 69 | sample['image_tensor_list'][i] = transform(sample['image_tensor_list'][i]) 70 | if not mask_flag: 71 | break 72 | else: 73 | raise Exception("No train/test mode selected") 74 | 75 | return sample['image_tensor_list'], sample['item'] 76 | 77 | def __len__(self): 78 | return len(self.dataframe.index) 79 | 80 | def __iter__(self): 81 | self.current_index = 0 82 | return self 83 | 84 | def __next__(self): 85 | if self.current_index < len(self): 86 | x = self[self.current_index] 87 | self.current_index += 1 88 | return x 89 | else: 90 | raise StopIteration 91 | 92 | def _get_img_visualize(self, image): 93 | return image 94 | 95 | def _get_label_visualize(self, item): 96 | return f'{item["patient_id"]}_{item["left_right"]}_{item["view"]}' 97 | 98 | def train_mode(self): 99 | self.__train_mode = True 100 | self.__test_mode = False 101 | return self 102 | 103 | def test_mode(self): 104 | self.__train_mode = False 105 | self.__test_mode = True 106 | return self 107 | 108 | def _split_dataframe(self, split_ratio, shuffle=False, random_state=None): 109 | if shuffle: 110 | dataframe = self.dataframe.sample(frac=1, random_state=random_state) 111 | else: 112 | dataframe = self.dataframe 113 | 114 | num_samples = len(dataframe.index) 115 | num_samples1 = int(num_samples * split_ratio) 116 | 117 | dataframe1 = dataframe.iloc[num_samples1:, :] 118 | dataframe2 = dataframe.iloc[:num_samples1, :] 119 | 120 | return dataframe1, dataframe2 121 | 122 | def _split_dataframe_crossval(self, folds, shuffle=False, random_state=None): 123 | if shuffle: 124 | dataframe = self.dataframe.sample(frac=1, random_state=random_state) 125 | else: 126 | dataframe = self.dataframe 127 | 128 | num_samples = len(dataframe.index) 129 | num_sample_per_fold = int(num_samples / folds) 130 | 131 | fold_dataframes = [] 132 | for i in range(folds): 133 | start_i = i * num_sample_per_fold 134 | end_i = (i + 1) * num_sample_per_fold 135 | fold_dataframe = dataframe.iloc[start_i:end_i, :] 136 | fold_dataframes.append(fold_dataframe) 137 | 138 | cv_dataframe_pairs = [] 139 | for i in range(folds): 140 | train_dataframe = pd.concat(list(d for ind, d in enumerate(fold_dataframes) if ind != i), ignore_index=True) 141 | val_dataframe = fold_dataframes[i] 142 | cv_dataframe_pairs.append((train_dataframe, val_dataframe)) 143 | 144 | return cv_dataframe_pairs 145 | 146 | def visualize(self): 147 | if self.include_masks: 148 | figure = plt.figure(figsize=(1, 2)) 149 | else: 150 | figure = plt.figure() 151 | 152 | def plot(e): 153 | plt.clf() 154 | image_list, item = next(self) 155 | 156 | image = image_list[0].transpose(0, 2) 157 | 158 | if self.include_masks: 159 | figure.add_subplot(1, 2, 1) 160 | 161 | mask = image_list[1].transpose(0, 2) 162 | 163 | plt.imshow(self._get_img_visualize(image)) 164 | plt.title(self._get_label_visualize(item), backgroundcolor='white') 165 | figure.add_subplot(1, 2, 2) 166 | plt.imshow(self._get_img_visualize(mask), cmap='gray') 167 | else: 168 | plt.imshow(self._get_img_visualize(image), cmap='gray') 169 | plt.title(self._get_label_visualize(item), backgroundcolor='white') 170 | plt.draw() 171 | 172 | figure.canvas.mpl_connect('key_press_event', plot) 173 | plt.show() -------------------------------------------------------------------------------- /utils/ddsm_preprocessor.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import argparse 3 | import os 4 | 5 | import cv2 6 | from PIL import Image 7 | import numpy as np 8 | import pandas 9 | from matplotlib import pyplot as plt 10 | from tqdm import tqdm 11 | import concurrent.futures 12 | 13 | 14 | class CBISDDSMPreprocessor: 15 | def __init__(self, download_path, csv_files_train, csv_files_test): 16 | self.__download_path = download_path 17 | self.__csv_files_train = csv_files_train 18 | self.__csv_files_test = csv_files_test 19 | self.__not_found = 0 20 | self.__other_errors = 0 21 | 22 | @staticmethod 23 | def __locate_lesion(mask_img, item_dict): 24 | mask = np.array(mask_img) 25 | ys, xs = np.where(mask) 26 | item_dict['minx'] = xs.min().tolist() 27 | item_dict['maxx'] = xs.max().tolist() 28 | item_dict['miny'] = ys.min().tolist() 29 | item_dict['maxy'] = ys.max().tolist() 30 | item_dict['cx'] = int((item_dict['maxx'] - item_dict['minx']) / 2 + item_dict['minx']) 31 | item_dict['cy'] = int((item_dict['maxy'] - item_dict['miny']) / 2 + item_dict['miny']) 32 | 33 | return True 34 | 35 | @staticmethod 36 | def __locate_breast(image, item_dict): 37 | image = (np.array(image) / 255).astype(np.uint8) 38 | threshold = int(0.05 * image.max()) 39 | _, image_binary = cv2.threshold(image, threshold, 255, cv2.THRESH_BINARY) 40 | contours, _ = cv2.findContours(image_binary, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 41 | contours_areas = [cv2.contourArea(cont) for cont in contours] 42 | biggest_contour_idx = np.argmax(contours_areas) 43 | breast_contour = contours[biggest_contour_idx] 44 | breast_contour = cv2.convexHull(breast_contour) 45 | 46 | epsilon = 0.001 * cv2.arcLength(breast_contour, True) 47 | approxCurve = cv2.approxPolyDP(breast_contour, epsilon, True) 48 | 49 | # print(approxCurve.shape[0]) 50 | # image_binary = np.dstack([image_binary]*3) 51 | # cv2.drawContours(image_binary, [approxCurve], 0, (255, 0, 0), 3) 52 | # plt.imshow(image_binary, cmap='gray') 53 | # plt.show() 54 | 55 | item_dict['breast_minx'] = breast_contour[:, 0, 0].min() 56 | item_dict['breast_maxx'] = breast_contour[:, 0, 0].max() 57 | item_dict['breast_miny'] = breast_contour[:, 0, 1].min() 58 | item_dict['breast_maxy'] = breast_contour[:, 0, 1].max() 59 | item_dict['breast_cx'] = int((item_dict['breast_maxx'] - item_dict['breast_minx']) / 2 + item_dict['breast_minx']) 60 | item_dict['breast_cy'] = int((item_dict['breast_maxy'] - item_dict['breast_miny']) / 2 + item_dict['breast_miny']) 61 | item_dict['breast_poly'] = approxCurve[:, 0, :].tolist() 62 | return True 63 | 64 | def __payload(self, row): 65 | item_dict = { 66 | "patient_id": row[0], 67 | "breast_density": row[1], 68 | "left_right": row[2], 69 | "view": row[3], 70 | "lesion_type": row[5], 71 | "type1": row[6], 72 | "type2": row[7], 73 | "assessment": row[8], 74 | "pathology": row[9], 75 | "subtlety": row[10], 76 | "image_path": os.path.splitext(row[11])[0] + '.png', 77 | "patch_path": os.path.splitext(row[12])[0] + '.png', 78 | "mask_path": os.path.splitext(row[13])[0] + '.png' 79 | } 80 | image = Image.open(os.path.join(self.__download_path, item_dict['image_path'])) 81 | mask_img = Image.open(os.path.join(self.__download_path, item_dict['mask_path'])) 82 | # CBIS-DDSM has the problem that sometimes the paths of the patch and the mask are swapped, so that 83 | # 'patch_path' = and vice versa. 84 | # This is detected by comparing the image size and the mask size. 85 | # However, sometimes the mask image is a little smaller than the image regardless of whether they are swapped. 86 | # So we need to check if the mask is a lot smaller as well. 87 | if image.size != mask_img.size and (mask_img.size[0] < image.size[0] * 0.5 or mask_img.size[1] < image.size[1] * 0.5): 88 | tmp = item_dict['mask_path'] 89 | item_dict['mask_path'] = item_dict['patch_path'] 90 | item_dict['patch_path'] = tmp 91 | mask_img = Image.open(os.path.join(self.__download_path, item_dict['mask_path'])) 92 | 93 | result = self.__locate_lesion(mask_img, item_dict) 94 | 95 | if not result: 96 | raise Exception() 97 | 98 | result = self.__locate_breast(image, item_dict) 99 | 100 | if not result: 101 | raise Exception() 102 | 103 | return item_dict 104 | 105 | def __parse_file(self, file_list, out_csv_path): 106 | rows = [] 107 | data = [] 108 | for csv_file in file_list: 109 | with open(csv_file) as fin: 110 | reader = csv.reader(fin, delimiter=',', quotechar='"') 111 | next(reader) 112 | 113 | for row in reader: 114 | rows.append(row) 115 | 116 | with concurrent.futures.ThreadPoolExecutor() as executor: 117 | future_to_row = {executor.submit(self.__payload, row): row for row in rows} 118 | for future in tqdm(concurrent.futures.as_completed(future_to_row), total=len(rows), unit='abnormalities'): 119 | try: 120 | item_dict = future.result() 121 | except FileNotFoundError: 122 | self.__not_found += 1 123 | continue 124 | except Exception as e: 125 | self.__other_errors += 1 126 | continue 127 | 128 | data.append(item_dict) 129 | 130 | df = pandas.DataFrame(data) 131 | df.to_csv(out_csv_path) 132 | 133 | def start(self): 134 | print('Processing {} abnormality csv files for training.'.format(len(self.__csv_files_train))) 135 | self.__parse_file(self.__csv_files_train, os.path.join(self.__download_path, 'lesions_train.csv')) 136 | 137 | print('Processing {} abnormality csv files for testing.'.format(len(self.__csv_files_train))) 138 | self.__parse_file(self.__csv_files_test, os.path.join(self.__download_path, 'lesions_test.csv')) 139 | 140 | if self.__not_found > 0: 141 | print('Could not locate {} files. Please re-run the downloader.'.format(self.__not_found)) 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser(prog='CBIS DDSM Preprocessor') 146 | parser.add_argument('-p', '--path', default='../CBIS_DDSM', 147 | help='Path to the download folder. It will be created if not existing.') 148 | parser.add_argument('-tr', '--csv_files_train', nargs='+', 149 | default=['../resources/calc_case_description_train_set.csv', 150 | '../resources/mass_case_description_train_set.csv'], 151 | help='One or more csv files to proces, as downloaded by TCIA repository.') 152 | parser.add_argument('-te', '--csv_files_test', nargs='+', 153 | default=['../resources/calc_case_description_test_set.csv', 154 | '../resources/mass_case_description_test_set.csv'], 155 | help='One or more csv files to proces, as downloaded by TCIA repository.') 156 | args = parser.parse_args() 157 | preprocessor = CBISDDSMPreprocessor(args.path, args.csv_files_train, args.csv_files_test) 158 | preprocessor.start() 159 | -------------------------------------------------------------------------------- /ddsm_dataset_factory.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os.path 4 | import pandas as pd 5 | from typing import List, Dict, Tuple 6 | from torchvision.transforms import Compose 7 | from tqdm import tqdm 8 | from datasets.generic_dataset import CBISDDSMGenericDataset 9 | from transforms.patches_centered import CenteredPatches 10 | from transforms.patches_random import RandomPatches 11 | from transforms.patches_normal import normal_patch_transform_wrapper 12 | from datasets.classification_dataset import CBISDDSMClassificationDataset 13 | from PIL import Image 14 | import numpy as np 15 | 16 | class CBISDDSMDatasetFactory: 17 | def __init__(self, 18 | config_path, 19 | include_train_set=True, 20 | include_test_set=False, 21 | include_masses=True, 22 | include_calcifications=False) -> None: 23 | self.__config = self.__read_config(config_path) 24 | self.__download_folder = self.__config["download_path"] 25 | self.__train: bool = include_train_set 26 | self.__test: bool = include_test_set 27 | self.__dataframe = None 28 | self.__excluded_attrs: List[str] = [] 29 | self.__excluded_values: Dict[str, set] = {'lesion_type': {'mass', 'calcification'}} 30 | self.__attribute_mapped_values: Dict[str, Dict[str, str]] = {} 31 | self.__transform_list = [] 32 | self.__image_transform_list = [] 33 | self.__image_transform_list_applied_training = [] 34 | self.__image_transform_list_applied_validation = [] 35 | self.__image_transform_list_applied_mask = [] 36 | self.__plus_normal = False 37 | self.__patch_transform_selected = False 38 | self.__from_cache = False 39 | 40 | if include_masses: 41 | self.__excluded_values['lesion_type'].remove('mass') 42 | 43 | if include_calcifications: 44 | self.__excluded_values['lesion_type'].remove('calcification') 45 | 46 | @staticmethod 47 | def __read_config(config_path): 48 | with open(config_path, 'r') as cf: 49 | config = json.load(cf) 50 | return config 51 | 52 | def __fetch_filter_lesions(self): 53 | try: 54 | csv_file_list = [] 55 | if self.__train: 56 | csv_file_list.append(os.path.join(self.__download_folder, 'lesions_train.csv')) 57 | if self.__test: 58 | csv_file_list.append(os.path.join(self.__download_folder, 'lesions_test.csv')) 59 | 60 | self.__dataframe = pd.concat((pd.read_csv(f) for f in csv_file_list), ignore_index=True) 61 | except: 62 | print(f'Database seems not properly set up in folder {self.__config["download_path"]}. Please (re)run setup.py or check the paths in config.json.') 63 | return 64 | 65 | self.__dataframe.drop(self.__excluded_attrs, axis=1, inplace=True) 66 | 67 | for attribute, value_set in self.__excluded_values.items(): 68 | self.__dataframe = self.__dataframe.loc[~self.__dataframe[attribute].isin(list(value_set))] 69 | 70 | for attribute, mapping in self.__attribute_mapped_values.items(): 71 | for v1, v2 in mapping.items(): 72 | self.__dataframe[attribute].replace(v1, v2, inplace=True) 73 | 74 | self.__dataframe.reset_index(inplace=True, drop=True) 75 | 76 | def drop_attribute_values(self, attribute: str, *value_list: str): 77 | value_set = self.__excluded_values.get(attribute, set()) 78 | for v in value_list: 79 | value_set.add(v) 80 | return self 81 | 82 | def map_attribute_value(self, attribute: str, mapping: Dict[str, str]): 83 | attribute_mapping = self.__attribute_mapped_values.get(attribute, dict()) 84 | attribute_mapping.update(mapping) 85 | self.__attribute_mapped_values[attribute] = attribute_mapping 86 | return self 87 | 88 | def drop_attributes(self, *attribute_list: str): 89 | self.__excluded_attrs.extend(attribute_list) 90 | return self 91 | 92 | def show_counts(self): 93 | if not self.__from_cache: 94 | self.__fetch_filter_lesions() 95 | 96 | df = self.__dataframe 97 | labels_list = ["lesion_type", "type1", "type2", "pathology", "assessment", "breast_density", "subtlety"] 98 | print(os.linesep) 99 | for label in labels_list: 100 | if label in df.columns: 101 | print(df[label].value_counts(sort=True, ascending=False)) 102 | print(os.linesep) 103 | 104 | return self 105 | 106 | def lesion_patches_centered(self, shape: Tuple[int] = (1024, 1024)): 107 | if self.__patch_transform_selected: 108 | raise Exception('Patch transform already selected!') 109 | self.__transform_list.append(CenteredPatches(shape)) 110 | self.__patch_transform_selected = True 111 | return self 112 | 113 | def lesion_patches_random(self, shape: Tuple[int] = (1024, 1024), min_overlap=0.9, normal_probability=0.0): 114 | if self.__patch_transform_selected: 115 | raise Exception('Patch transform already selected!') 116 | patch_transform = RandomPatches(shape, min_overlap=min_overlap) 117 | if normal_probability > 0: 118 | self.__plus_normal = True 119 | patch_transform = normal_patch_transform_wrapper(patch_transform, normal_probability, shape, 120 | 1 - min_overlap) 121 | self.__transform_list.append(patch_transform) 122 | self.__patch_transform_selected = True 123 | return self 124 | 125 | def cache_here(self): 126 | self.__fetch_filter_lesions() 127 | cache_name = hashlib.sha1(pd.util.hash_pandas_object(self.__dataframe, index=True).values) 128 | for trans in self.__transform_list: 129 | cache_name.update(bytes(str(trans), 'utf-8')) 130 | for trans in self.__image_transform_list: 131 | cache_name.update(bytes(str(trans), 'utf-8')) 132 | cache_name = cache_name.hexdigest() 133 | 134 | cache_path = os.path.join(self.__download_folder, 'cache', cache_name) 135 | cache_dataframe_path = os.path.join(cache_path, "dataframe.csv") 136 | 137 | if os.path.exists(cache_path) and os.path.exists(cache_dataframe_path): 138 | self.__dataframe = pd.read_csv(cache_dataframe_path) 139 | 140 | else: 141 | os.makedirs(cache_path, exist_ok=True) 142 | dataset = CBISDDSMGenericDataset(self.__dataframe, self.__download_folder, 143 | masks=True, 144 | transform=Compose(self.__transform_list), 145 | train_image_transform=self.__image_transform_list, 146 | test_image_transform=self.__image_transform_list) 147 | 148 | counter = 0 149 | for a in tqdm(dataset): 150 | sample_name = f"{counter:05d}" 151 | 152 | image_name = sample_name + ".png" 153 | image_path = os.path.join(cache_path, image_name) 154 | image = a[0][0].cpu().detach().numpy().squeeze() * 255 155 | img_pil = Image.fromarray(image.astype(np.uint8)) 156 | img_pil.save(image_path, ) 157 | self.__dataframe.at[counter, "image_path"] = image_name 158 | 159 | mask_name = sample_name + "_mask.png" 160 | mask_path = os.path.join(cache_path, mask_name) 161 | mask = a[0][1].cpu().detach().numpy().squeeze() * 255 162 | mask_pil = Image.fromarray(mask.astype(np.uint8)) 163 | mask_pil.save(mask_path) 164 | self.__dataframe.at[counter, "mask_path"] = mask_name 165 | 166 | counter += 1 167 | 168 | self.__dataframe.to_csv(cache_dataframe_path) 169 | 170 | self.__transform_list.clear() 171 | self.__image_transform_list.clear() 172 | self.__image_transform_list_applied_training.clear() 173 | self.__image_transform_list_applied_validation.clear() 174 | self.__download_folder = cache_path 175 | self.__from_cache = True 176 | 177 | return self 178 | 179 | 180 | def add_image_transforms(self, transform_list: List, for_train: bool = True, for_val: bool = True, for_mask=True): 181 | self.__image_transform_list.extend(transform_list) 182 | self.__image_transform_list_applied_training.extend([for_train]*len(transform_list)) 183 | self.__image_transform_list_applied_validation.extend([for_val]*len(transform_list)) 184 | self.__image_transform_list_applied_mask.extend([for_mask]*len(transform_list)) 185 | return self 186 | 187 | def split_cross_validation(self, k_folds=5): 188 | self.__split_validation = False 189 | self.__split_cross_validation = True 190 | self.__cross_validation_folds = k_folds 191 | return self 192 | 193 | def create_classification(self, attribute: str, mask_input: bool = False): 194 | if not self.__from_cache: 195 | self.__fetch_filter_lesions() 196 | 197 | label_list = self.__dataframe[attribute].unique().tolist() 198 | if self.__plus_normal: 199 | label_list.append('NORMAL') 200 | 201 | train_image_transforms = [trans for trans, ft in 202 | zip(self.__image_transform_list, self.__image_transform_list_applied_training) if 203 | ft] 204 | train_image_transform_for_mask_flags = [flag for flag, ft in 205 | zip(self.__image_transform_list_applied_mask, self.__image_transform_list_applied_training) if 206 | ft] 207 | val_transforms = [trans for trans, fv in 208 | zip(self.__image_transform_list, self.__image_transform_list_applied_validation) if fv] 209 | val_image_transform_for_mask_flags = [flag for flag, fv in 210 | zip(self.__image_transform_list_applied_mask, self.__image_transform_list_applied_validation) if fv] 211 | 212 | dataset = CBISDDSMClassificationDataset(self.__dataframe, self.__download_folder, attribute, 213 | label_list, 214 | masks=mask_input, transform=Compose(self.__transform_list), 215 | train_image_transform=train_image_transforms, 216 | train_image_transform_for_mask_flags=train_image_transform_for_mask_flags, 217 | test_image_transform=val_transforms, 218 | test_image_transform_for_mask_flags=val_image_transform_for_mask_flags) 219 | 220 | return dataset 221 | --------------------------------------------------------------------------------