├── docs ├── tox_pattern.png └── Scheme_extended.png ├── .gitignore ├── requirements.txt ├── utils ├── timed_input.py ├── imports.py ├── training_utils.py ├── data_processing.py └── image_dataset_reader.py ├── LICENSE ├── model_use_example.py ├── data_preparation └── compute_means.py ├── configs ├── cfg_training_cnn.py └── cfg_anomaly_detector.py ├── models ├── losses.py └── pretrained_networks.py ├── README.md ├── train_cnn.py └── anomaly_detector.py /docs/tox_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Boehringer-Ingelheim/anomaly-detection-in-histology/HEAD/docs/tox_pattern.png -------------------------------------------------------------------------------- /docs/Scheme_extended.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Boehringer-Ingelheim/anomaly-detection-in-histology/HEAD/docs/Scheme_extended.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Pycharm editor settings 2 | .idea 3 | .DS_Store 4 | 5 | *.pyc 6 | 7 | tb_logs/ 8 | results/ 9 | data/ 10 | 11 | 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.1 2 | numpy==1.22.3 3 | pandas==1.4.2 4 | Pillow==9.0.1 5 | prettytable==3.3.0 6 | scikit_learn==1.0.2 7 | seaborn==0.11.2 8 | torch==1.11.0 9 | torchvision==0.12.0 10 | tqdm==4.64.0 11 | -------------------------------------------------------------------------------- /utils/timed_input.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import select 3 | import logging 4 | 5 | 6 | def limited_time_input(message, delay): 7 | 8 | print(message) 9 | 10 | i, _, _ = select.select( [sys.stdin], [], [], delay) 11 | 12 | if i: 13 | read = sys.stdin.readline().strip() 14 | else: 15 | logging.info("skipping use input") 16 | #print("skipping use input") 17 | read = "" 18 | 19 | return read 20 | 21 | if __name__ == "__main__": 22 | choice = limited_time_input("please input your choice", 10) 23 | if not choice: 24 | print('nothing was chosen') 25 | else: 26 | print('{} was chosen'.format(choice)) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Boehringer Ingelheim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model_use_example.py: -------------------------------------------------------------------------------- 1 | from models.pretrained_networks import EfficientNet_B0_320 2 | from utils.training_utils import apply_net 3 | from utils.data_processing import CodesProcessor 4 | from torch.utils.data import DataLoader 5 | from utils.image_dataset_reader import HistImagesDataset 6 | from torchvision import transforms 7 | 8 | path_to_tissues = ( 9 | {'folder': "path_to_images", 'label': 'liver', 'ext': 'png'}, 10 | ) 11 | 12 | cnn_model_path = 'path_to_trained_model' # pretrained BIHN model, *.pt file 13 | 14 | n_samples_per_folder = 3 # number of image samples to test on 15 | n_classes = 16 # number of classes the model was trained on 16 | dev = "cpu" 17 | 18 | tr_normalize = transforms.Normalize(mean=(0.5788, 0.3551, 0.5655), std=(1, 1, 1)) 19 | transforms_seq = transforms.Compose([transforms.ToTensor(), tr_normalize]) 20 | images_dataset = HistImagesDataset(*path_to_tissues, n_samples=n_samples_per_folder, transform=transforms_seq) 21 | test_data_loader = DataLoader(images_dataset) 22 | 23 | model = EfficientNet_B0_320(path_trained_model=cnn_model_path, n_classes=n_classes, dev=dev).to(dev) 24 | code_processor = CodesProcessor() 25 | apply_net(model, dev, test_data_loader, verbose=True, code_processor=code_processor) 26 | features = code_processor.get_codes() 27 | 28 | print(f"There are {features.shape[0]} feature vectors of length {features.shape[1]}") 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /data_preparation/compute_means.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | Image.MAX_IMAGE_PIXELS = 600000000 3 | from torchvision import transforms 4 | 5 | import os 6 | import sys 7 | 8 | prj_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(prj_root) 10 | 11 | from utils.image_dataset_reader import HistImagesDataset, samples_per_location_from_samples_per_class 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | root_folder = os.path.join(prj_root, 'data/') 17 | 18 | normal_images_paths = ( 19 | 20 | {'folder': root_folder + "train/mt_mouse_liver", 'label': 'liver', 'ext': 'png'}, 21 | 22 | ) 23 | 24 | img_extension = 'png' 25 | patch_size = (256, 256) 26 | n_patches = 6921 #None # None means all # number of images per class 27 | 28 | 29 | def image_means(normal_images_paths, patch_size, n_patches=None): 30 | 31 | transforms_seq = transforms.Compose([transforms.RandomCrop(patch_size), transforms.ToTensor()]) 32 | 33 | n_patches_per_location = None 34 | if n_patches: 35 | n_patches_per_location = samples_per_location_from_samples_per_class(*normal_images_paths, samples_per_class=n_patches) 36 | 37 | images_dataset = HistImagesDataset(*normal_images_paths, n_samples=n_patches_per_location, transform=transforms_seq) 38 | 39 | means, stds = comp_means(images_dataset, sum(n_patches_per_location)) 40 | 41 | return means, stds 42 | 43 | 44 | def comp_means(images_dataset, n_patches): 45 | 46 | im_loader = DataLoader(images_dataset, num_workers=0) 47 | 48 | progress = tqdm(im_loader, total=n_patches) 49 | 50 | means = 0 51 | stds = 0 52 | n_patches_read = 0 53 | for samples in progress: 54 | image = samples['image'] 55 | 56 | image = torch.squeeze(image) 57 | std, mean = torch.std_mean(image, dim=(1, 2)) 58 | means += mean 59 | stds += std 60 | n_patches_read += 1 61 | 62 | pass 63 | 64 | means = means / n_patches_read 65 | stds = stds / n_patches_read 66 | print('average value: {}, std value: {}, based on {} images'.format(means, stds, n_patches_read)) 67 | 68 | return means, stds 69 | 70 | #--------------- 71 | 72 | image_means(normal_images_paths, patch_size, n_patches=n_patches) -------------------------------------------------------------------------------- /utils/imports.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import pickle 4 | import ast 5 | 6 | def import_config_from_file(path): 7 | 8 | modul_name = os.path.splitext(os.path.split(path)[1])[0] 9 | spec = importlib.util.spec_from_file_location(modul_name, path) 10 | cfg = importlib.util.module_from_spec(spec) 11 | spec.loader.exec_module(cfg) 12 | 13 | return cfg 14 | 15 | 16 | def show_configuration(cfg, file=None): 17 | 18 | if not file: 19 | print('---used configuration---') 20 | for att in dir(cfg): 21 | 22 | if not att.startswith('_'): 23 | attr_value = getattr(cfg, att) 24 | 25 | normal_print = True 26 | if isinstance(attr_value, (list, tuple, dict)) and len(str(attr_value)) > 100: 27 | for val in attr_value: 28 | if isinstance(val, (list, tuple, dict)): 29 | normal_print = False 30 | break 31 | 32 | if normal_print: 33 | print('{:<30} {:>100}'.format(att, str(attr_value)), file=file) 34 | else: 35 | for i, current_attr in enumerate(attr_value): 36 | print('{:<30} {:>100}'.format(att + '[' + str(i) + ']', str(current_attr)), file=file) 37 | 38 | 39 | def save_configuration(cfg, path): 40 | 41 | with open(path, 'w') as f: 42 | show_configuration(cfg, file=f) 43 | 44 | 45 | def pickle_configuraton_as_dictionary(cfg, path): 46 | 47 | dic = {} 48 | for att in dir(cfg): 49 | 50 | if not att.startswith('_'): 51 | dic[att] = getattr(cfg, att) 52 | 53 | pickle.dump(dic, open(path, 'wb')) 54 | 55 | 56 | def update_configuration(parser=None, cfg=None): 57 | """ 58 | Adds or overwrites configuration with parameters provided in command line 59 | 60 | :param parser: parser from argparse 61 | :param cfg: configuration (will be changed to new configuration inplace) 62 | :return: configuration with overwritten/added parameters from command line 63 | 64 | cfg = get_configuration(parser) - will load configuration from file provided --config cfg_file.py, adds/overwrites 65 | parameters from command line with prefix P, e.g. -Pdevice='cuda:0' -Pseed=100 overwrites parameters 'device' and 66 | 'seed' from cfg_file.py 67 | 68 | cfg = get_configuration(parser, cfg) - will return configuration cfg with overwritten parameters from command line 69 | with prefix P, e.g. -Pdevice='cuda:0' -Pseed=100 overwrites parameter 'device' ad 'seed' from cfg. 70 | Also works inplace: get_configuration(parser, cfg) 71 | 72 | cfg = get_configuration(cfg) - just returns configuration cfg (nothing is done) 73 | 74 | """ 75 | 76 | if cfg is None and parser is None: 77 | logging.error("if configuration cfg is not given, parser must be provided to load configuration") 78 | raise 79 | 80 | if parser is not None: 81 | parser.add_argument('--config', type=str, help="config file", default=None) 82 | parser.add_argument('-P', action='append') 83 | args = parser.parse_args() 84 | 85 | if cfg is None: 86 | cfg = import_config_from_file(args.config) 87 | 88 | if parser.parse_args().P is not None: 89 | overwrite_parameters = {} 90 | for key, value in [s.split('=') for s in args.P]: 91 | try: 92 | key = ast.literal_eval(key) 93 | except ValueError: 94 | pass 95 | 96 | try: 97 | value = ast.literal_eval(value) 98 | except (ValueError, SyntaxError): 99 | pass 100 | 101 | overwrite_parameters[key] = value 102 | 103 | for par in overwrite_parameters.keys(): 104 | setattr(cfg, par, overwrite_parameters[par]) 105 | 106 | return cfg 107 | 108 | 109 | -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import logging 4 | import torch.nn as nn # must stay here 5 | import torch.optim as optim # must stay here 6 | 7 | import os 8 | 9 | 10 | def apply_net(model, device, val_data_loader, n_val_samples=np.inf, verbose=True, code_processor=None, image_processor=None): 11 | 12 | """ 13 | Validation of the trained model 14 | 15 | :param device: device to put data on 16 | :param val_data_loader: data loader that provides images to be processed 17 | :param model: network model to be used. It should already be on the right device 18 | :param n_val_samples: if given uses that number of samples from data loader 19 | :param code_processor: callable class that can gather code vectors 20 | :param image_processor: receives input and output images (of auto-encoder) and computes and saves a reconstruction error 21 | 22 | """ 23 | 24 | def get_patch_coordinates(samples): 25 | # reads 'info' field of samples with stored x, y, w, h information in the string and parse it to dictionary 26 | # with the corresponding x, y, w, h fields, each of which contains a list of integers of the length equal 27 | # to the number of images in the batch 28 | 29 | if 'info' in samples.keys(): 30 | patch_info = samples['info'] 31 | else: 32 | return None 33 | 34 | x, y, w, h = [], [], [], [] 35 | for info in patch_info: 36 | 37 | temp = info.split(', ')[0].split(': ') 38 | assert temp[0] == 'x' 39 | x.append(int(temp[1])) 40 | 41 | temp = info.split(', ')[1].split(': ') 42 | assert temp[0] == 'y' 43 | y.append(int(temp[1])) 44 | 45 | temp = info.split(', ')[2].split(': ') 46 | assert temp[0] == 'w' 47 | w.append(int(temp[1])) 48 | 49 | temp = info.split(', ')[3].split(': ') 50 | assert temp[0] == 'h' 51 | h.append(int(temp[1])) 52 | 53 | return {'x': x, 'y': y, 'w': w, 'h': h} 54 | 55 | 56 | def extend_image_names(samples): 57 | 58 | if 'info' in samples.keys(): 59 | patch_info = samples['info'] 60 | else: 61 | patch_info = ['x: N, y: N, w: N, h: N'] * len(samples['image_name']) 62 | 63 | im_names = [] 64 | for im_path, info in zip(samples['image_name'], patch_info): 65 | info = info.split(',') 66 | x = info[0].split()[1] 67 | y = info[1].split()[1] 68 | w = info[2].split()[1] 69 | h = info[3].split()[1] 70 | 71 | image_name = os.path.basename(im_path) 72 | image_name = os.path.splitext(image_name)[0] 73 | im_names.append(image_name + '_x' + x + '_y' + y + '_w' + w + '_h' + h) 74 | 75 | return im_names 76 | 77 | init_model_status = model.training 78 | model.eval() 79 | 80 | val_loss = torch.tensor(0.0) 81 | diff_measure = [] 82 | n_processed = 0 83 | data_iter = iter(val_data_loader) 84 | with torch.no_grad(): 85 | counter = 0 86 | 87 | while n_processed < n_val_samples: 88 | 89 | try: 90 | samples = next(data_iter) 91 | except StopIteration: 92 | if n_val_samples != np.inf: 93 | logging.warning( 94 | 'Warning: finite iterator was provided, it was exhausted before required n_samples were generated') 95 | break 96 | 97 | images = samples['image'].to(device) 98 | 99 | output_dict = model(images) 100 | 101 | if 'rec_image' in output_dict: 102 | outputs = output_dict['rec_image'] 103 | else: 104 | outputs = None 105 | 106 | if 'codes' in output_dict: 107 | embeddings = output_dict['codes'] 108 | elif 'pooled_codes' in output_dict: 109 | embeddings = output_dict['pooled_codes'] 110 | else: 111 | embeddings = None 112 | 113 | n_processed += images.shape[0] 114 | counter += 1 115 | 116 | # collect computed codes if requested 117 | if code_processor is not None: 118 | 119 | im_names = samples['image_name'][:] 120 | im_labels = samples['string_label'][:] 121 | coordinates = get_patch_coordinates(samples) 122 | 123 | code_processor(embeddings, im_names, im_labels, coordinates) 124 | 125 | 126 | 127 | if image_processor is not None: 128 | if image_processor.save_images_path: 129 | 130 | im_names = extend_image_names(samples) 131 | image_processor(images, outputs, im_names) 132 | 133 | else: 134 | image_processor(images, outputs) 135 | 136 | 137 | if verbose: 138 | logging.info('{} validation patches were read'.format(n_processed)) 139 | 140 | # setting initial status of the model 141 | model.train(init_model_status) 142 | 143 | return None 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /configs/cfg_training_cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as _np 2 | import time as _time 3 | import os as _os 4 | 5 | string_time = _time.strftime("%y%m%d_%H%M%S") 6 | 7 | test_run = False # set to True if this is the fast test run, no results are saved. Allows checking for no run-time errors 8 | seed_number = 500 9 | 10 | # root folder for inputs (data) and outputs - defined here as the root folder of the code repository. 11 | # You can define your own arbitrary root path for data (outside of the project code) 12 | _prj_root = _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__))) 13 | #_prj_root = '/your root path to input and output data/' 14 | 15 | # path to trained model, tensorboard logs etc. You can define your own arbitrary path. 16 | path_to_results = _os.path.join(_prj_root, "train_results") 17 | 18 | # path to the training data. You can define your own arbitrary path. 19 | path_to_data = _os.path.join(_prj_root, "data/train") 20 | 21 | # Target combination of specie, organ, and staining, the tissue where anomalies should be found 22 | data_staining = "Masson" 23 | #data_staining = "HE" 24 | organ = "Liver" 25 | animal = "Mouse" 26 | #animal = "Rat" 27 | 28 | # training data - all healthy, but might be cluttered with low quality examples 29 | path_to_tissues = ( 30 | {'folder': path_to_data + "mt_mouse_brain", 'label': 'brain', 'ext': 'png'}, 31 | {'folder': path_to_data + "mt_mouse_heart", 'label': 'heart', 'ext': 'png'}, 32 | {'folder': path_to_data + "mt_mouse_kidney", 'label': 'kidney', 'ext': 'png'}, 33 | {'folder': path_to_data + "mt_mouse_lung", 'label': 'lung', 'ext': 'png'}, 34 | {'folder': path_to_data + "mt_mouse_pancreas", 'label': 'pancreas', 'ext': 'png'}, 35 | {'folder': path_to_data + "mt_mouse_spleen", 'label': 'spleen', 'ext': 'png'}, 36 | 37 | {'folder': path_to_data + "mt_mouse_liver", 'label': 'liver', 'ext': 'png'}, 38 | 39 | {'folder': path_to_data + "mt_rat_liver", 'label': 'liver_rat', 'ext': 'png'}, 40 | 41 | {'folder': path_to_data + "he_mouse_brain", 'label': 'he_brain', 'ext': 'png'}, 42 | {'folder': path_to_data + "he_mouse_kidney", 'label': 'he_kidney', 'ext': 'png'}, 43 | {'folder': path_to_data + "he_mouse_spleen", 'label': 'he_spleen', 'ext': 'png'}, 44 | {'folder': path_to_data + "he_mouse_pancreas", 'label': 'he_pancreas', 'ext': 'png'}, 45 | {'folder': path_to_data + "he_mouse_heart", 'label': 'he_heart', 'ext': 'png'}, 46 | {'folder': path_to_data + "he_mouse_lung", 'label': 'he_lung', 'ext': 'png'}, 47 | 48 | {'folder': path_to_data + "he_mouse_liver", 'label': 'he_liver', 'ext': 'png'}, 49 | 50 | {'folder': path_to_data + "he_rat_liver", 'label': 'he_liver_rat', 'ext': 'png'}, 51 | ) 52 | 53 | number_of_classes = len(set([loc['label'] for loc in path_to_tissues])) 54 | 55 | centerloss_classes = 'derived' 56 | #centerloss_classes = None # do not use centerloss 57 | #centerloss_classes = ('liver', 'heart') 58 | #centerloss_classes = 'liver' 59 | #centerloss_classes = 'he_liver' 60 | #centerloss_classes = ('liver', 'liver_rat') 61 | #centerloss_classes = 'all' # tighten distribution for all the classes 62 | if centerloss_classes == 'derived': 63 | if data_staining == "HE" and animal == "Mouse" and organ == "Liver": 64 | centerloss_classes = 'he_liver' 65 | elif data_staining == "Masson" and animal == "Mouse" and organ == "Liver": 66 | centerloss_classes = 'liver' 67 | elif data_staining == "HE" and animal == "Rat" and organ == "Liver": 68 | centerloss_classes = 'he_liver_rat' 69 | elif data_staining == "Masson" and animal == "Rat" and organ == "Liver": 70 | centerloss_classes = 'liver_rat' 71 | else: 72 | raise RuntimeError("such a combination of animal staining and organ is not yet implemented") 73 | 74 | mixup_classes = [['brain', 'heart', 'kidney', 'lung', 'pancreas', 'spleen', 'liver', 'liver_rat'], 75 | ['he_brain', 'he_kidney', 'he_spleen','he_pancreas', 'he_heart', 'he_liver', 'he_liver_rat', 'he_lung']] 76 | 77 | #mixup_classes = False # False or comment out 78 | 79 | 80 | #model_name = 'VGG_11' # class defined in pretrained_networks.py 81 | #model_name = 'DenseNet_121' # class defined in pretrained_networks.py 82 | #model_name = 'ResNet_18' 83 | #model_name = 'DenseNet_121_512' 84 | model_name = 'EfficientNet_B0_320' 85 | #model_name = 'EfficientNet_B0' 86 | #model_name = 'EfficientNet_B2_352' 87 | #model_name = 'EfficientNet_B2' 88 | #model_name = 'ConvNeXt' 89 | #model_name = 'VT_B_32' 90 | 91 | # arbitrary description of th experiment 92 | description = 'test' 93 | 94 | device_name = "cuda:0" 95 | 96 | num_workers = 3 # should be at least 4 times the number of GPUs used. Beyond that almost no speed gain 97 | 98 | # image transformations 99 | normalize_mean = (0.5788, 0.3551, 0.5655) 100 | normalize_std = (1, 1, 1) 101 | 102 | aug_brightness=(0.8, 1.2) 103 | aug_contrast=(0.8, 1.2) 104 | 105 | n_samples_train_per_class = 6920 # number of samples to be taken for each class for training + validation 106 | n_samples_val = _np.int(n_samples_train_per_class / 10 * number_of_classes) 107 | 108 | batch_size = 64 109 | num_epochs = 15 110 | patch_size = (256, 256) 111 | #patch_size = (224, 224) 112 | cl_weight = 1.0 # center-loss weight 113 | ce_weight = 1.0 # cross entropy weight 114 | model_lr= 0.001 # learning rate 115 | ce_momentum= 0.9 # cross entropy loss, momentum 116 | 117 | train_step_show = 10 118 | 119 | if test_run: 120 | mixup_classes = False 121 | n_samples_train_per_class = 500 122 | n_samples_val = _np.int(n_samples_train_per_class / 10 * number_of_classes) 123 | num_epochs = 2 124 | 125 | 126 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CenterLoss(nn.Module): 6 | """Center loss. 7 | 8 | Reference: 9 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | feat_dim (int): feature dimension. 14 | """ 15 | 16 | def __init__(self, num_classes=10, feat_dim=2, device=None, constrained_classes=None, mu=0.5): 17 | super(CenterLoss, self).__init__() 18 | self.num_classes = num_classes 19 | self.feat_dim = feat_dim 20 | self.device = device 21 | 22 | # random initialization of the centers. Will be later corrected based on input data 23 | centers = torch.randn(self.num_classes, self.feat_dim) 24 | if device: 25 | self.centers = centers.to(device) 26 | else: 27 | self.centers = centers 28 | 29 | self.centers_were_set = False 30 | self.mu = mu 31 | 32 | self.constrained_classes = constrained_classes 33 | 34 | # if constrained_classes is not None: 35 | # print('center loss is used for class {} only'.format(constrained_classes)) 36 | 37 | def forward(self, x, labels): 38 | """ 39 | Args: 40 | x: feature matrix with shape (batch_size, feat_dim). 41 | labels: ground truth labels with shape (batch_size). 42 | """ 43 | 44 | # update centers based on input x 45 | with torch.no_grad(): 46 | for label in range(self.num_classes): 47 | idx = labels == label 48 | class_available = torch.any(idx) 49 | 50 | if class_available: 51 | new_center = x[idx].mean(dim=0) 52 | 53 | if self.centers_were_set: 54 | # this corresponds to the paper gradient update for centers, but here I use explicite update without the need of pytorch learning the centers 55 | self.centers[label] = (1 - self.mu) * self.centers[label] + self.mu * new_center 56 | else: 57 | self.centers[label] = new_center 58 | self.centers_were_set = True 59 | 60 | batch_size = x.size(0) 61 | 62 | # ||c||**2 + ||I||**2-2C*I 63 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 64 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 65 | distmat.addmm_(x, self.centers.t(), beta=1, alpha=-2) 66 | 67 | # simplier way to compute distances - already existing in pytorch (includes unecessary square root) 68 | # But there is some problem with Nans, probably after taking a root from negative, at list at 1.2 version of pytorch 69 | # distmat = torch.cdist(x, self.centers)**2 70 | 71 | classes = torch.arange(self.num_classes).long() 72 | if self.device: classes = classes.to(self.device) 73 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 74 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 75 | 76 | n_per_class = mask.sum(dim=0) 77 | n_per_class[n_per_class == 0] = 1 # to prevent dividing by 0 78 | 79 | dist = distmat * mask.float() 80 | dist = dist.clamp(min=1e-12, max=1e+12) 81 | loss = dist.sum(dim=0) 82 | loss = loss / n_per_class.float() 83 | 84 | if self.constrained_classes: 85 | loss_within = loss[self.constrained_classes] 86 | else: 87 | #loss_within = loss.sum() 88 | loss_within = loss 89 | 90 | loss_within = loss_within.mean() 91 | 92 | # # # between class centers loss 93 | # mask = ~mask 94 | # n_per_class = mask.sum(dim=0) 95 | # n_per_class[n_per_class == 0] = 1 # to prevent dividing by 0 96 | # 97 | # dist = distmat * mask.float() 98 | # dist = dist.clamp(min=1e-12, max=1e+12) 99 | # loss = dist.sum(dim=0) 100 | # loss = loss / n_per_class.float() 101 | # 102 | # if self.constrained_classes: 103 | # loss_between = loss[self.constrained_classes] 104 | # else: 105 | # # loss_between = loss.sum() 106 | # loss_between = loss 107 | # 108 | # loss_between = loss_between.mean() 109 | 110 | #loss = loss_within / loss_between 111 | #loss = loss_within / math.sqrt(loss_between) 112 | #loss = loss_within - 0.2*loss_between 113 | loss = loss_within 114 | 115 | loss = loss/x.size(1) # making centerloss independent from the dimensionality of feature vectors (not needed when using ratio of distances) 116 | #loss = math.sqrt(loss) 117 | return loss 118 | 119 | class CenterLossOld(nn.Module): 120 | """Center loss. 121 | 122 | Reference: 123 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 124 | 125 | Args: 126 | num_classes (int): number of classes. 127 | feat_dim (int): feature dimension. 128 | """ 129 | 130 | def __init__(self, num_classes=10, feat_dim=2, device=None, one_class=None): 131 | super(CenterLoss, self).__init__() 132 | self.num_classes = num_classes 133 | self.feat_dim = feat_dim 134 | 135 | if device is None: 136 | self.device = None 137 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 138 | 139 | else: 140 | self.device = device 141 | self.centers = nn.Parameter((torch.randn(self.num_classes, self.feat_dim)).to(self.device)) 142 | 143 | self.one_class = one_class 144 | if one_class is not None: 145 | print('center loss is used for class {} only'.format(one_class)) 146 | 147 | def forward(self, x, labels): 148 | """ 149 | Args: 150 | x: feature matrix with shape (batch_size, feat_dim). 151 | labels: ground truth labels with shape (batch_size). 152 | """ 153 | batch_size = x.size(0) 154 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 155 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 156 | #distmat.addmm_(1, -2, x, self.centers.t()) 157 | distmat.addmm_(x, self.centers.t(), beta=1, alpha=-2) 158 | 159 | classes = torch.arange(self.num_classes).long() 160 | if self.device: classes = classes.to(self.device) 161 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 162 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 163 | 164 | dist = distmat * mask.float() 165 | dist = dist.clamp(min=1e-12, max=1e+12) 166 | loss = dist.sum(dim=0) 167 | 168 | if self.one_class: 169 | loss = loss[self.one_class] 170 | else: 171 | loss = loss.sum() 172 | 173 | loss = loss / batch_size 174 | 175 | return loss 176 | 177 | 178 | -------------------------------------------------------------------------------- /configs/cfg_anomaly_detector.py: -------------------------------------------------------------------------------- 1 | import os as _os 2 | import pickle as _pickle 3 | 4 | 5 | show_non_liver = False # used for t-sne visualization of non liver data from auxiliary task 6 | seed_number = 500 7 | 8 | # root folder for inputs and outputs 9 | _prj_root = _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__))) 10 | #_prj_root = '/your root path to input and output data/' 11 | 12 | # the evaluation results will be written here 13 | output_path = _os.path.join(_prj_root, 'test_results/') 14 | 15 | _train_data = _os.path.join(_prj_root, "data/train/") 16 | _test_data = _os.path.join(_prj_root, "data/test/") 17 | 18 | ad_model = "CNN_location" # anomaly detection model will be taken from the folder where CNNmodel is. If anomaly model with the same name as CNN model (except file extention) does not exist, it will be trained and saved at this location. 19 | #ad_model = "/BIHNmodels/EfficientNet_B0_320_Masson_Liver_Mouse_acc0.9755.pkl" 20 | #ad_model = "" # new anomaly detection model will be trained 21 | 22 | #cnn_model = '' # ImageNet pretrained 23 | #cnn_model = /BIHNmodels/EfficientNet_B0_320_Masson_Liver_Mouse_acc0.9755.pt" 24 | cnn_model = _os.path.join(_prj_root, 'train_results/30830_124948/EfficientNet_B0_320_Masson_Liver_Mouse_230830_124948_acc0.9339.pt') 25 | 26 | aug_saturation=(0.4, 1.6) 27 | aug_hue= (-0.05, 0.05) 28 | 29 | augmentation = True # if True augments (one class) training examples (useful for larger models like densenet) 30 | 31 | csv_liver_tissue_anomalies = 'liver_tissue_anomalies.csv' 32 | csv_liver_tissue_testnormals = 'liver_tissue_testnormals.csv' 33 | 34 | dev = "cuda:0" 35 | 36 | batch_size = 32 37 | 38 | #n_features_visualization = {'normal': 1000, 'NAS anomaly': 1000} # labels should correspond to the labels of datasets 39 | n_features_visualization = {'brain': 1000, 'heart': 1000, 'kidney': 1000, 'lung': 1000, 'pancreas': 1000, 'spleen': 1000, 40 | 'liver': 1000, 'liver_rat': 1000, 'he_brain': 1000, 'he_kidney': 1000, 'he_spleen': 1000, 41 | 'he_pancreas': 1000, 'he_heart': 1000, 'he_liver': 1000, 'he_lung': 1000, 'he_liver_rat': 1000, 42 | 'normal': 1000, 'NAS anomaly': 300} # labels should correspond to the labels of datasets 43 | 44 | test_normal_patches_max = None # None - use all 45 | test_anomaly_patches_per_class_max = None # None - use all 46 | visual_test_auxiliary_patches_per_class = 1000 47 | train_patches_for_train_max = None # None - use all 48 | 49 | # one-class classifier to be used 50 | clf = "svm.OneClassSVM(nu=0.1, kernel='rbf')" 51 | #clf = "svm.OneClassSVM(nu=0.05, kernel='rbf')" 52 | #clf = "LocalOutlierFactor(n_neighbors=30, novelty=True, metric='minkowski')" 53 | 54 | ext2save = 'png' 55 | 56 | save_images = True # write png images in addition to HALO annotations for all required WSI 57 | save_n_FN = 100 # number of falsely classified anomaly patch images to be saved 58 | save_n_FP = 100 # number of falsely classified normal patch examples to be saved 59 | 60 | description = "running with pre-trained BIHN models" 61 | 62 | if ad_model == 'CNN_location': 63 | ad_model, _ = _os.path.splitext(cnn_model) 64 | ad_model = ad_model + '.pkl' 65 | 66 | # getting parameters from training_cnn configuration file 67 | _path_to_configuration_pkl = _os.path.splitext(cnn_model)[0] + '_training_configuration.pkl' 68 | 69 | try: 70 | _pkl_dic = _pickle.load(open(_path_to_configuration_pkl, 'rb')) 71 | except FileNotFoundError: 72 | raise RuntimeError("Meta data cannot be read. Probable reason: the old version of the trained model did not include meta data, or no model was given") 73 | else: 74 | print('taking parameters from saved along the model meta data: {}'.format(_path_to_configuration_pkl)) 75 | n_trained_classes = int(_pkl_dic['number_of_classes']) 76 | 77 | normalize_mean = _pkl_dic['normalize_mean'] 78 | normalize_std = _pkl_dic['normalize_std'] 79 | 80 | aug_brightness = _pkl_dic['aug_brightness'] 81 | aug_contrast = _pkl_dic['aug_contrast'] 82 | 83 | model_architecture = _pkl_dic['model_name'] 84 | 85 | patch_size = _pkl_dic['patch_size'] 86 | 87 | animal = _pkl_dic['animal'] 88 | 89 | try: 90 | data_staining = _pkl_dic['data_staining'] 91 | except: 92 | if _pkl_dic['centerloss_classes'] == 'liver': 93 | data_staining = "Masson" 94 | elif _pkl_dic['centerloss_classes'] == 'he_liver': 95 | data_staining = "HE" 96 | elif _pkl_dic['centerloss_classes'] == 'he_liver_rat': 97 | data_staining = "HE" 98 | elif _pkl_dic['centerloss_classes'] == 'liver_rat': 99 | data_staining = "Masson" 100 | else: 101 | raise RuntimeError("Error: trained center loss class is not valid") 102 | 103 | # ----data-------- 104 | paths_non_liver_tissues_test = () 105 | if show_non_liver: 106 | paths_non_liver_tissues_test = ( 107 | {'folder': _train_data + "mt_mouse_brain", 'label': 'brain', 'ext': 'png'}, 108 | {'folder': _train_data + "mt_mouse_heart", 'label': 'heart', 'ext': 'png'}, 109 | {'folder': _train_data + "mt_mouse_kidney", 'label': 'kidney', 'ext': 'png'}, 110 | {'folder': _train_data + "mt_mouse_lung", 'label': 'lung', 'ext': 'png'}, 111 | {'folder': _train_data + "mt_mouse_pancreas", 'label': 'pancreas', 'ext': 'png'}, 112 | {'folder': _train_data + "mt_mouse_spleen", 'label': 'spleen', 'ext': 'png'}, 113 | 114 | {'folder': _train_data + "mt_mouse_liver", 'label': 'liver', 'ext': 'png'}, 115 | 116 | {'folder': _train_data + "mt_rat_liver", 'label': 'liver_rat', 'ext': 'png'}, 117 | 118 | {'folder': _train_data + "he_mouse_brain", 'label': 'he_brain', 'ext': 'png'}, 119 | {'folder': _train_data + "he_mouse_kidney", 'label': 'he_kidney', 'ext': 'png'}, 120 | {'folder': _train_data + "he_mouse_spleen", 'label': 'he_spleen', 'ext': 'png'}, 121 | {'folder': _train_data + "he_mouse_pancreas", 'label': 'he_pancreas', 'ext': 'png'}, 122 | {'folder': _train_data + "he_mouse_heart", 'label': 'he_heart', 'ext': 'png'}, 123 | {'folder': _train_data + "he_mouse_lung", 'label': 'he_lung', 'ext': 'png'}, 124 | 125 | {'folder': _train_data + "he_mouse_liver", 'label': 'he_liver', 'ext': 'png'}, 126 | 127 | {'folder': _train_data + "he_rat_liver", 'label': 'he_liver_rat', 'ext': 'png'}, 128 | 129 | ) 130 | 131 | 132 | paths_liver_anomaly_test = () # anomalies for quantitative test (labeled png) 133 | if data_staining == "Masson": 134 | paths_liver_anomaly_test = ( 135 | 136 | {'folder': _test_data + "NAFLD_anomaly_mt_mouse_liver", 'label': 'NAS anomaly', 'ext': 'png'}, 137 | 138 | ) 139 | elif data_staining == "HE": 140 | paths_liver_anomaly_test = ( 141 | 142 | {'folder': _test_data + "NAFLD_anomaly_he_mouse_liver", 'label': 'NAS anomaly', 'ext': 'png'}, 143 | 144 | ) 145 | 146 | paths_normal = () # healthy for training one class classifier 147 | if data_staining == "Masson" and animal == "Mouse": 148 | paths_normal = ( 149 | 150 | {'folder': _train_data + "mt_mouse_liver", 'label': 'normal_train', 'ext': 'png'}, 151 | 152 | ) 153 | elif data_staining == "HE" and animal == "Mouse": 154 | paths_normal = ( 155 | 156 | {'folder': _train_data + "he_mouse_liver", 'label': 'normal_train', 'ext': 'png'}, 157 | 158 | ) 159 | elif data_staining == "Masson" and animal == "Rat": 160 | paths_normal = ( 161 | {'folder': _train_data + "mt_rat_liver", 'label': 'normal_train', 'ext': 'png'}, 162 | ) 163 | elif data_staining == "HE" and animal == "Rat": 164 | paths_normal = ( 165 | {'folder': _train_data + "he_rat_liver", 'label': 'normal_train', 'ext': 'png'}, 166 | ) 167 | 168 | paths_normal_test = () # healthy for quantitative tests and visual test, those were not used for training 169 | if data_staining == "Masson": 170 | paths_normal_test = ( 171 | 172 | {'folder': _test_data + "normal_mt_mouse_liver", 'label': 'normal', 'ext': 'png'}, 173 | 174 | ) 175 | elif data_staining == "HE": 176 | paths_normal_test = ( 177 | 178 | {'folder': _test_data + "normal_he_mouse_liver", 'label': 'normal', 'ext': 'png'}, 179 | 180 | ) 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Image Representations for Anomaly Detection 2 | 3 | ------- 4 | 5 | This repository contains Pytorch implementation of **training image representations** and **performance evaluation** of the approach introduced in 6 | *I. Zingman, B. Stierstorfer, C. Lempp, F. Heinemann. ["Learning image representations for anomaly detection: application to discovery of 7 | histological alterations in drug development", Medical Image Analysis, 2024.](https://www.sciencedirect.com/science/article/pii/S1361841523003274#da1)* 8 | It is also available on [ArXiv](https://arxiv.org/abs/2210.07675) and temporary has a free access on [Elsevier](https://authors.elsevier.com/a/1iIdF4rfPmE0%7EA). 9 | 10 | The paper develops a method for anomaly detection in whole slide images of stained tissue samples in order to routinely screen histopathological data for abnormal alterations in tissue. 11 | 12 | ![GitHub Logo](docs/tox_pattern.png) 13 | 14 | **Figure** above shows detection of adverse drug reactions by the Boehinger Ingelheim Histological Network (BIHN) based anomaly detection. **A:** The developed Anomaly Detection (AD) method detects induced tissue alterations in 15 | the liver of mouse after administration an experimental compound. The fraction of abnormal tiles increases with the the dosage of the compound. The 16 | compound was previously found to have toxic side effects in toxicological screening by pathologists. Each dot corresponds to a single Whole Slide Image (WSI). Three arrows 17 | correspond to three WSI examples given in **B**. Stars on the top of the graph show statistical significance of the change compared to the mean of control 18 | group. **B:** Examples of detected anomalies. In the control group (left image) blood and a few other not pathological structures result in a low level of false 19 | positives. Detections in compound treated groups (two right images) correspond to pathological alterations and were confirmed by a pathologist. 20 | 21 | ------ 22 | **Requirements** 23 | 24 | ```PyTorch```, ```NumPy```, ```Pillow```, ```scikit-learn``` 25 | 26 | The code in the repository was tested under ```Python 3.9``` with GPU 11GB and packages' listed in the ```requirements.txt```. 27 | It, however, should also run with earlier Python versions and smaller GPU memory. 28 | 29 | **Experiments** (training image representations and performance evaluation) 30 | 31 | ![GitHub Logo](docs/Scheme_extended.png) 32 | 33 | **Setting up dataset** 34 | 35 | * The training dataset with normal tissue of different species, organs, and staining can be downloaded from ```data/train/``` folder from https://osf.io/gqutd/. 36 | This dataset was used for training image representations. 37 | 38 | * The evaluation dataset with normal mouse liver tissue and mouse tissue with Non-Alcoholic Fatty Liver Disease (NAFLD) can 39 | be downloaded from ```data/test/``` folder from https://osf.io/gqutd/ 40 | 41 | * Due to large sizes of zip files it is recommended to download each zip file separately. 42 | 43 | * Create the folder structure shown below under the root folder of your repository with the cloned code or in any other location. 44 | In the last case set ```_prj_root``` variable to the chosen location in ```configs/cfg_training_cnn.py``` and ```configs/cfg_anomaly_detector.py``` configuration files. 45 | We use *.py configuration files, not e.g. yaml, which allows more flexibility and is convenient for prototyping. 46 | 47 | * Unzip downloaded data files to the corresponding folders within the created folders structure 48 | 49 | * If you want to use pre-trained models (instead of training yourself) 50 | * download them from ```trained models/``` folder from https://osf.io/gqutd/. 51 | * unzip and save ```EfficientNet_B0_320_HE_Liver_Mouse_acc0.9762.pt```, ```EfficientNet_B0_320_Masson_Liver_Mouse_acc0.9755.pt``` CNN models, 52 | the corresponding ```EfficientNet_B0_320_Masson_Liver_Mouse_acc0.9755.pkl``` and ```EfficientNet_B0_320_HE_Liver_Mouse_acc0.9762.pkl``` 53 | anomaly detection models (One-cass SVM classifiers), and the corresponding ```EfficientNet_B0_320_HE_Liver_Mouse_acc0.9762_training_configuration.pkl``` 54 | and ```EfficientNet_B0_320_Masson_Liver_Mouse_acc0.9755_training_configuration.pkl``` 55 | configuration files into e.g. ```BIHN_models``` folder under the project root. 56 | 57 | 58 | **Folders structure for project's input** 59 | ``` 60 | . 61 | ├── data 62 | ├── test 63 | │ ├── NAFLD_anomaly_he_mouse_liver 64 | │ ├── NAFLD_anomaly_mt_mouse_liver 65 | │ ├── normal_he_mouse_liver 66 | │ └── normal_mt_mouse_liver 67 | └── train 68 | ├── he_mouse_brain 69 | ├── he_mouse_heart 70 | ├── he_mouse_kidney 71 | ├── he_mouse_liver 72 | ├── he_mouse_lung 73 | ├── he_mouse_pancreas 74 | ├── he_mouse_spleen 75 | ├── he_rat_liver 76 | ├── mt_mouse_brain 77 | ├── mt_mouse_heart 78 | ├── mt_mouse_kidney 79 | ├── mt_mouse_liver 80 | ├── mt_mouse_lung 81 | ├── mt_mouse_pancreas 82 | ├── mt_mouse_spleen 83 | └── mt_rat_liver 84 | 85 | 86 | ``` 87 | 88 | **Training** 89 | 90 | * Set variable ```data_staining``` in ```configs/cfg_training_cnn.py``` to either ```Masson``` (Massosn's Trichrome staining) or ```HE```(H&E staining) values, which will 91 | adjust training image representations for anomaly detection in images of tissue stained correspondingly. If you store the training data in your own location, update 92 | ```path_to_data``` accordingly. 93 | * Run ```python train_cnn.py --config configs/cfg_training_cnn.py``` 94 | * The code generates ```train_results/stamp``` folder with trained models (models for each epoch and the best one), confusion matrix, configuration 95 | and log files, where *stamp* is a unique number that is set for each run. You can redefine the output 96 | folder in the configuration file ```configs/cfg_training_cnn.py```, if needed, by updating ```path_to_results```. 97 | 98 | **Evaluation** 99 | 100 | * Set ```cnn_model``` variable in ```configs/cfg_anomaly_detector.py``` to the relative to root path to the trained CNN model, which was 101 | generated in folder ```train_results/stamp/model_name.pt``` during the training step above. Alternatively, you can set an arbitrary path to the downloaded from https://osf.io/gqutd pre-trained CNN model ```*.pt```. 102 | * If you've downloaded an anomaly model ```*.pkl``` from https://osf.io/gqutd/, set ```ad_model``` to its location. Alternatively, if you want to train anomaly model on your own (once-class classifier), set ```ad_model``` to empty string ```""``` or to ```"CNN_location"```. 103 | * Run ```python anomaly_detector.py --config configs/cfg_anomaly_detector.py```. The code will output evaluation results to ```test_results``` folder. 104 | If anomaly model (once-class classfier) was trained, it will be saved to the folder where CNN model is. 105 | 106 | *Expected performance of anomaly detection with BIHN models* 107 | 108 | | Staining | Balanced accuracy | AU-ROC | F1 score | 109 | |------------------|:-------------------:|:--------:|:---------------------:| 110 | | H&E | 94.20% | 97.33% | 94.09% | 111 | | Masson Trichrome | 97.51% | 99.03% | 97.51% | 112 | 113 | * To evaluate other algorithms from [Anomalib library](https://github.com/openvinotoolkit/anomalib) on our dataset with NAFLD pathology, 114 | please consult Anomalib section *Custom Dataset*. Particularly, one needs to set appropriate paths in yaml configuration files of the chosen method located at ```anomalib_root/anomalib/models/method/config_file.yaml```. 115 | The paths fields to be set in yaml are ```normal_dir```, ```abnormal_dir```, ```normal_test_dir```, which should point to ```./data/train/*mouse_liver/```, ```./data/test/NAFLD_anomaly_*_mouse_liver```, ```./data/test/normal_*_mouse_liver``` data paths correspondingly. 116 | The star in paths refers to a particular staining type, ```mt``` or ```he``` you want to experiment with. The ```task``` field should be set to "classification". 117 | * To evaluate [DPA](https://github.com/ninatu/anomaly_detection) appraoch we adapted ```Camalyon16Dataset``` class, reading images from NAFLD dataset. 118 | We obtained our best results for DPA using ```camelyon16``` ```wo_pg_unsupervsed``` default configuration with the following parameters tuned ```inner_dims: 16, latent_dim:16``` (for both decoder and encoder, same values for all layers as in the default configuration), ```initial_image_res:256, max_image_res:256, crop_size: 256```. 119 | Batch size was reduced to 64 to be able to run on 256x256 size images. 120 | 121 | 122 | **Use of pretrained BIHN models in your own projects** 123 | 124 | In oder to use pretrained BIHN models (*.pt files that can be downloaded from https://osf.io/gqutd/) to generate 125 | feature representations of histopathological images (Masson or H&E) for your own tasks, you can consult the code example in ```model_use_example.py```. 126 | 127 | **Citing** 128 | ```markdown 129 | @article{zingman2022anomaly, 130 | title={Learning image representations for anomaly detection: application to discovery of histological alterations in drug development}, 131 | author={Igor Zingman and Birgit Stierstorfer and Charlotte Lempp and Fabian Heinemann}, 132 | year={2022}, 133 | journal={CoRR}, 134 | volume={abs/2210.07675}, 135 | eprinttype = {arXiv}, 136 | url = {https://arxiv.org/abs/2210.07675} 137 | } 138 | ``` 139 | ```markdown 140 | @online{NAFLD_dataset, 141 | author = {Igor Zingman and Birgit Stierstofer and Fabian Heinemann}, 142 | title = {{NAFLD} pathology and healthy tissue samples}, 143 | year = {2022}, 144 | url = {https://osf.io/gqutd/}, 145 | } 146 | ``` 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /utils/data_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import os 5 | import pickle 6 | import time 7 | import logging 8 | 9 | 10 | def set_seed(seed): 11 | random.seed(seed) 12 | os.environ['PYTHONHASHSEED'] = str(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | 16 | 17 | class CodesProcessor: 18 | 19 | def __init__(self, cached_codes_path=None, cache_thr=3.0): 20 | 21 | logging.info('code processor is initialized with {} GB maximum GPU memmory threshold'.format(cache_thr)) 22 | self.codes = None 23 | self.patch_coordinates = {'x': [], 'y': [], 'w': [], 'h': []} 24 | self.image_names = [] 25 | self.image_labels = [] 26 | 27 | if cached_codes_path is not None: 28 | self.cached_codes_path = cached_codes_path + '/cached_codes_temp' + time.strftime("%H%M%S_%d%m%y") + '.pt' 29 | else: 30 | self.cached_codes_path = None 31 | 32 | self.cache_status = False 33 | self.cache_thr = cache_thr 34 | #self.avg_pool = nn.AvgPool2d(kernel_size=average_pooling_size) 35 | 36 | 37 | def __call__(self, codes, image_names, image_labels, coordinates=None): 38 | 39 | #codes = self.avg_pool(codes) 40 | 41 | if self.codes is None: 42 | self.codes = codes 43 | else: 44 | self.codes = torch.cat((self.codes, codes), dim=0) 45 | 46 | if coordinates is not None: 47 | self.patch_coordinates['x'].extend(coordinates['x']) 48 | self.patch_coordinates['y'].extend(coordinates['y']) 49 | self.patch_coordinates['w'].extend(coordinates['w']) 50 | self.patch_coordinates['h'].extend(coordinates['h']) 51 | elif self.patch_coordinates is not None: 52 | self.patch_coordinates = None # first call 53 | assert len(self.image_names) == 0 54 | 55 | self.image_names.extend(image_names) 56 | self.image_labels.extend(image_labels) 57 | 58 | gb_taken = self.codes.element_size() * self.codes.nelement() / 1024 / 1024 / 1024 59 | if self.cached_codes_path is not None and gb_taken > self.cache_thr: 60 | logging.info('{} GB memmory taken, {} patch codes gathered, saving to disk to prevent GPU crash'.format(gb_taken, self.codes.shape[0])) 61 | self.backup_codes() 62 | 63 | def backup_codes(self): 64 | logging.info('saving codes') 65 | 66 | if self.cache_status: 67 | cached_codes = torch.load(self.cached_codes_path, map_location='cpu') 68 | cached_codes = torch.cat((self.codes.cpu(), cached_codes), dim=0) 69 | torch.save(cached_codes, self.cached_codes_path) 70 | else: 71 | torch.save(self.codes.cpu(), self.cached_codes_path) 72 | 73 | self.codes = None 74 | self.cache_status = True 75 | 76 | logging.info('codes were saved') 77 | 78 | def get_codes(self): 79 | 80 | if self.cache_status: 81 | 82 | codes = torch.load(self.cached_codes_path, map_location='cpu') 83 | codes = torch.cat((self.codes.cpu(), codes), dim=0) 84 | else: 85 | codes = self.codes.cpu() 86 | 87 | codes = codes.numpy() 88 | codes = np.squeeze(codes) 89 | 90 | return codes 91 | 92 | def get_patch_coordinates(self): 93 | return self.patch_coordinates 94 | 95 | def get_image_names(self): 96 | return self.image_names 97 | 98 | def get_image_labels(self): 99 | return self.image_labels 100 | 101 | 102 | 103 | def read_features(file_name, paths, label_names, n_codes=None, group_method=None, group_size=None, arrange=None): 104 | """ 105 | Reads and groups features in different ways. 106 | 107 | :param file_name: name of the file with extension that contains latent codes 108 | :param paths: full paths to the files, each of which contains codes for a specific category 109 | :param n_codes: number of codes to be sampled for every category - makes the number of samples equal for each category (even when it is not so in the data) 110 | :param group_method: 'mean', 'concat', or 'max' - a way to combine the data from patch codes to image codes - requires group_size parameter 111 | :param group_size: the number of codes to be grouped together 112 | :param label_names: list of string labels corresponding to the data in paths 113 | :param arrange: If provided, rearranges the feature vectors, such that they will be continuous sequence for particular spatial windows in the original tensor. 114 | 'arrange' is a dictionary with fields win_size and tens_size. win_size is a spatial side size of square patch with channel features. 115 | tens_size is a spatial side size of the whole tensor of features 116 | :return features (n_samples x n_features) and sample labels - indexes of 'paths' or 'label_names' 117 | """ 118 | 119 | assert len(paths) == len(label_names), "length of paths and label_names lists must be same" 120 | 121 | def rearrange(features, win_size, tensor_size): 122 | assert tensor_size % win_size == 0, "tensor_size must be multiple of win_size" 123 | assert features.shape[0] % (tensor_size*tensor_size) == 0, "number of feature vectors maust be a multiple of tensor_size squared" 124 | 125 | n_tensors = features.shape[0] / tensor_size 126 | #n_win = np.int(tensor_size / win_size) 127 | 128 | cur_index = 0 129 | #check_up = [] 130 | rearranged_features = np.zeros_like(features) 131 | for tens in range(0, features.shape[0], tensor_size*tensor_size): 132 | for y_block in range(0, tensor_size, win_size): 133 | for x in range(tensor_size): 134 | for y in range(y_block, win_size + y_block): 135 | ind = x *tensor_size + y + tens 136 | rearranged_features[cur_index] = features[ind, :] 137 | #check_up.append(ind) 138 | 139 | cur_index += 1 140 | 141 | assert features.shape[0] == cur_index, "tensor_size does not match the number of feature vectors" 142 | #return rearranged_features, check_up 143 | return rearranged_features 144 | 145 | 146 | 147 | features = None 148 | labels = [] 149 | for n, path in enumerate(paths): 150 | 151 | path = os.path.join(path, file_name) 152 | with open(path, 'rb') as f: 153 | cur_features = pickle.load(f) 154 | 155 | if arrange: 156 | cur_features = rearrange(cur_features, arrange['win_size'], arrange['tens_size']) 157 | 158 | feat_vec_len = cur_features.shape[-1] 159 | if group_method == 'concat': 160 | cur_features = np.reshape(cur_features, (-1, group_size*feat_vec_len), order='C') 161 | if group_method == 'concatabs': 162 | cur_features = np.reshape(np.abs(cur_features), (-1, group_size*feat_vec_len), order='C') 163 | elif group_method == 'abs': 164 | cur_features = np.abs(cur_features) 165 | elif group_method == 'meanabs': 166 | cur_features = np.reshape(cur_features, (-1, group_size, feat_vec_len), order='C') 167 | cur_features = np.mean(np.abs(cur_features), axis=1) 168 | elif group_method == 'mean': 169 | cur_features = np.reshape(cur_features, (-1, group_size, feat_vec_len), order='C') 170 | cur_features = np.mean(cur_features, axis=1) 171 | elif group_method == 'maxabs': 172 | cur_features = np.reshape(cur_features, (-1, group_size, feat_vec_len), order='C') 173 | cur_features = np.max(np.abs(cur_features), axis=1) 174 | elif group_method == 'minabs': 175 | cur_features = np.reshape(cur_features, (-1, group_size, feat_vec_len), order='C') 176 | cur_features = np.min(np.abs(cur_features), axis=1) 177 | elif group_method == 'norm': 178 | cur_features = np.reshape(cur_features, (-1, group_size, feat_vec_len), order='C') 179 | cur_features = np.linalg.norm(cur_features, axis=1) 180 | elif group_method == 'meanstd': 181 | cur_features = np.reshape(cur_features, (-1, group_size, feat_vec_len), order='C') 182 | cur_1 = np.mean(cur_features, axis=1) 183 | cur_2 = np.std(cur_features, axis=1) 184 | cur_features = np.concatenate((cur_1, cur_2), axis=1) 185 | 186 | # use only part of feature vectors to reduce running time 187 | if n_codes: 188 | if n_codes > cur_features.shape[0]: 189 | logging.info('the number of requested feature vectors for {} does not exist'.format(label_names[n])) 190 | logging.info('using only {} codes'.format(cur_features.shape[0])) 191 | n_codes_cur = cur_features.shape[0] 192 | else: 193 | logging.info('using all {} requested codes for {}'.format(n_codes, label_names[n])) 194 | n_codes_cur = n_codes 195 | 196 | #ind = np.random.randint(0, high=n_codes_cur, size=n_codes_cur) 197 | ind = np.random.permutation(cur_features.shape[0])[:n_codes_cur] 198 | cur_features = cur_features[ind, :] 199 | 200 | else: 201 | logging.info('using all {} codes for {}'.format(cur_features.shape[0], label_names[n])) 202 | 203 | cur_labels = [n] * cur_features.shape[0] 204 | 205 | if features is not None: 206 | features = np.concatenate((features, cur_features), axis=0) 207 | else: # first run 208 | features = cur_features 209 | 210 | labels.extend(cur_labels) 211 | 212 | return features, labels 213 | 214 | 215 | 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /train_cnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Oct 15 14:42:19 2019 5 | 6 | @author: zingman 7 | """ 8 | 9 | # ----------------------------------------------------------------------------- 10 | # classifies histological images into two categories healthy/non-healthy using 11 | # tiled patch images saved on the hard drive 12 | # ----------------------------------------------------------------------------- 13 | 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | 19 | from torch.utils.data import DataLoader 20 | from torchvision import transforms 21 | from utils.image_dataset_reader import HistImagesDataset, samples_per_location_from_samples_per_class 22 | import torch.optim as optim 23 | import matplotlib.pyplot as plt 24 | from tqdm import tqdm 25 | from torch.utils.tensorboard import SummaryWriter 26 | import shutil 27 | 28 | from time import perf_counter 29 | import os 30 | import datetime 31 | import copy 32 | import time 33 | from utils.data_processing import set_seed 34 | 35 | import models.pretrained_networks as HistoModel 36 | from utils.timed_input import limited_time_input 37 | import logging 38 | from models.losses import CenterLoss 39 | import argparse 40 | 41 | from utils.imports import show_configuration, save_configuration, \ 42 | pickle_configuraton_as_dictionary, update_configuration 43 | 44 | from sklearn.metrics import confusion_matrix 45 | import seaborn as sn 46 | import pandas as pd 47 | 48 | #plt.ion() 49 | 50 | def evaluate_model(val_loader, loss_fun, n_samples_val, model_skel): 51 | 52 | str_labels = val_loader.dataset.str_labels 53 | # string labels are encoded by sequential integer labels beginning from 0 54 | integer_labels = list(range(len(str_labels))) 55 | 56 | final_val_loss, final_val_acc, conf_mat = validation(model_skel, iter(val_loader), loss_fun, n_samples=n_samples_val, confusion=True, integer_labels=integer_labels) 57 | print('validation accuracy of the model: {:.4f}, validation loss of the model: {:.4f}'.format(final_val_acc, 58 | final_val_loss)) 59 | # visualize confusion matrix 60 | df_cf = pd.DataFrame(conf_mat, index=str_labels, columns=str_labels) 61 | plt.figure(figsize=(15, 10)) 62 | sn.heatmap(df_cf, annot=True) 63 | 64 | plt.ylabel('True label') 65 | plt.xlabel('Predicted label') 66 | plt.title('accuracy: {}'.format(final_val_acc)) 67 | 68 | def validation(model, data_iter, loss_fun, n_samples=None, centerloss_fun=None, ce_weight=1.0, cl_weight=0.0, confusion=False, integer_labels=None): 69 | """ 70 | :param model: model to be validated 71 | :param data_iter: data iterator (use iter(data_set) or iter(data_loader)). The iterator can be infinite. 72 | If iterator is finite and it is exhausted before n_samples were generated, warning is printed out 73 | :param n_samples: n_samples to be used for validating the model 74 | :param confusion: calculate also confusion matrix 75 | :return: accuracy, loss values, and confusion matrix if requested 76 | """ 77 | 78 | if confusion and integer_labels is None: 79 | print("Error: when confusion matrix need to be calculated, integer labels must be provided") 80 | raise 81 | 82 | if not n_samples: 83 | n_samples = np.inf 84 | 85 | model.eval() 86 | 87 | val_loss = torch.tensor(0.0, device=device) 88 | n_pred = torch.tensor(0.0, device=device) 89 | n_processed = 0.0 90 | conf_mat = None 91 | with torch.no_grad(): 92 | counter = 0 93 | # for samples in data_loader: 94 | while n_processed < n_samples: 95 | 96 | try: 97 | samples = next(data_iter) 98 | except StopIteration: 99 | logging.warning( 100 | 'finite iterator with {} images was provided, it was exhausted before required {} were generated'.format(n_processed, n_samples)) 101 | break 102 | 103 | images = samples['image'].to(device) 104 | int_labels = samples['label'] 105 | n_processed += len(int_labels) 106 | int_labels = int_labels.to(device) 107 | 108 | outputs = model(images) 109 | pedictions = outputs['categories'] 110 | entropy_loss = loss_fun(pedictions, int_labels) 111 | 112 | if (cl_weight != 0) and (centerloss_fun is not None): 113 | features = outputs['pooled_codes'] 114 | center_loss = centerloss_fun(features, int_labels) 115 | loss = ce_weight * entropy_loss + cl_weight * center_loss 116 | else: 117 | loss = entropy_loss 118 | 119 | predicted_values = torch.max(pedictions, 1)[1] 120 | n_pred += torch.sum(predicted_values == int_labels) 121 | 122 | val_loss += loss 123 | 124 | if confusion: 125 | 126 | if counter > 0: 127 | conf_mat += confusion_matrix(int_labels.cpu().numpy(), predicted_values.cpu().numpy(), labels=integer_labels) 128 | else: 129 | conf_mat = confusion_matrix(int_labels.cpu().numpy(), predicted_values.cpu().numpy(), labels=integer_labels) 130 | 131 | counter += 1 132 | 133 | loss_value = val_loss.item() / counter 134 | accuracy = n_pred.item() / n_processed 135 | return loss_value, accuracy, conf_mat 136 | 137 | 138 | def train_epoch(model, data_loader, optimizer, loss_fun, tb_writer, iter_step_show=10, centerloss=None, ce_weight=1.0, cl_weight=0.0): 139 | 140 | try: 141 | len_dataset = len(data_loader) 142 | except TypeError: 143 | len_dataset = float("inf") 144 | 145 | model.train() 146 | 147 | train_loss = 0.0 148 | counter = 0 149 | progress = tqdm(data_loader, desc="Batch loss: ", total=len_dataset, disable=False) 150 | for samples in progress: 151 | 152 | try: 153 | train_epoch.iteration += 1 154 | except AttributeError: 155 | train_epoch.iteration = 0 156 | 157 | images = samples['image'].to(device) 158 | 159 | int_labels = samples['label'] 160 | int_labels = int_labels.to(device) 161 | 162 | optimizer.zero_grad() 163 | 164 | outputs = model(images) 165 | predictions = outputs['categories'] 166 | entropy_loss_criterion = loss_fun(predictions, int_labels) 167 | 168 | if (cl_weight != 0) and (centerloss is not None): 169 | features = outputs['pooled_codes'] 170 | center_loss_criterion = centerloss(features, int_labels) 171 | loss = ce_weight * entropy_loss_criterion + cl_weight * center_loss_criterion 172 | else: 173 | loss = entropy_loss_criterion 174 | 175 | loss.backward() 176 | optimizer.step() 177 | 178 | train_loss += loss.item() 179 | 180 | if counter % iter_step_show == 0: 181 | progress.set_description("Batch loss: {:.4f}".format(loss.item())) 182 | tb_writer.add_scalar('Batch_loss', loss.item(), train_epoch.iteration) 183 | 184 | counter += 1 185 | 186 | epoch_loss = train_loss / counter 187 | 188 | return epoch_loss 189 | 190 | 191 | #----------------------------------------------------------------------------------------------------------------------- 192 | # ----------------------------------------main code--------------------------------------------------------------------- 193 | #----------------------------------------------------------------------------------------------------------------------- 194 | 195 | assert torch.cuda.is_available(), "GPU is not available" 196 | 197 | logging.basicConfig(level=logging.INFO) 198 | parser = argparse.ArgumentParser(description='training cnn') 199 | cfg = update_configuration(parser) 200 | 201 | if not cfg.description: 202 | cfg.description = limited_time_input("Please enter description of an experiment...", 60) 203 | 204 | print('\n') 205 | show_configuration(cfg) 206 | print('\n') 207 | 208 | device = torch.device(cfg.device_name) 209 | 210 | set_seed(cfg.seed_number) 211 | 212 | t_start = perf_counter() 213 | 214 | # defining paths and creating output directories 215 | output_path = cfg.path_to_results + '/' + cfg.string_time 216 | output_path_tb = cfg.path_to_results + '_tb/' + cfg.string_time 217 | os.makedirs(output_path_tb) 218 | os.makedirs(output_path) 219 | tb_writer = SummaryWriter(output_path_tb) 220 | 221 | file_handler = logging.FileHandler(os.path.join(output_path, 'training_cnn.log')) 222 | logging.root.addHandler(file_handler) 223 | 224 | tr_resize = transforms.Resize(cfg.patch_size) 225 | 226 | tr_normalize = transforms.Normalize(mean=cfg.normalize_mean, std=cfg.normalize_std) 227 | bc_jitter = transforms.ColorJitter(brightness=cfg.aug_brightness, contrast=cfg.aug_contrast) 228 | 229 | transforms_seq_train = transforms.Compose([transforms.CenterCrop(cfg.patch_size), bc_jitter, transforms.ToTensor(), tr_normalize]) 230 | transforms_seq_val = transforms.Compose([transforms.CenterCrop(cfg.patch_size), transforms.ToTensor(), tr_normalize]) 231 | 232 | n_samples_train_per_location = samples_per_location_from_samples_per_class(*cfg.path_to_tissues, samples_per_class=cfg.n_samples_train_per_class) 233 | images_dataset = HistImagesDataset(*cfg.path_to_tissues, n_samples=n_samples_train_per_location, transform=transforms_seq_train, repetition=True) 234 | 235 | n_classes = len(images_dataset.str_labels) 236 | assert cfg.number_of_classes == n_classes 237 | 238 | images_validation, images_train = images_dataset.split_set(cfg.n_samples_val, transform_validation=transforms_seq_val) 239 | 240 | if getattr(cfg, 'mixup_classes', False): 241 | images_train.prepare_mixup(cfg.mixup_classes) 242 | 243 | train_loader = DataLoader(images_train, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=True, pin_memory=False) 244 | val_loader = DataLoader(images_validation, batch_size=cfg.batch_size, num_workers=cfg.num_workers, pin_memory=False) 245 | 246 | NetworkModel = getattr(HistoModel, cfg.model_name) 247 | model = NetworkModel(n_classes=n_classes, dev=device).to(device) 248 | 249 | try: 250 | logging.info("training dataset consist of {} images".format(len(images_train))) 251 | logging.info("validation dataset consist of {} images".format(len(images_validation))) 252 | except TypeError: 253 | logging.info('iterable dataset is used, the size cannot be determined a priori') 254 | 255 | 256 | # --------------------------- 257 | # -----------training ------ 258 | # ---------------------------- 259 | 260 | if cfg.centerloss_classes is not None: 261 | if cfg.centerloss_classes == 'all': 262 | chosen_class_int = None 263 | elif isinstance(cfg.centerloss_classes, (list, tuple)): 264 | chosen_class_int = [] 265 | for lb in cfg.centerloss_classes: 266 | chosen_class_int.append(images_train.get_int_label(lb)) 267 | elif isinstance(cfg.centerloss_classes, str): 268 | chosen_class_int = images_train.get_int_label(cfg.centerloss_classes) 269 | else: 270 | assert False, 'centerloss_classes in configuration file is not valid' 271 | 272 | metric_loss = CenterLoss(num_classes=n_classes, feat_dim=model.fv_length(), device=device, constrained_classes=chosen_class_int, mu=0.5) 273 | logging.info('chosen classes in centerloss: {} ({})'.format(cfg.centerloss_classes, chosen_class_int)) 274 | else: 275 | metric_loss = None 276 | 277 | loss_fun = nn.CrossEntropyLoss() 278 | optimizer = optim.SGD(model.parameters(), lr=cfg.model_lr, momentum=cfg.ce_momentum) 279 | 280 | best_val_acc = 0.0 281 | best_model_wts = copy.deepcopy(model.state_dict()) 282 | 283 | n_trained_param = model.count_parameters() 284 | 285 | logging.info("number of parameters to be trained is {}".format(n_trained_param)) 286 | 287 | t_start_epoch = perf_counter() 288 | for epoch in range(cfg.num_epochs): 289 | 290 | logging.info('Epoch {}/{}, training ...'.format(epoch, cfg.num_epochs - 1)) 291 | epoch_training_av_loss = train_epoch(model, train_loader, optimizer, loss_fun, tb_writer, cfg.train_step_show, centerloss=metric_loss, ce_weight=cfg.ce_weight, cl_weight=cfg.cl_weight) 292 | with torch.cuda.device(cfg.device_name): # by default cuda:0 is used 293 | torch.cuda.empty_cache() 294 | 295 | logging.info('validating on a separate validation dataset...') 296 | epoch_val_loss, epoch_val_acc, _ = validation(model, iter(val_loader), loss_fun, n_samples=cfg.n_samples_val, centerloss_fun=metric_loss, ce_weight=cfg.ce_weight, cl_weight=cfg.cl_weight) 297 | with torch.cuda.device(cfg.device_name): # by default cuda:0 is used 298 | torch.cuda.empty_cache() 299 | 300 | logging.info('validating on train dataset...') 301 | epoch_train_loss, epoch_train_acc, _ = validation(model, iter(train_loader), loss_fun, n_samples=cfg.n_samples_val, centerloss_fun=metric_loss, ce_weight=cfg.ce_weight, cl_weight=cfg.cl_weight) 302 | with torch.cuda.device(cfg.device_name): # by default cuda:0 is used 303 | torch.cuda.empty_cache() 304 | 305 | if epoch_val_acc > best_val_acc: 306 | best_val_acc = epoch_val_acc 307 | best_model_wts = copy.deepcopy(model.state_dict()) 308 | 309 | tb_writer.add_scalar('average_training_loss', epoch_training_av_loss, epoch) 310 | tb_writer.add_scalar('validation_loss', epoch_val_loss, epoch) 311 | tb_writer.add_scalar('validation_accuracy', epoch_val_acc, epoch) 312 | tb_writer.add_scalar('training_loss', epoch_train_loss, epoch) 313 | tb_writer.add_scalar('training_accuracy', epoch_train_acc, epoch) 314 | 315 | logging.info('Training loss: {:.4f}'.format(epoch_train_loss)) 316 | logging.info('Training accuracy: {:.4f}'.format(epoch_train_acc)) 317 | logging.info('Validation loss: {:.4f}'.format(epoch_val_loss)) 318 | logging.info('Validation accuracy: {:.4f}'.format(epoch_val_acc)) 319 | logging.info('-' * 10) 320 | 321 | if epoch == 0: 322 | save_configuration(cfg, os.path.join(output_path, 'training_cnn_configuration.txt')) 323 | with open(os.path.join(output_path, 'training_cnn_configuration.txt'), 'a') as fh: 324 | print(model, file=fh) 325 | 326 | # saving the resulted model 327 | model_file_name = cfg.model_name + '_epoch: {}_acc{:.4f}'.format(epoch, epoch_val_acc) 328 | saved_model_path_full = os.path.join(output_path, model_file_name + '.pt') 329 | torch.save(model.state_dict(), saved_model_path_full) 330 | logging.info('current model was saved to {}'.format(saved_model_path_full)) 331 | 332 | t_end_epoch = perf_counter() 333 | logging.info('training the epoch took {} sec'.format(t_end_epoch - t_start_epoch)) 334 | t_start_epoch = t_end_epoch 335 | 336 | tb_writer.close() 337 | 338 | t_end = perf_counter() 339 | logging.info('training took {} sec'.format(t_end - t_start)) 340 | 341 | # saving the best model 342 | model_file_name = cfg.model_name + '_' + cfg.data_staining + '_' + cfg.organ + '_' + cfg.animal + '_' + cfg.string_time + '_acc{:.4f}'.format(best_val_acc) 343 | saved_model_path_full = os.path.join(output_path, model_file_name + '.pt') 344 | torch.save(best_model_wts, saved_model_path_full) 345 | logging.info('best model was saved to {}'.format(saved_model_path_full)) 346 | model_for_evaluation = NetworkModel(n_classes=n_classes, path_trained_model=saved_model_path_full, dev=device).to(device) 347 | evaluate_model(val_loader, loss_fun, cfg.n_samples_val, model_for_evaluation) 348 | plt.savefig(os.path.join(output_path, 'confusion_matrix_best_model.png')) 349 | 350 | logging.info('tensorboard log was saved to {}'.format(output_path_tb)) 351 | 352 | # save configuration that can be read together with the models it was used to train 353 | saved_configuration_path_full = os.path.join(output_path, model_file_name + '_training_configuration.pkl') 354 | pickle_configuraton_as_dictionary(cfg, saved_configuration_path_full ) 355 | logging.info('pickle configuration file was saved to {}'.format(saved_configuration_path_full )) 356 | 357 | if hasattr(cfg, 'test_run') and cfg.test_run is True: 358 | print('output stored in {} and {} will be removed'.format(output_path, output_path_tb)) 359 | answer = input('Do you really want to remove these folders?') 360 | if answer in ('yes', 'y', 'Y', 'YES', 'Yes'): 361 | print('output stored in {} and {} are being removed'.format(output_path, output_path_tb)) 362 | time.sleep(5) 363 | shutil.rmtree(output_path_tb) 364 | shutil.rmtree(output_path) 365 | print('output stored in {} and {} were removed'.format(output_path, output_path_tb)) 366 | -------------------------------------------------------------------------------- /models/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | import torch.nn.functional as F 5 | from prettytable import PrettyTable 6 | import math 7 | import logging 8 | 9 | codes_field = 'codes' 10 | pooled_codes_field = 'pooled_codes' 11 | classes_field = 'categories' 12 | 13 | 14 | class CustomNet(nn.Module): 15 | def __init__(self, model_name, custom_trained_model=True): 16 | super().__init__() 17 | 18 | # define standard architecture 19 | specific_model = getattr(models, model_name) 20 | if custom_trained_model: 21 | self.model = specific_model(pretrained=False) 22 | else: 23 | self.model = specific_model(pretrained=True) 24 | 25 | def _find_number_of_features(self): 26 | 27 | n_features = None 28 | 29 | modules = list(self.model.named_modules()) 30 | for n in reversed(range(len(modules))): 31 | 32 | if hasattr(modules[n][1], "out_channels"): 33 | n_features = modules[n][1].out_channels 34 | logging.info("{} features were discovered as out_channels in CNN layer".format(n_features)) 35 | break 36 | 37 | if n_features is None: 38 | logging.info("number of features cannot be found") 39 | raise 40 | 41 | return n_features 42 | 43 | def add_classifier(self, n_features=None, n_hidden=64, *, n_classes): 44 | 45 | if not n_features: 46 | self.n_features = self._find_number_of_features() 47 | else: 48 | self.n_features = n_features 49 | logging.info("{} features were requested".format(self.n_features)) 50 | 51 | self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 52 | self.model.classifier = nn.Sequential( 53 | nn.Linear(self.n_features * 1 * 1, n_hidden), 54 | nn.ReLU(True), 55 | nn.Linear(n_hidden, n_classes), 56 | 57 | ) 58 | 59 | def req_grad(self, n, status): 60 | ''' 61 | Sets requires_grad field for network parameters up to the features convolutional layer 'n' to freeze (status=False) or unfreeze (status=True) the weights te be learned 62 | :param n: the number of first feature convolutional layers to be affected [0, n-1] 63 | :param status: False - freeze weights, True - unfreeze weights 64 | :return: 65 | ''' 66 | for layer_n in range(n): 67 | for parameter in self.model.features[layer_n].parameters(): 68 | parameter.requires_grad = status 69 | 70 | def fv_length(self): 71 | return self.n_features 72 | 73 | def custom_init_fcl_weights_vgg(self, n_first=0, n_last=None): 74 | """ 75 | Initializes the weights of fully connected layers according to the VGG recommendations (exists in the original pytorch vgg code) 76 | :param n_first: first layer to be initialized 77 | :param n_last: last layer to be initialized 78 | :return: 79 | """ 80 | 81 | if n_last is None: 82 | n_last = len(self.model.classifier) - 1 83 | 84 | for layer_n in range(n_first, n_last + 1): 85 | layer = self.model.classifier[layer_n] 86 | 87 | if isinstance(layer, nn.Linear): 88 | assert layer.weight.requires_grad is True 89 | nn.init.normal_(layer.weight, 0, 0.01) # 0.01 90 | assert layer.bias.requires_grad is True 91 | nn.init.constant_(layer.bias, 0) 92 | 93 | def custom_init_fcl_efficientnet(self, n_first=0, n_last=None): 94 | 95 | if n_last is None: 96 | n_last = len(self.model.classifier) - 1 97 | 98 | for layer_n in range(n_first, n_last + 1): 99 | layer = self.model.classifier[layer_n] 100 | 101 | if isinstance(layer, nn.Linear): 102 | 103 | init_range = 1.0 / math.sqrt(layer.out_features) 104 | assert layer.weight.requires_grad is True 105 | nn.init.uniform_(layer.weight, -init_range, init_range) 106 | assert layer.bias.requires_grad is True 107 | nn.init.zeros_(layer.bias) 108 | 109 | def count_parameters(self, verbose=False): 110 | table = PrettyTable(["Modules", "Parameters"]) 111 | total_params = 0 112 | for name, parameter in self.named_parameters(): 113 | if not parameter.requires_grad: 114 | continue 115 | 116 | param = parameter.numel() 117 | table.add_row([name, param]) 118 | total_params += param 119 | 120 | if verbose: 121 | print(table) 122 | print(f"Total Trainable Params: {total_params}") 123 | 124 | return total_params 125 | 126 | class VGG_11(CustomNet): 127 | def __init__(self, path_trained_model='', *, n_classes, dev): 128 | 129 | model_name = 'vgg11' 130 | 131 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 132 | 133 | self.add_classifier(n_classes=n_classes) 134 | 135 | # initialize weights of the model 136 | if path_trained_model: 137 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 138 | else: 139 | self.custom_init_fcl_weights_vgg() 140 | 141 | def forward(self, x): 142 | 143 | codes = self.model.features(x) 144 | pooled_codes = self.model.avgpool(codes) 145 | pooled_codes = torch.flatten(pooled_codes, 1) 146 | x = self.model.classifier(pooled_codes) 147 | 148 | return {classes_field: x, pooled_codes_field: pooled_codes} 149 | 150 | 151 | class VGG_16(CustomNet): 152 | def __init__(self, path_trained_model='', *, n_classes, dev): 153 | 154 | model_name = 'vgg16' 155 | 156 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 157 | 158 | self.add_classifier(n_classes=n_classes) 159 | 160 | # initialize weights of the model 161 | if path_trained_model: 162 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 163 | else: 164 | self.custom_init_fcl_weights_vgg() 165 | 166 | def forward(self, x): 167 | 168 | codes = self.model.features(x) 169 | pooled_codes = self.model.avgpool(codes) 170 | pooled_codes = torch.flatten(pooled_codes, 1) 171 | x = self.model.classifier(pooled_codes) 172 | 173 | return {classes_field: x, pooled_codes_field: pooled_codes} 174 | 175 | 176 | class VGG_19(CustomNet): 177 | def __init__(self, path_trained_model='', *, n_classes, dev): 178 | 179 | model_name = 'vgg19' 180 | 181 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 182 | 183 | self.add_classifier(n_classes=n_classes) 184 | 185 | # initialize weights of the model 186 | if path_trained_model: 187 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 188 | else: 189 | self.custom_init_fcl_weights_vgg() 190 | 191 | def forward(self, x): 192 | 193 | codes = self.model.features(x) 194 | pooled_codes = self.model.avgpool(codes) 195 | pooled_codes = torch.flatten(pooled_codes, 1) 196 | x = self.model.classifier(pooled_codes) 197 | 198 | return {classes_field: x, pooled_codes_field: pooled_codes} 199 | 200 | 201 | class EfficientNet_B0(CustomNet): 202 | def __init__(self, path_trained_model='', *, n_classes, dev): 203 | 204 | model_name = 'efficientnet_b0' 205 | 206 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 207 | 208 | self.add_classifier(n_classes=n_classes) 209 | 210 | # initialize weights of the model 211 | if path_trained_model: 212 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 213 | else: 214 | self.custom_init_fcl_weights_vgg() 215 | 216 | def forward(self, x): 217 | 218 | codes = self.model.features(x) 219 | pooled_codes = self.model.avgpool(codes) 220 | pooled_codes = torch.flatten(pooled_codes, 1) 221 | x = self.model.classifier(pooled_codes) 222 | 223 | return {classes_field: x, pooled_codes_field: pooled_codes} 224 | 225 | 226 | class EfficientNet_B2(CustomNet): 227 | def __init__(self, path_trained_model='', *, n_classes, dev): 228 | 229 | model_name = 'efficientnet_b2' 230 | 231 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 232 | 233 | self.add_classifier(n_classes=n_classes) 234 | 235 | # initialize weights of the model 236 | if path_trained_model: 237 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 238 | else: 239 | self.custom_init_fcl_weights_vgg() 240 | 241 | def forward(self, x): 242 | 243 | codes = self.model.features(x) 244 | pooled_codes = self.model.avgpool(codes) 245 | pooled_codes = torch.flatten(pooled_codes, 1) 246 | x = self.model.classifier(pooled_codes) 247 | 248 | return {classes_field: x, pooled_codes_field: pooled_codes} 249 | 250 | 251 | class EfficientNet_B0_320(CustomNet): 252 | def __init__(self, path_trained_model='', *, n_classes, dev): 253 | 254 | model_name = 'efficientnet_b0' 255 | n_features = 320 256 | 257 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 258 | 259 | self.model.features = self.model.features[:-1] 260 | self.model.features.add_module('addedBN', nn.BatchNorm2d(n_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) 261 | self.model.features.add_module('AddedSiLu', nn.SiLU(inplace=True)) 262 | 263 | self.add_classifier(n_classes=n_classes) 264 | assert n_features == self.n_features 265 | 266 | # initialize weights of the model 267 | if path_trained_model: 268 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 269 | else: 270 | self.custom_init_fcl_weights_vgg() 271 | 272 | def forward(self, x): 273 | 274 | codes = self.model.features(x) 275 | pooled_codes = self.model.avgpool(codes) 276 | pooled_codes = torch.flatten(pooled_codes, 1) 277 | x = self.model.classifier(pooled_codes) 278 | 279 | return {classes_field: x, pooled_codes_field: pooled_codes} 280 | 281 | class EfficientNet_B2_352(CustomNet): 282 | def __init__(self, path_trained_model='', *, n_classes, dev): 283 | 284 | model_name = 'efficientnet_b2' 285 | n_features = 352 286 | 287 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 288 | 289 | self.model.features = self.model.features[:-1] 290 | self.model.features.add_module('addedBN', nn.BatchNorm2d(n_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) 291 | self.model.features.add_module('AddedSiLu', nn.SiLU(inplace=True)) 292 | 293 | self.add_classifier(n_classes=n_classes) 294 | assert n_features == self.n_features 295 | 296 | # initialize weights of the model 297 | if path_trained_model: 298 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 299 | else: 300 | self.custom_init_fcl_weights_vgg() 301 | 302 | def forward(self, x): 303 | 304 | codes = self.model.features(x) 305 | pooled_codes = self.model.avgpool(codes) 306 | pooled_codes = torch.flatten(pooled_codes, 1) 307 | x = self.model.classifier(pooled_codes) 308 | 309 | return {classes_field: x, pooled_codes_field: pooled_codes} 310 | 311 | 312 | class ConvNeXt(CustomNet): 313 | def __init__(self, path_trained_model='', *, n_classes, dev): 314 | 315 | model_name = 'convnext_tiny' 316 | 317 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 318 | 319 | self.add_classifier(n_classes=n_classes) 320 | 321 | # initialize weights of the model 322 | if path_trained_model: 323 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 324 | else: 325 | self.custom_init_fcl_weights_vgg() 326 | 327 | def forward(self, x): 328 | 329 | codes = self.model.features(x) 330 | pooled_codes = self.model.avgpool(codes) 331 | pooled_codes = torch.flatten(pooled_codes, 1) 332 | x = self.model.classifier(pooled_codes) 333 | 334 | return {classes_field: x, pooled_codes_field: pooled_codes} 335 | 336 | 337 | class ResNet_18(CustomNet): 338 | def __init__(self, path_trained_model='', *, n_classes, dev): 339 | 340 | model_name = 'resnet18' 341 | 342 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 343 | 344 | self.add_classifier(n_classes=n_classes) 345 | 346 | # initialize weights of the model 347 | if path_trained_model: 348 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 349 | else: 350 | self.custom_init_fcl_weights_vgg() 351 | 352 | def forward(self, x): 353 | 354 | x = self.model.conv1(x) 355 | x = self.model.bn1(x) 356 | x = self.model.relu(x) 357 | x = self.model.maxpool(x) 358 | 359 | x = self.model.layer1(x) 360 | x = self.model.layer2(x) 361 | x = self.model.layer3(x) 362 | codes = self.model.layer4(x) 363 | pooled_codes = self.model.avgpool(codes) 364 | 365 | pooled_codes = torch.flatten(pooled_codes, 1) 366 | x = self.model.classifier(pooled_codes) 367 | 368 | return {classes_field: x, pooled_codes_field: pooled_codes} 369 | 370 | 371 | class DenseNet_121(CustomNet): 372 | def __init__(self, path_trained_model='', *, n_classes, dev): 373 | 374 | model_name = 'densenet121' 375 | n_features = 1024 376 | 377 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 378 | 379 | self.add_classifier(n_classes=n_classes, n_features=n_features) 380 | 381 | # initialize weights of the model 382 | if path_trained_model: 383 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 384 | else: 385 | self.custom_init_fcl_weights_vgg() 386 | 387 | def forward(self, x): 388 | 389 | codes = self.model.features(x) 390 | codes = F.relu(codes, inplace=True) 391 | pooled_codes = self.model.avgpool(codes) 392 | pooled_codes = torch.flatten(pooled_codes, 1) 393 | x = self.model.classifier(pooled_codes) 394 | 395 | return {classes_field: x, pooled_codes_field: pooled_codes} 396 | 397 | 398 | class DenseNet_121_512(CustomNet): 399 | def __init__(self, path_trained_model='', *, n_classes, dev): 400 | 401 | model_name = 'densenet121' 402 | n_features = 512 403 | 404 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 405 | 406 | self.model.features = self.model.features[:-2] 407 | self.model.features.add_module('addedBN', nn.BatchNorm2d(n_features)) 408 | 409 | self.add_classifier(n_classes=n_classes) 410 | assert n_features == self.n_features 411 | 412 | # initialize weights of the model 413 | if path_trained_model: 414 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 415 | else: 416 | self.custom_init_fcl_weights_vgg() 417 | 418 | def forward(self, x): 419 | 420 | codes = self.model.features(x) 421 | codes = F.relu(codes, inplace=True) 422 | pooled_codes = self.model.avgpool(codes) 423 | pooled_codes = torch.flatten(pooled_codes, 1) 424 | x = self.model.classifier(pooled_codes) 425 | 426 | return {classes_field: x, pooled_codes_field: pooled_codes} 427 | 428 | 429 | class VT_B_32(CustomNet): 430 | def __init__(self, path_trained_model='', *, n_classes, dev): 431 | 432 | model_name = 'vit_b_32' 433 | 434 | super().__init__(model_name=model_name, custom_trained_model=bool(path_trained_model)) 435 | 436 | self.add_classifier(n_classes=n_classes) 437 | 438 | # initialize weights of the model 439 | if path_trained_model: 440 | self.load_state_dict(torch.load(path_trained_model, map_location=dev)) 441 | else: 442 | self.custom_init_fcl_weights_vgg() 443 | 444 | def forward(self, x): 445 | 446 | x = self.model._process_input(x) 447 | n = x.shape[0] 448 | 449 | # Expand the class token to the full batch 450 | batch_class_token = self.model.class_token.expand(n, -1, -1) 451 | x = torch.cat([batch_class_token, x], dim=1) 452 | 453 | x = self.model.encoder(x) 454 | 455 | # Classifier "token" as used by standard language architectures 456 | pooled_codes = x[:, 0] 457 | 458 | x = self.model.classifier(pooled_codes) 459 | 460 | return {classes_field: x, pooled_codes_field: pooled_codes} 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | -------------------------------------------------------------------------------- /anomaly_detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from sklearn import svm # must stay here (eval) 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | from utils.training_utils import apply_net 7 | import torch 8 | from utils.data_processing import CodesProcessor, set_seed 9 | from torchvision import transforms 10 | from utils.imports import show_configuration, save_configuration, update_configuration 11 | from utils.image_dataset_reader import HistImagesDataset, samples_per_location_from_samples_per_class 12 | from torch.utils.data import DataLoader 13 | import logging 14 | import pickle 15 | from sklearn import manifold 16 | from matplotlib import cm 17 | import pandas as pd 18 | from matplotlib.font_manager import FontProperties 19 | import time 20 | import os 21 | import argparse 22 | from time import perf_counter 23 | from sklearn.metrics import roc_auc_score, balanced_accuracy_score, f1_score, roc_curve 24 | from shutil import copyfile 25 | import models.pretrained_networks as HistoModel 26 | from utils.timed_input import limited_time_input 27 | matplotlib.use('Qt5Agg') 28 | 29 | 30 | def anomaly_training(model, clf, data_loader, dev, cache_path=None): 31 | 32 | logging.info('computing codes for training anomaly classifier') 33 | code_processor = CodesProcessor(cached_codes_path=cache_path) 34 | apply_net(model, dev, data_loader, verbose=True, code_processor=code_processor) 35 | 36 | logging.info('all codes for training were computed') 37 | 38 | features = code_processor.get_codes() 39 | logging.info('codes are generated, starting training anomaly classifier') 40 | clf.fit(features) 41 | 42 | return clf, features 43 | 44 | 45 | def anomaly_detection(clf, model, data_loader, dev, feat_container=None, cache_path=None): 46 | 47 | code_processor = CodesProcessor(cached_codes_path=cache_path) 48 | apply_net(model, dev, data_loader, verbose=True, code_processor=code_processor) 49 | logging.info('all codes were computed') 50 | 51 | # classify feature representations 52 | features = code_processor.get_codes() 53 | image_names = code_processor.get_image_names() 54 | image_labels = code_processor.get_image_labels() 55 | 56 | assert len(features) == len(image_names) 57 | 58 | scores = clf.decision_function(features) 59 | 60 | if feat_container is not None: 61 | 62 | feat_container['labels'].extend(image_labels) 63 | 64 | if len(feat_container['features']) == 0: 65 | feat_container['features'] = np.zeros((0, features.shape[1])) 66 | feat_container['features'] = np.concatenate((feat_container['features'], features), axis=0) 67 | 68 | return scores, image_names, image_labels 69 | 70 | 71 | def save_images_from_excel(path2csv, path2save, n_examples, ex_type='FN'): 72 | 73 | """ 74 | Copies n_examples images with the highest/lowest scores (most normals/anomalous, corresponding to most false negative/positive examples - normal/abnormal detections from anomaly/normal data) 75 | to newly created subfolder 'ex_type' in path2save, for ex_type='FN'/'FP'. All the data is read from csv file path2csv 76 | :param path2csv: 77 | :param path2save: 78 | :param n_examples: maximal number of extreme examples to be copied (lower number might be copied if thee are less of normal/anomalous images) 79 | :param ex_type: must be 'FN' or 'FP' - false negatives for anomalous input data or false positives for normal input data 80 | :return: 81 | """ 82 | 83 | # read csv 84 | df = pd.read_csv(path2csv) 85 | image_paths = df['image_path'].to_list() 86 | image_names = df['image name'].to_list() 87 | anomaly_scores = df['anomaly scores'].to_list() 88 | labels = df['original label'].to_list() 89 | 90 | assert len(image_paths) == len(image_names) 91 | 92 | if ex_type == 'FN': 93 | idx_sorted = np.argsort(-np.array(anomaly_scores)) 94 | elif ex_type == 'FP': 95 | idx_sorted = np.argsort(np.array(anomaly_scores)) 96 | else: 97 | raise RuntimeError("argument ex_type must be 'FN' or 'FP'") 98 | 99 | image_path_new = os.path.join(path2save, ex_type) 100 | if not os.path.exists(image_path_new): 101 | os.makedirs(image_path_new) 102 | 103 | for n, idx in enumerate(idx_sorted): 104 | 105 | image_path = image_paths[idx] 106 | image_name = image_names[idx] 107 | anomaly_score = anomaly_scores[idx] 108 | label = labels[idx] 109 | 110 | if n == n_examples: 111 | break 112 | 113 | if ex_type == 'FN': 114 | if anomaly_score < 0: 115 | break 116 | elif ex_type == 'FP': 117 | if anomaly_score > 0: 118 | break 119 | else: 120 | raise RuntimeError("argument ex_type must be 'FN' or 'FP'") 121 | 122 | image_name_new, ext = os.path.splitext(image_name) 123 | image_name_new = image_name_new + '_' + label + '_score:' + str(anomaly_score).replace(' ', '_') + ext 124 | 125 | full_path_new = os.path.join(image_path_new, image_name_new) 126 | 127 | full_path_old = os.path.join(image_path, image_name) 128 | 129 | copyfile(full_path_old, full_path_new) 130 | 131 | 132 | def TSNE_visualization(features, labels, colormap='tab10', n_features=None, scores=None, save_path=None): 133 | """ 134 | 135 | :param features: 136 | :param labels: data original labels 137 | :param colormap: 138 | :param n_features: maximal allowed number of feature vectors to be shown 139 | :param scores: anomaly scores for the data. To be used only for the second plot to see detection boundary 140 | :param save_path: path to save TSNE images 141 | :return: 142 | """ 143 | 144 | assert len(features) == len(labels) 145 | if scores is not None: 146 | scores = np.array(scores) 147 | assert len(labels) == len(scores) 148 | 149 | perplexity = 30 150 | colormap = cm.get_cmap(colormap) 151 | 152 | me = manifold.TSNE(n_components=2, perplexity=perplexity, verbose=1, method='barnes_hut') 153 | new_features = me.fit_transform(features) 154 | 155 | # --------plot data TSNE clusters-------- 156 | # plot new embeddings 157 | fig, ax = plt.subplots() 158 | fig2, ax2 = plt.subplots() 159 | set_labels = set(labels) 160 | 161 | # make the order in which 'normal' or 'NAS_anomaly' labels will be in the beginning of the set_labels list 162 | n_lab = [] 163 | a_lab = [] 164 | for lab in set_labels: 165 | if ('normal' in lab) or ('NAS anomaly' in lab): 166 | n_lab.append(lab) 167 | else: 168 | a_lab.append(lab) 169 | set_labels = n_lab + a_lab 170 | 171 | n_labels = len(set_labels) 172 | for n, label in enumerate(set_labels): 173 | 174 | if label == 'normal' or label == 'NAS anomaly': 175 | marker_size = 5 176 | marker_shape = 'o' 177 | else: 178 | marker_size = 5 179 | marker_shape = 'o' 180 | 181 | # select indexes of current class only 182 | ind = [labels[i] == label for i in range(len(labels))] 183 | 184 | class_features = new_features[ind, :] 185 | if scores is not None: 186 | class_scores = scores[ind] 187 | 188 | # randomly select only part of features 189 | if n_features: 190 | ind = np.random.permutation(class_features.shape[0])[:n_features] 191 | class_features = class_features[ind, :] 192 | if scores is not None: 193 | class_scores = class_scores[ind] 194 | 195 | ax.scatter(class_features[:, 0], class_features[:, 1], label=label, marker=marker_shape, s=marker_size, c=np.reshape(colormap(n/(n_labels-1)), (1, 4)), alpha=0.5) 196 | 197 | if scores is not None: 198 | ind_anomaly = class_scores > 0 199 | 200 | h1 = ax2.scatter(class_features[ind_anomaly, 0], class_features[ind_anomaly, 1], marker='o', label='normal', s=marker_size, c='b', alpha=0.5) 201 | h2 = ax2.scatter(class_features[~ind_anomaly, 0], class_features[~ind_anomaly, 1], marker='o', label='anomaly', s=marker_size, c='r', alpha=0.5) 202 | 203 | leg = ax.legend(bbox_to_anchor=(0.98, 1), prop=fontP, markerscale=2.5) 204 | # makes opacity = 100% instead of taking opacity from the points drawn 205 | for lh in leg.legendHandles: 206 | lh.set_alpha(1) 207 | 208 | if scores is not None: 209 | leg2 = ax2.legend(bbox_to_anchor=(0.98, 1), prop=fontP, handles=[h1, h2], markerscale=2.5) 210 | # makes opacity = 100% instead of taking opacity from the points drawn 211 | for lh in leg2.legendHandles: 212 | lh.set_alpha(1) 213 | 214 | if save_path: 215 | fig.savefig(os.path.join(save_path, 'feature_vectors_TSNE.png'), bbox_inches='tight') 216 | plt.close() 217 | fig2.savefig(os.path.join(save_path, 'detections_TSNE.png'), bbox_inches='tight') 218 | plt.close() 219 | 220 | 221 | def save_to_excel(scores, image_names, image_labels, csv_path): 222 | 223 | paths = [] 224 | file_names = [] 225 | for i in range(len(image_names)): 226 | path, file_name = os.path.split(image_names[i]) 227 | paths.append(path) 228 | file_names.append(file_name) 229 | 230 | data = pd.DataFrame(scores, columns=['anomaly scores']) 231 | data['original label'] = image_labels 232 | data['image_path'] = paths 233 | data['image name'] = file_names 234 | 235 | data.to_csv(csv_path, index=False) 236 | 237 | return 238 | 239 | 240 | def create_data_loader(paths, cfg, n_patches=None, batch_size=1, augmentation=False): 241 | 242 | if not paths: 243 | return None 244 | 245 | if n_patches and isinstance(n_patches, int): 246 | n_patches = samples_per_location_from_samples_per_class(*paths, samples_per_class=n_patches) 247 | 248 | tr_normalize = transforms.Normalize(mean=cfg.normalize_mean, std=cfg.normalize_std) 249 | if augmentation: 250 | hs_jitter = transforms.ColorJitter(saturation=cfg.aug_saturation, hue=cfg.aug_hue) 251 | bc_jitter = transforms.ColorJitter(brightness=cfg.aug_brightness, contrast=cfg.aug_contrast) 252 | 253 | if not augmentation: 254 | seq = [transforms.CenterCrop(cfg.patch_size), transforms.ToTensor(), tr_normalize] 255 | else: 256 | seq = [transforms.CenterCrop(cfg.patch_size), hs_jitter, bc_jitter, transforms.ToTensor(), tr_normalize] 257 | 258 | transforms_seq = transforms.Compose(seq) 259 | 260 | dataset = HistImagesDataset(*paths, transform=transforms_seq, n_samples=n_patches) 261 | 262 | data_loader = DataLoader(dataset, num_workers=0, batch_size=batch_size) 263 | 264 | return data_loader 265 | 266 | 267 | def sample_features(embeddings, n_samples, anomaly_scores): 268 | 269 | features_to_visualize = {'features': np.zeros((0, 0)), 'labels': []} 270 | scores = [] 271 | 272 | labs = set(embeddings['labels']) 273 | lab_intersection = labs.intersection(n_samples.keys()) 274 | assert len(lab_intersection) >= 2 275 | 276 | if len(labs.intersection(n_samples.keys())) != len(labs): 277 | logging.info('labels from datasets to be visualized in TSNE and requested labels for visualization do not completely coinside') 278 | 279 | for lab in labs: 280 | if lab in n_samples.keys(): 281 | idx = [i for i in range(len(embeddings['labels'])) if embeddings['labels'][i] == lab] 282 | if n_samples[lab] <= len(idx): 283 | idx = random.sample(idx, n_samples[lab]) 284 | else: 285 | logging.warning('number of requested features {} is larger than was calculated {}, using all calculated features'.format(n_samples[lab], len(idx))) 286 | 287 | sampled_labels = [embeddings['labels'][i] for i in idx] 288 | sampled_scores = [anomaly_scores[i] for i in idx] 289 | sampled_features = embeddings['features'][idx, :] 290 | 291 | scores.extend(sampled_scores) 292 | features_to_visualize['labels'].extend(sampled_labels) 293 | if len(features_to_visualize['features']) == 0: 294 | features_to_visualize['features'] = np.zeros((0, sampled_features.shape[1])) 295 | features_to_visualize['features'] = np.concatenate((features_to_visualize['features'], sampled_features), axis=0) 296 | 297 | return features_to_visualize, scores 298 | 299 | # ---------------------------------------------------------------------------------------------------------------------- 300 | # ---------------------------------------------------------------------------------------------------------------------- 301 | 302 | 303 | logging.basicConfig(level=logging.INFO) 304 | t_start = perf_counter() 305 | string_time = time.strftime("%H%M%S_%d%m%y") 306 | parser = argparse.ArgumentParser(description='anomaly detector') 307 | 308 | cfg = update_configuration(parser) 309 | 310 | if not cfg.description: 311 | cfg.description = limited_time_input("Please enter description of an experiment...", 30) 312 | 313 | print('\n') 314 | show_configuration(cfg) 315 | print('\n') 316 | 317 | set_seed(cfg.seed_number) 318 | 319 | fontP = FontProperties() 320 | fontP.set_size('xx-small') 321 | 322 | if not os.path.isdir(cfg.output_path): 323 | os.makedirs(cfg.output_path) 324 | 325 | if os.listdir(cfg.output_path): 326 | raise RuntimeError(f"output folder {cfg.output_path} is not empty") 327 | #assert not os.listdir(cfg.output_path), f"output folder {cfg.output_path} is not empty" 328 | 329 | #stream_handler = logging.StreamHandler(stream=sys.stdout) 330 | file_handler = logging.FileHandler(os.path.join(cfg.output_path, 'anomaly_detector.log')) 331 | logging.root.addHandler(file_handler) 332 | 333 | logging.info("staining: {}".format(cfg.data_staining)) 334 | logging.info("number of classes: {}".format(cfg.n_trained_classes)) 335 | logging.info('output folder: {}'.format(cfg.output_path)) 336 | 337 | dev = torch.device(cfg.dev) 338 | 339 | NetworkModel = getattr(HistoModel, cfg.model_architecture) 340 | model = NetworkModel(path_trained_model=cfg.cnn_model, n_classes=cfg.n_trained_classes, dev=dev) 341 | model.to(dev) 342 | 343 | if cfg.cnn_model: # path to model was defined 344 | CNN_model_name, _ = os.path.splitext(cfg.cnn_model) 345 | CNN_model_path, CNN_model_name = os.path.split(CNN_model_name) 346 | logging.info('CNN model {} will be used'.format(cfg.cnn_model)) 347 | 348 | #output_path = os.path.join(CNN_model_path, cfg.anomaly_model_folder + "_" + CNN_model_name) 349 | 350 | #anomaly_model_path = os.path.join(output_path, CNN_model_name + '.pkl') 351 | else: # path for model, use pre-trained standard model 352 | CNN_model_name = cfg.model_architecture + '_backbone' 353 | CNN_model_path = cfg.output_path 354 | #CNN_model_path = cfg.code_output_path + string_time 355 | 356 | CNN_model_path_full = os.path.join(CNN_model_path, CNN_model_name + '.pt') 357 | torch.save(model.state_dict(), CNN_model_path_full) 358 | logging.info('backbone model was saved to {}'.format(CNN_model_path_full)) 359 | 360 | #output_path = os.path.join(CNN_model_path, cfg.anomaly_model_folder + "_" + CNN_model_name) 361 | #anomaly_model_path = os.path.join(output_path, CNN_model_name + '.pkl') 362 | 363 | 364 | print("") 365 | if cfg.ad_model: # path to anomaly detection model was defined 366 | 367 | AD_model_name, _ = os.path.splitext(cfg.ad_model) 368 | AD_model_path, AD_model_name = os.path.split(AD_model_name) 369 | 370 | assert CNN_model_name == AD_model_name[:len(CNN_model_name)], "CNN and AD names must be the same except extention and possibly suffix" 371 | 372 | try: 373 | clf = pickle.load(open(cfg.ad_model, 'rb')) 374 | except FileNotFoundError: 375 | clf = None 376 | 377 | else: 378 | clf = None 379 | 380 | # try: 381 | # clf = pickle.load(open(anomaly_model_path, 'rb')) 382 | # except: 383 | # new_model = True 384 | # 385 | # try: 386 | # os.makedirs(output_path) 387 | # except FileExistsError: 388 | # logging.info('{} already exists. Probably training has already been started but then was canceled'.format(output_path)) 389 | # else: 390 | # ans = limited_time_input("Do you want to use already trained one class classifier model [y/n]? ", 30) 391 | # if ans in ('y', 'Y', 'yes', 'YES', 'Yes'): 392 | # new_model = False 393 | # 394 | # logging.info("previousely trained one-class classifier will used") 395 | # 396 | # # create a new subfolder 397 | # path_folder, path_subfolder = os.path.split(output_path) 398 | # path_subfolder = path_subfolder + '_' + string_time 399 | # output_path = os.path.join(path_folder, path_subfolder) 400 | # else: 401 | # new_model = True 402 | # 403 | # logging.info("new one class classifier will be trained (the previous one will NOT be erased)") 404 | # # create a new subfolder 405 | # path_folder, path_subfolder = os.path.split(output_path) 406 | # path_subfolder = path_subfolder + '_retrained_' + string_time 407 | # output_path = os.path.join(path_folder, path_subfolder) 408 | # anomaly_model_path = os.path.join(output_path, CNN_model_name + '.pkl') 409 | # 410 | # try: 411 | # os.makedirs(output_path) 412 | # except FileExistsError: 413 | # raise RuntimeError('{} already exists'.format(output_path)) 414 | 415 | 416 | # if not cfg.cnn_model: # save backbone model 417 | # saved_model_path_full = os.path.join(CNN_model_path, CNN_model_name + '.pt') 418 | # torch.save(model.state_dict(), saved_model_path_full) 419 | # logging.info('backbone model was saved to {}'.format(saved_model_path_full)) 420 | 421 | # file_handler = logging.FileHandler(os.path.join(output_path, 'anomaly_detector.log')) 422 | # logging.root.addHandler(file_handler) 423 | # 424 | # logging.info("staining: {}".format(cfg.data_staining)) 425 | # logging.info("number of classes: {}".format(cfg.n_trained_classes)) 426 | 427 | if not clf: 428 | 429 | logging.info('a new model for anomaly detector will be trained') 430 | 431 | AD_model_name = CNN_model_name 432 | AD_model_path = CNN_model_path 433 | AD_model_path_full = os.path.join(AD_model_path, AD_model_name + '.pkl') 434 | 435 | # prevent overwriting if the model with same name was already existing 436 | if os.path.isfile(AD_model_path_full): 437 | AD_model_name = CNN_model_name + '_' + string_time 438 | AD_model_path_full = os.path.join(CNN_model_path, AD_model_name + '.pkl') 439 | 440 | logging.info("--------------training---------------------") 441 | training_data_loader = create_data_loader(cfg.paths_normal, cfg, n_patches=cfg.train_patches_for_train_max, batch_size=cfg.batch_size, augmentation=cfg.augmentation) 442 | classifier = eval(cfg.clf) 443 | clf, _ = anomaly_training(model, classifier, training_data_loader, dev, cache_path=cfg.output_path) 444 | logging.info('-----------training has been finished--------------') 445 | 446 | pickle.dump(clf, open(AD_model_path_full, 'wb')) 447 | logging.info('trained model was saved to {}'.format(AD_model_path_full)) 448 | 449 | t_end_training = perf_counter() 450 | logging.info('training on normal data took {} sec'.format(t_end_training - t_start)) 451 | 452 | #save_configuration(cfg, output_path + '/anomaly_detector_configuration.txt') 453 | 454 | #training_data_loader = create_data_loader(cfg.paths_normal, cfg, n_patches=cfg.train_patches_for_test_max, brightness_factor=cfg.brightness_factor, batch_size=cfg.batch_size) 455 | else: 456 | logging.info('anomaly detector model {} was found and will be used'.format(cfg.ad_model)) 457 | 458 | # if new_model: 459 | # 460 | # logging.info('anomaly detector model {} was not found'.format(anomaly_model_path)) 461 | # logging.info('a new model for anomaly detector will be trained') 462 | # 463 | # logging.info("--------------training---------------------") 464 | # training_data_loader = create_data_loader(cfg.paths_normal, cfg, n_patches=cfg.train_patches_for_train_max, batch_size=cfg.batch_size, augmentation=cfg.augmentation) 465 | # classifier = eval(cfg.clf) 466 | # clf, _ = anomaly_training(model, classifier, training_data_loader, dev, cache_path=output_path) 467 | # logging.info('-----------training has been finished--------------') 468 | # 469 | # pickle.dump(clf, open(anomaly_model_path, 'wb')) 470 | # logging.info('trained model was saved to {}'.format(anomaly_model_path)) 471 | # 472 | # else: 473 | # logging.info('already saved anomaly detector model {} was found and will be used'.format(anomaly_model_path)) 474 | 475 | save_configuration(cfg, os.path.join(cfg.output_path, 'anomaly_detector_configuration.txt')) 476 | #logging.info('output folder: {}'.format(output_path)) 477 | 478 | 479 | features_to_visualize = {'features': np.zeros((0, 0)), 'labels': []} 480 | scores = [] 481 | scores_normal_test, normal_test_im_names, normal_test_labels = None, None, None 482 | scores_liver_anomaly_test, liver_anomaly_test_im_names, liver_anomaly_test_labels = None, None, None 483 | scores_non_liver_test, non_liver_test_im_names, non_liver_test_labels = None, None, None 484 | 485 | print("") 486 | normal_test_data_loader = create_data_loader(cfg.paths_normal_test, cfg, n_patches=cfg.test_normal_patches_max, batch_size=cfg.batch_size) 487 | 488 | if normal_test_data_loader: 489 | logging.info('-----------anomaly detection in normal test data----------------') 490 | scores_normal_test, normal_test_im_names, normal_test_labels = anomaly_detection(clf, model, normal_test_data_loader, dev, feat_container=features_to_visualize, cache_path=cfg.output_path) 491 | if scores_normal_test is not None: 492 | scores.extend(scores_normal_test) 493 | save_to_excel(scores_normal_test, normal_test_im_names, normal_test_labels, os.path.join(cfg.output_path, cfg.csv_liver_tissue_testnormals)) 494 | save_images_from_excel(os.path.join(cfg.output_path, cfg.csv_liver_tissue_testnormals), cfg.output_path, cfg.save_n_FP, 'FP') 495 | 496 | print("") 497 | liver_anomaly_test_data_loader = create_data_loader(cfg.paths_liver_anomaly_test, cfg, n_patches=cfg.test_anomaly_patches_per_class_max, batch_size=cfg.batch_size) 498 | 499 | if liver_anomaly_test_data_loader: 500 | logging.info('-----------anomaly detection in liver with conditions----------------') 501 | scores_liver_anomaly_test, liver_anomaly_test_im_names, liver_anomaly_test_labels = anomaly_detection(clf, model, liver_anomaly_test_data_loader, dev, feat_container=features_to_visualize, cache_path=cfg.output_path) 502 | if scores_liver_anomaly_test is not None: 503 | scores.extend(scores_liver_anomaly_test) 504 | save_to_excel(scores_liver_anomaly_test, liver_anomaly_test_im_names, liver_anomaly_test_labels, os.path.join(cfg.output_path, cfg.csv_liver_tissue_anomalies)) 505 | save_images_from_excel(os.path.join(cfg.output_path, cfg.csv_liver_tissue_anomalies), cfg.output_path, cfg.save_n_FN, 'FN') 506 | 507 | 508 | print("") 509 | non_liver_test_data_loader = create_data_loader(cfg.paths_non_liver_tissues_test, cfg, n_patches=cfg.visual_test_auxiliary_patches_per_class, batch_size=cfg.batch_size) 510 | 511 | if non_liver_test_data_loader: 512 | logging.info('-----------anomaly detection in different to liver tissues---------------') 513 | scores_non_liver_test, non_liver_test_im_names, non_liver_test_labels = anomaly_detection(clf, model, non_liver_test_data_loader, dev, feat_container=features_to_visualize, cache_path=cfg.output_path) 514 | if scores_non_liver_test is not None: 515 | scores.extend(scores_non_liver_test) 516 | 517 | 518 | # t_end = perf_counter() 519 | # logging.info('training on normal data and inference took {} sec'.format(t_end - t_start)) 520 | 521 | # print to file 522 | if (scores_normal_test is not None) and (scores_liver_anomaly_test is not None): 523 | with open(os.path.join(cfg.output_path, 'anomaly_detector_results.txt'), 'w') as fh: 524 | 525 | logging.info('----------performance summary----------') 526 | print('----------performance summary-----------', file=fh) 527 | 528 | all_targets = np.concatenate((np.ones(len(scores_liver_anomaly_test)), np.zeros(len(scores_normal_test)))) 529 | all_scores = - np.concatenate((scores_liver_anomaly_test, scores_normal_test)) 530 | 531 | # balanced accuracy 532 | ba = balanced_accuracy_score(all_targets, all_scores > 0) 533 | logging.info('Balanced accuracy: {}'.format(ba)) 534 | print('Balanced accuracy: {}'.format(ba), file=fh) 535 | 536 | # calculate AUC 537 | auc = roc_auc_score(all_targets, all_scores) 538 | logging.info('AUC: {}'.format(auc)) 539 | print('AUC: {}'.format(auc), file=fh) 540 | 541 | # ROC curve 542 | fpr, tpr, thr_roc = roc_curve(all_targets, all_scores) 543 | ind_thr0 = np.argmin(np.abs(thr_roc)) 544 | fpr_thr0 = fpr[ind_thr0] 545 | tpr_thr0 = tpr[ind_thr0] 546 | 547 | fig, ax = plt.subplots() 548 | plt.plot(fpr, tpr) 549 | plt.plot(fpr_thr0, tpr_thr0, '*') 550 | plt.xlabel("False Positives Rate (1-specificity)") 551 | plt.ylabel("True Positives Rate (sensitivity)") 552 | fig.savefig(os.path.join(cfg.output_path, 'ROC.png'), bbox_inches='tight') 553 | plt.close() 554 | 555 | logging.info('predicted anomalies in normal data: {:.2f}%'.format(fpr_thr0 * 100)) 556 | print('predicted anomalies in normal data: {:.2f}%'.format(fpr_thr0 * 100), file=fh) 557 | logging.info('predicted normals in anomaly data: {:.2f}%'.format((1 - tpr_thr0) * 100)) 558 | print('predicted normals in anomaly data: {:.2f}%'.format((1 - tpr_thr0) * 100), file=fh) 559 | 560 | # f1 score 561 | f1 = f1_score(all_targets, all_scores > 0) 562 | logging.info('F1 score: {}'.format(f1)) 563 | print('F1 score: {}'.format(f1), file=fh) 564 | 565 | if scores: 566 | features_to_visualize_sampled, scores_sampled = sample_features(features_to_visualize, cfg.n_features_visualization, scores) 567 | TSNE_visualization(features_to_visualize_sampled['features'], features_to_visualize_sampled['labels'], scores=scores_sampled, save_path=cfg.output_path, colormap='tab20') 568 | 569 | t_end = perf_counter() 570 | logging.info('it took {} sec'.format(t_end - t_start)) 571 | 572 | -------------------------------------------------------------------------------- /utils/image_dataset_reader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Oct 8 12:11:10 2019 5 | 6 | @author: zingman 7 | """ 8 | from PIL import Image 9 | import glob 10 | import random 11 | from torch.utils.data import Dataset 12 | 13 | import numpy as np 14 | import copy 15 | 16 | import logging 17 | 18 | 19 | import time 20 | 21 | Image.MAX_IMAGE_PIXELS = 5000000000 22 | from tqdm import tqdm 23 | 24 | 25 | class color_transform_skimage(): 26 | 27 | def __init__(self, int_type_precision=np.uint8): 28 | """ 29 | Envelopes skimage color transforms such that they will output integer images suitable for LUT transformation. 30 | Below there are definitions of parameters that should correspond to the skimage color transform used. 31 | For color transformations use forward() and backward functions() 32 | 33 | :param int_type_precision: precision with which LUT will be performed. e.g. for np.uin8 it will be LUT 256 values 34 | """ 35 | 36 | # -------these values should be adapted to the color skimage transform used-------- 37 | self.tr_type = np.float # type of output of color skimage transformation 38 | 39 | # self.tr_min = np.array([16, 16, 16], dtype=self.tr_type) # min values of the output of color transform (YCbCr) 40 | # self.tr_max = np.array([235, 240, 240], dtype=self.tr_type) # max values of the output of color transform (YCbCr) 41 | 42 | self.tr_min = np.array([0, 0, 0], dtype=self.tr_type) # min values of the output of color transform 43 | self.tr_max = np.array([255, 255, 255], dtype=self.tr_type) # max values of the output of color transform 44 | 45 | # color transforms used 46 | # self.forward_transform = rgb2ycbcr # forward color transform 47 | # self.backward_transform = ycbcr2rgb # backward color transform 48 | 49 | self.forward_transform = lambda im_in: np.array(im_in, dtype=self.tr_type) # identity transform (direct RGB use) 50 | self.backward_transform = lambda im_in: im_in/np.expand_dims(np.expand_dims(self.tr_max, axis=0), axis=0) # identity transform (direct RGB use), skiimage transforms output in [0, 1] range 51 | 52 | #-------------- 53 | 54 | self.dig_type = int_type_precision # integer type where mixup up transform (LUT) happens 55 | self.max_intensity = np.iinfo(self.dig_type).max # number of values - 1 used for LUT and color histograms 56 | 57 | self.input_type = None # type of the input image to be processed, also the output image processed 58 | self.max_input_type = None # maximal value of the input image being processed 59 | 60 | def get_n_intensity_levels(self): 61 | return self.max_intensity + 1 62 | 63 | def _digitize(self, im): 64 | """ 65 | Transforms the input image to the integer type (self.dig_type) rescaling according the input type of the image 66 | :param im: 67 | :return: 68 | """ 69 | 70 | im_out = np.array(im, dtype=self.dig_type) 71 | im = np.moveaxis(im, [0, 1, 2], [1, 2, 0]) 72 | for i, im_channel in enumerate(im): 73 | 74 | if im_channel.max() < self.tr_min[i]: 75 | self.logger.warning("Warning: skimage tranform can get values that are smaller than min defined in the color_transform_skimage class") 76 | if im_channel.min() > self.tr_max[i]: 77 | self.logger.warning("Warning: skimage transform can get values that are larger than max defined in the color_transform_skimage class") 78 | 79 | factor = np.float(self.max_intensity) / (self.tr_max[i] - self.tr_min[i]) 80 | im_channel = (im_channel - self.tr_min[i]) * factor 81 | im_channel = np.clip(im_channel, 0, self.max_intensity) 82 | im_channel = np.array(im_channel, dtype=self.dig_type) 83 | im_out[:, :, i] = im_channel 84 | 85 | return im_out 86 | 87 | def _undigitize(self, im): 88 | """ 89 | Transforms from the integer type used by LUT to the skimage transform output type (self.tr_type). 90 | :param im: 91 | :return: 92 | """ 93 | 94 | im_out = np.array(im, dtype=self.tr_type) 95 | im = np.moveaxis(im, [0, 1, 2], [1, 2, 0]) 96 | for i, im_channel in enumerate(im): 97 | factor = (self.tr_max[i] - self.tr_min[i]) / np.float(self.max_intensity) 98 | im_channel = im_channel * factor + self.tr_min[i] 99 | im_channel = np.clip(im_channel, self.tr_min[i], self.tr_max[i]) 100 | im_channel = np.array(im_channel, dtype=self.tr_type) 101 | im_out[:, :, i] = im_channel 102 | 103 | return im_out 104 | 105 | def forward(self, im): 106 | ''' 107 | im: input ndarray image (can be float or uint) 108 | return: image integer type (self.dig_type) 109 | 110 | ''' 111 | self.input_type = im.dtype 112 | self.max_input_type = np.iinfo(self.input_type).max 113 | 114 | im2 = self.forward_transform(im) 115 | assert im2.dtype == self.tr_type 116 | 117 | im3 = self._digitize(im2) 118 | 119 | return im3 120 | 121 | def backward(self, im): 122 | """ 123 | Backward color transformation from self.dig_type to the type of the original input image that passed forward() transformation 124 | :param im: 125 | :return: 126 | """ 127 | 128 | im1 = self._undigitize(im) 129 | im2 = self.backward_transform(im1) 130 | 131 | # output of skimage.color transforms are floats in [0, 1] range 132 | im2 = np.clip(im2, 0.0, 1.0) # however, values may in practice be higher than 1 133 | im3 = np.array(im2 * self.max_input_type, dtype=self.input_type) 134 | 135 | return im3 136 | 137 | 138 | class HistImagesDataset(Dataset): 139 | 140 | def __init__(self, *imsets, n_samples=None, repetition=False, transform=None, dataset_name=''): 141 | 142 | """ 143 | The dataset is created from any number of image sets (imsets) defined with a class label 144 | and a path (location) to images (see example of use in check_up() function). Several paths can belong to the same class. 145 | Images and locations are chosen to be read in a random order. The sampling rate for each location is determined 146 | by the number of images relative to the other locations or n_samples if given. 147 | Output images are PIL images (unless 'transform' changes the type). 148 | Images with '_empty' suffix are ignored. 149 | Images are read e,g. by indexing as dataset[i], which outputs a dictionary with 'image', 'label', 'string_label', and 'image_name' 150 | fields. 'label' field contains a numeric value of a label, while 'string_label' is an string label taken from 'imsets'. 151 | The string_label provided with the input imsets can also be accessed with instance.get_str_label(label) 152 | 153 | :param imsets: sequence of dictionaries with fields 'folder', 'label, 'ext', and optionally pattern 154 | 'key' (rules used by the UNIX shell e.g. '*67?', '10[1-2,5-7]') that should be a part of a file name of all files of interest. 155 | :param n_samples: a number of images to read from each folder (dictionary, location) or a list with different number of samples for each folder (dictionary, location). 156 | It is used to restrict the number of images in the case of too large or unbalanced datasets. 157 | You can use 'samples_per_location_from_samples_per_class' function defined in this module for 158 | defining the same number of images for each class (when number of locations is larger than the number of classes) 159 | :param repetition: effective only when n_samples is larger than the number of images for in a folder (location). 160 | When False n_samples is reduced to the number of images in the folder, if True n_samples will be sampled from the folder with repetitions. 161 | :param transform: pytorch transform to be applied on each read PIL image 162 | :param dataset_name: the name of the dataset that will be used for logging 163 | 164 | split_set(): method can be used to divide the dataset to two disjoints datasets 165 | (see example of use in check_up() function). Alternatively, torch.utils.data.random_split 166 | can be used. 167 | 168 | create_subset(): Creates a subset of dataset. The original set is not changed. 169 | 170 | prepare_mixup(): Computes distribution for each class and switches on mix-up normalization - 171 | transfer of color appearance of images to randomly chosen class 172 | 173 | """ 174 | if dataset_name: 175 | self.logger = logging.getLogger('HistImagesDataset' + ':' + dataset_name) 176 | else: 177 | self.logger = logging.getLogger('HistImagesDataset') 178 | 179 | # if given n_samples is integer create list with repeated 'n_samples' for each location 180 | if n_samples: 181 | if isinstance(n_samples, int): 182 | n_samples = [n_samples] * len(imsets) 183 | else: 184 | n_samples = list(n_samples[:]) # copies values in order to avoid changing argument of calling function. List command allows an argument to be either list or tuple 185 | 186 | assert isinstance(n_samples, list), 'n_samples keyward should be either integer or a list of numbers' 187 | assert len(n_samples) == len(imsets), 'n_samples should be of a length equal to the number of image sets' 188 | 189 | self.str_labels = [] # set of used string labels in the dataset (the length equals to the number of classes, or number of different labels) 190 | self.file_names = [] 191 | self.labels = [] # integer labels for every image file 192 | n_samples_per_class = [] 193 | for i, imset in enumerate(imsets, 0): 194 | 195 | if 'key' in imset.keys(): 196 | key = imset['key'] 197 | if key is None or len(key) == 0: 198 | key = '*' 199 | else: 200 | key = '*' 201 | 202 | # reading file names for every location with appropriate file extensions and matching key 203 | pattern = imset['folder'] + '/' + key + '.' + imset['ext'] 204 | # alphabetic sorting is for reproducability purposes only 205 | file_names = sorted(glob.glob(pattern)) 206 | 207 | # exclude those that do not include '_empty' suffix 208 | file_names = [x for x in file_names if ('_empty.' + imset['ext']) not in x] 209 | 210 | if len(file_names) == 0: 211 | self.logger.error('no files were found in: {}'.format(imset['folder'])) 212 | else: 213 | self.logger.debug('file names in {} were successfully read'.format(imset['folder'])) 214 | 215 | # random sampling of file names (useful in the case of large or unbalanced datasets) 216 | if n_samples: 217 | if n_samples[i] > len(file_names): 218 | if not repetition: 219 | self.logger.warning('number of requested samples ({}) is larger then the size of the dataset ({}) in: {}. Using all available samples.'.format(n_samples[i], len(file_names), imset['folder'])) 220 | n_samples[i] = len(file_names) 221 | else: 222 | self.logger.warning('number of requested samples ({}) is larger then the size of the dataset ({}) in: {}. Mupliple copies of the same images will be used.'.format(n_samples[i], len(file_names), imset['folder'])) 223 | 224 | file_names = sample_with_possible_repetition(file_names, n_samples[i]) 225 | 226 | self.file_names += file_names 227 | 228 | if imset['label'] in self.str_labels: # label already appeared 229 | idx_class = self.str_labels.index(imset['label']) 230 | n_samples_per_class[idx_class] += len(file_names) 231 | else: # new label 232 | self.str_labels.append(imset['label']) 233 | idx_class = len(self.str_labels) - 1 234 | assert idx_class == self.str_labels.index(imset['label']) 235 | n_samples_per_class.append(len(file_names)) 236 | 237 | self.labels += [idx_class] * len(file_names) 238 | 239 | self.len = len(self.labels) 240 | 241 | # randomization of the order of the samples from different locations to be used in __getitem__() 242 | self.idx = self._random_index() 243 | 244 | self.transform = transform 245 | 246 | self.transform_lut = None # transformation between class intensities is not defined 247 | 248 | self.ImReader = Image.open 249 | 250 | self.n_channels_mixup = [True, True, True] # for mixup augmentaiton make transform (adapt appearance) for channels with True value only 251 | 252 | self.color_transform = color_transform_skimage() 253 | 254 | self.logger.info( 255 | 'mapping dataset was initialized. {} classes, {} images per class'.format(self.str_labels, n_samples_per_class)) 256 | 257 | assert self.get_number_samples_per_class() == n_samples_per_class 258 | 259 | 260 | def _random_index(self): # generates random numbers up to the length of the dataset self.len 261 | return random.sample(range(self.len), self.len) 262 | 263 | def get_str_label(self, int_label): 264 | return self.str_labels[int_label] 265 | 266 | def get_int_label(self, string): 267 | return self.str_labels.index(string) 268 | 269 | def get_number_samples_per_class(self): 270 | 271 | num_samples_per_class = [] 272 | for str_label in self.str_labels: 273 | num_samples_per_class.append(sum([self.get_str_label(int_label) == str_label for int_label in self.labels])) 274 | 275 | return num_samples_per_class 276 | 277 | def compute_statistics(self): 278 | 279 | means = 0 280 | stds = 0 281 | n_patches_read = 0 282 | number_of_classes = len(self.str_labels) 283 | class_mean = [np.zeros((3)) for label in range(number_of_classes)] 284 | class_std = [np.zeros((3)) for label in range(number_of_classes)] 285 | n_class_patches_read = [0 for label in range(number_of_classes)] 286 | self.logger.info("computing statistics of the train dataset") 287 | for i in tqdm(range(len(self))): 288 | 289 | im, label, _ = self._get_raw_image(i) 290 | 291 | # From PIL to ndarray 292 | im = np.array(im) 293 | 294 | mean = np.mean(im, axis=(0, 1)) 295 | std = np.std(im, axis=(0, 1)) 296 | 297 | means += mean 298 | stds += std 299 | n_patches_read += 1 300 | 301 | class_mean[label] += mean 302 | class_std[label] += std 303 | n_class_patches_read[label] += 1 304 | 305 | means = means / n_patches_read 306 | stds = stds / n_patches_read 307 | 308 | class_means = {} 309 | class_stds = {} 310 | for label in range(number_of_classes): 311 | str_label = self.get_str_label(label) 312 | class_means[str_label] = class_mean[label] / n_class_patches_read[label] 313 | class_stds[str_label] = class_std[label] / n_class_patches_read[label] 314 | self.logger.info('average of {} class is {}, average of std is {}'.format(str_label, class_means[str_label], class_stds[str_label])) 315 | 316 | self.logger.info('average value: {}, std value: {}, based on {} images'.format(means, stds, n_patches_read)) 317 | 318 | return means, stds, class_means, class_stds 319 | 320 | 321 | def _mixup_normalization(self, im, src_cls): 322 | """ 323 | :param im: PIL image to be normalized to appearance of a random class 324 | :param src_cls: class (int) of the input image 325 | :return: transformed_im: normalized PIL image 326 | :return: dst_class: class (int) the image appearance was transformed to 327 | """ 328 | 329 | n_cls = len(self.str_labels) 330 | 331 | # ranomly chose class for allowed destination class appearance 332 | allowed_dst_class_labs = self._allowed_dst_class(src_cls) 333 | dst_class_idx = random.randrange(len(allowed_dst_class_labs)) 334 | dst_class = allowed_dst_class_labs[dst_class_idx] 335 | 336 | im = np.array(im) # transform PIL to ndarray 337 | 338 | # use color transformation to get channels that will be normalized to the channels of destination class 339 | im = self.color_transform.forward(im) 340 | 341 | n_channels = im.shape[2] 342 | assert n_channels == len(self.n_channels_mixup) 343 | assert n_channels == self.transform_lut.shape[2] 344 | 345 | 346 | transformed_im = np.array(im, copy=True) # copy of an array 347 | 348 | # make per channel normalizations 349 | for channel in range(n_channels): 350 | if self.n_channels_mixup[channel]: 351 | im_channel = im[:,:,channel] 352 | transformed_im[:,:,channel] = self.transform_lut[src_cls, dst_class, channel][im_channel] 353 | 354 | # transform back normalized channels 355 | transformed_im = self.color_transform.backward(transformed_im) 356 | # output PIL image 357 | transformed_im = Image.fromarray(transformed_im) 358 | return transformed_im, dst_class 359 | 360 | def _allowed_dst_class(self, src_cls): 361 | """ 362 | Outputs allowed integer class labels for given input integer label for mix-up augmentation 363 | (based on label groups in self.mixup_classes) 364 | :param src_cls: integer class label 365 | :return: dst_cls: list of integer class labels 366 | """ 367 | 368 | dst_cls = None 369 | for classes in self.mixup_classes: 370 | if src_cls in classes: 371 | dst_cls = classes 372 | break 373 | 374 | assert dst_cls is not None 375 | return dst_cls 376 | 377 | 378 | def _list_of_lists_strlab_2_list_of_lists_intlab(self, classes_str): 379 | """ 380 | :param classes_str: list of lists that contains string class labels. 381 | :return: list of lists that contains ineger class labels 382 | """ 383 | classes_int = [] 384 | for lst in classes_str: 385 | classes_int.append([]) 386 | for cls in lst: 387 | classes_int[-1].append(self.get_int_label(cls)) 388 | 389 | return classes_int 390 | 391 | def prepare_mixup(self, mixup_classes=None): 392 | 393 | """ 394 | Prepares normalization transformation between every pair of classes, per channel 395 | 396 | :param mixup_classes: list of lists, each of which contains string labels of classes between which mixup will be done 397 | 398 | """ 399 | 400 | # set groups of mixup classes 401 | if mixup_classes is None: 402 | self.mixup_classes = self._list_of_lists_strlab_2_list_of_lists_intlab([self.str_labels]) 403 | else: 404 | self.mixup_classes = self._list_of_lists_strlab_2_list_of_lists_intlab(mixup_classes) 405 | 406 | flattened_mixup_classes = [num for sublist in self.mixup_classes for num in sublist] 407 | assert len(flattened_mixup_classes) == len(self.str_labels) 408 | assert len(set(flattened_mixup_classes)) == len(self.str_labels) 409 | 410 | # compute cumulative distribution functions for all classes 411 | cdf = self._compute_cdfs() 412 | 413 | n_cls = len(cdf) 414 | n_intensities = len(cdf[0, 0]) 415 | n_channels = len(cdf[0]) 416 | 417 | # compute transformation-normalization - look up table per channel, for every pair of classes 418 | self.transform_lut = np.empty((n_cls, n_cls, n_channels, n_intensities)) 419 | for source_cls in range(n_cls): 420 | for dist_cls in range(n_cls): 421 | for channel in range(n_channels): 422 | self.transform_lut[source_cls, dist_cls, channel] = self._compute_cdfs_lut(cdf[source_cls, channel], cdf[dist_cls, channel]) 423 | 424 | self.logger.info('mix-up is prepared, all images will be transformed with mix-up') 425 | 426 | @staticmethod 427 | def _compute_cdfs_lut(cdf_source, cdf_target): 428 | """ 429 | Param: cdf_source, cdf_target cumulative distribution functions of classes to be transformed from and to 430 | Return: look up table for normalized source intensities to the target intensities 431 | """ 432 | 433 | assert len(cdf_source) == len(cdf_target) 434 | cdf_target_intensities = np.arange(len(cdf_target)) 435 | cdf_source_intensities = np.interp(cdf_source, cdf_target, cdf_target_intensities) 436 | 437 | return cdf_source_intensities 438 | 439 | 440 | def _compute_cdfs(self): 441 | """ 442 | Computes cumulative distribution function for each class and each channel 443 | """ 444 | 445 | n_classes = len(self.str_labels) 446 | n_channels = len(self.n_channels_mixup) 447 | n_intensity_levels = self.color_transform.get_n_intensity_levels() 448 | 449 | cdf = np.zeros((n_classes, n_channels, n_intensity_levels)) 450 | 451 | samples_per_class = np.zeros((n_classes, 1, 1)) 452 | 453 | self.logger.info('generation of set of class distributions for mix-up augmentaiton') 454 | # calculation average cumulative histogram over all images, per class, per channel 455 | for i in tqdm(range(len(self))): 456 | 457 | im, label, _ = self._get_raw_image(i) 458 | if im.mode != 'RGB': 459 | self.logger.warning("not RGB with 8bit depth were not yet tested for mix-up") 460 | 461 | # From PIL to ndarray 462 | im = np.array(im) 463 | im = self.color_transform.forward(im) 464 | assert np.iinfo(im.dtype).max + 1 == n_intensity_levels 465 | assert np.iinfo(im.dtype).min == 0 466 | 467 | samples_per_class[label] += 1 468 | 469 | im_size = np.prod(im.shape[:2]) 470 | 471 | for i in range(n_channels): 472 | histogram, _ = np.histogram(im[:,:, i].ravel(), bins=n_intensity_levels, range=(0, np.iinfo(im.dtype).max+1)) 473 | cumulative_histogram = np.cumsum(histogram) / im_size 474 | cdf[label, i] += cumulative_histogram 475 | 476 | 477 | # get average cumulative histogram (from sum histogram) over all images 478 | cdf = cdf / np.repeat(np.repeat(samples_per_class, n_channels, 1), n_intensity_levels, 2) 479 | assert np.all(np.squeeze(samples_per_class) == np.array(self.get_number_samples_per_class())) 480 | 481 | return cdf 482 | 483 | 484 | def __len__(self): 485 | return self.len 486 | 487 | def _try_read_image(self, path, trials=10): 488 | 489 | pause = 5.0 490 | 491 | for i in range(trials): # try to read an image a few times for the case of lost connection to a linked folder 492 | 493 | try: 494 | img = self.ImReader(path) 495 | except FileNotFoundError: 496 | if trials != 1: 497 | self.logger.warning('file {} was not found, probably lost connection to the linked data folder'.format(path)) 498 | self.logger.warning('{}/{} read trial was unsuccessful'.format(i + 1, trials)) 499 | self.logger.warning('Waiting for {} sec'.format(pause)) 500 | time.sleep(pause) 501 | self.logger.warning('resuming') 502 | img = None 503 | except RuntimeError as er: 504 | if trials != 1: 505 | self.logger.warning('{}/{} read trial was unsuccessful for {}'.format(i + 1, trials, path)) 506 | self.logger.warning('unknown error: {}'.format(er)) 507 | self.logger.warning('Waiting for {} sec'.format(pause)) 508 | time.sleep(pause) 509 | self.logger.warning('resuming') 510 | img = None 511 | else: 512 | break 513 | 514 | return img 515 | 516 | def _get_raw_image(self, n): 517 | 518 | idx = self.idx[n] 519 | img_path = self.file_names[idx] 520 | img_label = self.labels[idx] 521 | 522 | img = self._try_read_image(img_path) 523 | 524 | return img, img_label, img_path 525 | 526 | 527 | def __getitem__(self, n): 528 | 529 | img, img_label, img_path = self._get_raw_image(n) 530 | 531 | if self.transform_lut is not None: 532 | img, aug_dst = self._mixup_normalization(img, img_label) 533 | else: 534 | aug_dst = None 535 | 536 | if self.transform: 537 | img = self.transform(img) 538 | 539 | sample = {'image': img, 'label': img_label, 'string_label': self.get_str_label(img_label), 'image_name': img_path} 540 | 541 | # this is for debugging only 542 | if self.transform_lut is not None: 543 | sample['_debug_mixup_dst'] = aug_dst 544 | 545 | 546 | return sample 547 | 548 | def shuffle(self): 549 | random.shuffle(self.idx) 550 | 551 | def _get_shuffled_idx(self): 552 | idx = random.sample(self.idx, len(self.idx)) 553 | 554 | return idx 555 | 556 | 557 | 558 | def split_set(self, n_val, transform_validation='same', val_dataset_name='', train_dataset_name='', shuffle=False): 559 | 560 | """ 561 | Separates the dataset into two for training and validation. 562 | A separate part of the images in the original dataset will be taken for newly created (validation) dataset, 563 | the other part will be taken by the second created dataset (training). 564 | The original dataset is not changed. 565 | 566 | :param transform_validation: transformation to be applied on the created validaiton dataset. If not supplied or if 'same' 567 | the same transformation as for the original set is applied, if None no transformation will be applied 568 | :param n_val: the number of images in the validation dataset (overall number from all classes). 569 | :param val_dataset_name: name of validaiton dataset that willl be used for logging 570 | :param train_dataset_name: name of training dataset that will be used logging 571 | :param shuffle: If True, randomizes the choice of images taken for newly created datasets. 572 | If False, the first part of images will be taken from the dataset for created validation dataset, while the second part for the training datset. 573 | Ordering of images in the created datasets is also randomized if shuffle is True. shuffle=True might be useful if several different splitted datasets are required 574 | :return cls2: newly created separate dataset for validation 575 | :return cls1: newly created separate dataset for training 576 | """ 577 | 578 | assert n_val < len(self.idx), 'not enough images for validation' 579 | 580 | # creates two new identical datasets 581 | cls2 = copy.deepcopy(self) # validation dataset 582 | cls1 = copy.deepcopy(self) # training dataset 583 | 584 | # set transformation to validation dataset 585 | if transform_validation: 586 | if transform_validation == 'same': 587 | cls2.transform = self.transform 588 | else: 589 | cls2.transform = transform_validation 590 | else: 591 | cls2.transform = None 592 | 593 | #------- 594 | if shuffle: 595 | idx = self._get_shuffled_idx() 596 | else: 597 | idx = self.idx 598 | 599 | idx2 = idx[:n_val] 600 | idx1 = idx[n_val:] 601 | 602 | cls2.file_names = [cls2.file_names[i] for i in idx2] 603 | cls2.labels = [cls2.labels[i] for i in idx2] 604 | cls2.len = len(cls2.file_names) 605 | if shuffle: 606 | cls2.idx = cls2._random_index() 607 | else: 608 | cls2.idx = list(range(cls2.len)) 609 | 610 | cls1.file_names = [cls1.file_names[i] for i in idx1] 611 | cls1.labels = [cls1.labels[i] for i in idx1] 612 | cls1.len = len(cls1.file_names) 613 | if shuffle: 614 | cls1.idx = cls1._random_index() 615 | else: 616 | cls1.idx = list(range(cls1.len)) 617 | 618 | if val_dataset_name: 619 | cls2.logger = logging.getLogger('HistImagesDataset' + ':' + val_dataset_name) 620 | else: 621 | cls2.logger = logging.getLogger('HistImagesDataset') 622 | 623 | if train_dataset_name: 624 | cls1.logger = logging.getLogger('HistImagesDataset' + ':' + train_dataset_name) 625 | else: 626 | cls1.logger = logging.getLogger('HistImagesDataset') 627 | 628 | cls1.transform_lut = None 629 | cls2.transform_lut = None 630 | 631 | # # adding to all labels suffix: val 632 | # for i in range(len(self.str_labels)): 633 | # cls2.str_labels[i] += '_val' 634 | 635 | self.logger.info('dataset was splitted {} images in the validation and {} in the train datasets'.format(cls2.len, cls1.len)) 636 | 637 | cls1.logger.info('splitted training set was initialized. {} classes, {} images per class'.format(cls1.str_labels, cls1.get_number_samples_per_class())) 638 | cls2.logger.info('splitted validation set was initialized. {} classes, {} images per class'.format(cls2.str_labels, cls2.get_number_samples_per_class())) 639 | 640 | return cls2, cls1 641 | 642 | 643 | def samples_per_location_from_samples_per_class(*imsets, samples_per_class: int) -> list: 644 | """ 645 | :param imsets: sequence of dictionaries with fields 'folder' (locations/folders with the data) 'label' (for each location/folder) 646 | like in inititalizers of of the HistImagesDataset or WSIiterableDataset classes 647 | :param samples_per_class: number of samples to be taken for each class (label) 648 | :return: list of number of samples to be taken for each location/folder with input data 649 | """ 650 | 651 | assert isinstance(samples_per_class, int), "samples per_class must be integer" 652 | str_labels = [] 653 | counter = {} 654 | # counting the number of locations for each class 655 | for location in imsets: 656 | 657 | str_label = location['label'] 658 | if str_label not in str_labels: 659 | str_labels.append(str_label) 660 | counter[str_label] = 1 661 | else: 662 | counter[str_label] += 1 663 | 664 | samples_per_location = [] 665 | for location in imsets: 666 | str_label = location['label'] 667 | samples_per_location.append(round(samples_per_class / counter[str_label])) 668 | 669 | return samples_per_location 670 | 671 | 672 | def sample_with_possible_repetition(file_names: list, n_samples: int) -> list: 673 | """ 674 | :param file_names: list 675 | :param n_samples: number of samples from list to be taken. If the number of samples 'n_samples' is larger than the 676 | length of the list, output list will have multiple copies of elements from the input 'file_names' list 677 | :return: list of sampled elements from 'file_names' list 678 | """ 679 | 680 | n_elements = len(file_names) 681 | n_cycles = n_samples // n_elements 682 | n_samples_last = n_samples - n_cycles * n_elements 683 | res = [] 684 | for i in range(n_cycles): 685 | res += random.sample(file_names, n_elements) 686 | 687 | res += random.sample(file_names, n_samples_last) 688 | 689 | return res 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | --------------------------------------------------------------------------------