├── .gitignore ├── documentation ├── experiment_walkthrough.md ├── nnOOD_overview_v2.png └── synthetic_task_guide.md ├── nnood ├── __init__.py ├── configuration.py ├── data │ ├── __init__.py │ ├── dataset_conversion │ │ ├── __init__.py │ │ ├── chestxray14_lists │ │ │ ├── BBox_List_2017.csv │ │ │ ├── anomaly_FemaleAdultPA_test_list.txt │ │ │ ├── anomaly_MaleAdultPA_test_list.txt │ │ │ ├── norm_FemaleAdultPA_train_list.txt │ │ │ └── norm_MaleAdultPA_train_list.txt │ │ ├── convert_chestxray14.py │ │ ├── convert_mvtec.py │ │ └── utils.py │ ├── readme.md │ └── sanity_checks.py ├── evaluation │ ├── __init__.py │ ├── evaluator.py │ ├── metrics.py │ ├── nnOOD_evaluate_folder.py │ ├── nnOOD_run_testing.py │ └── readme.md ├── experiment_planning │ ├── DatasetAnalyser.py │ ├── __init__.py │ ├── experiment_planner.py │ ├── modality_conversion.py │ ├── nnOOD_plan_and_preprocess.py │ ├── nnOOD_update_plans_number.py │ └── utils.py ├── inference │ ├── __init__.py │ ├── export_utils.py │ ├── model_restore.py │ ├── predict.py │ └── predict_simple.py ├── network_architecture │ ├── __init__.py │ ├── generic_UNet.py │ ├── initialisation.py │ └── neural_network.py ├── paths.py ├── preprocessing │ ├── __init__.py │ ├── foreground_mask.py │ ├── normalisation.py │ └── preprocessing.py ├── self_supervised_task │ ├── __init__.py │ ├── cutpaste.py │ ├── fpi.py │ ├── nsa.py │ ├── nsa_utils.py │ ├── opencv_nsa.py │ ├── opencv_pii.py │ ├── patch_blender.py │ ├── patch_ex.py │ ├── patch_labeller.py │ ├── patch_shape_maker.py │ ├── patch_transforms │ │ ├── __init__.py │ │ ├── base_transform.py │ │ ├── colour_transforms.py │ │ └── spatial_transforms.py │ ├── pii.py │ ├── rect_fpi.py │ ├── self_sup_task.py │ └── utils.py ├── training │ ├── __init__.py │ ├── data_augmentation │ │ ├── __init__.py │ │ ├── custom_transforms.py │ │ ├── default_data_augmentation.py │ │ └── downsampling.py │ ├── dataloading │ │ ├── __init__.py │ │ └── dataset_loading.py │ ├── loss_functions │ │ ├── __init__.py │ │ └── deep_supervision.py │ ├── network_training │ │ ├── __init__.py │ │ ├── network_trainer.py │ │ ├── nnOODTrainer.py │ │ └── nnOODTrainerDS.py │ └── nnOOD_run_training.py └── utils │ ├── __init__.py │ ├── default_configuration.py │ ├── file_operations.py │ ├── miscellaneous.py │ └── to_torch.py ├── notebooks ├── dataloader_test.ipynb ├── mvtec_obj_stats ├── object_mask_helper.ipynb ├── patch_interpolation_helper_test.ipynb ├── readme.md ├── results_notebooks │ ├── chestxray14.ipynb │ └── mvtec.ipynb ├── sampling_tests.ipynb ├── self_sup_task_visualiser.ipynb └── trainer_test.ipynb ├── readme.md └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /documentation/experiment_walkthrough.md: -------------------------------------------------------------------------------- 1 | 2 | # Experiment walkthrough 3 | 4 | Before running any programs, define nnOOD's enviroment variables: 5 | - `nnood_raw_data_base` - Folder containing raw data of each dataset. 6 | - `nnood_preprocessed_data_base` - Folder containing preprocessed datasets and experiment plans. 7 | - `nnood_results_base` - Folder containing outputs of experiments: trained models, logs and any test results. 8 | 9 | ## Dataset conversion 10 | 11 | nnOOD uses a similar dataset structure to nnU-Net. Each dataset has it's own folder within the raw dataset folder, 12 | `$nnood_raw_data_base/DATASET_NAME`, which must contain: 13 | - `imagesTr/` folder, containing normal training images. Images must follow naming convention 14 | `_MMMM.[png|nii.gz]`, where MMMM is the modality number. 15 | - `imagesTs` folder, containing test images, following the same naming convention as `imagesTr`. 16 | - `labelsTs` folder, containing the labels for the test images. Test images must follow naming convention 17 | `.[png|nii.gz]`. If a label is not provided for a test image, it is assumed to be entirely normal. 18 | - `dataset.json` file, describing dataset, described in `nnood/data/readme.md` 19 | - We recommend using `nnood/data/dataset_conversion/utils.generate_dataset_json` to produce this. 20 | 21 | ## Experiment planning and preprocessing 22 | 23 | As with nnU-Net, nnOOD xtracts a dataset fingerprint, creates an experiment plan and preprocesses the dataset ready for 24 | training. 25 | 26 | ```bash 27 | python /nnood/experiment_planning/nnOOD_plan_and_preprocess.py -d DATASET_NAME --verify_dataset_integrity 28 | ``` 29 | 30 | This stores the experiment plans and data within `$nnood_preprocessed_data_base/DATASET_NAME`. 31 | The experiment plan is independent to the self-supervised tasks. 32 | 33 | ## Model training 34 | 35 | nnOOD uses 5-fold cross-validation to try to get a more consistent measurement of a task on a given dataset. 36 | To train one fold of this proces, use the command: 37 | 38 | ```bash 39 | python /nnood/training/nnOOD_run_training.py CONFIGURATION TRAINER_CLASS_NAME DATASET_NAME TASK_NAME FOLD 40 | ``` 41 | 42 | CONFIGURATION can currently only be `lowres` or `fullres`, depending on whether the dataset recommends one or two 43 | stages, as at this point we have not been able to implement cascade training. 44 | 45 | TRAINER_CLASS_NAME chooses between `nnOODTrainer` and `nnOODTrainerDS`, which determines whether deep supervision is 46 | used, although based on nnU-Net's results we recommend using `nnOODTrainerDS` unless the user needs to be especially 47 | careful with memory consumption. 48 | 49 | TASK_NAME is the name of the class which extends and implements `SelfSupTask` (such as `FPI`, `CutPaste`, etc) and 50 | must be stored within `nnood/self_supervised_task`. 51 | 52 | Other options can be viewed using the `-h` flag. 53 | 54 | The outputs of this experiment are stored in: 55 | ``` 56 | $nnood_results_base/DATASET_NAME/TASK_NAME/CONFIGURATION/_/fold_FOLD/ 57 | ``` 58 | where `default_plans_identifier` is defined in `nnood/paths.py`. 59 | 60 | The outputs are: 61 | - `training_log_.txt` - log of outputs during experiment 62 | - `model_(best|final_checkpoint).(model|pkl)` - files containing trained model and training progress 63 | 64 | ## Testing models 65 | 66 | Test the ensemble of the 5 models trained as part of the cross-validation on a certain task with the command: 67 | ```bash 68 | python /nnood/evaluation/nnOOD_run_testing.py -d DATASET_NAME -t TASK_NAME -tr TRAINER_CLASS_NAME 69 | ``` 70 | 71 | Other options can be viewed using the `-h` flag. 72 | 73 | The models are evaluated on the datasets test set, given in `imagesTs` and `labelsTs`, as described in 74 | [dataset conversion](#dataset-conversion). 75 | 76 | This computes the following: 77 | - A prediction for each test image, saved as `.[png|nii.gz]` 78 | - `summary.json` - a record of the metrics on the test dataset (AUROC, AP) along with a timestamp for when the test 79 | took place. 80 | 81 | All are saved to: 82 | ``` 83 | $nnood_results_base/DATASET_NAME/TASK_NAME/testResults// 84 | ``` -------------------------------------------------------------------------------- /documentation/nnOOD_overview_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matt-baugh/nnOOD/a953bcad86c59cd016169141a24631cc3ded02ff/documentation/nnOOD_overview_v2.png -------------------------------------------------------------------------------- /documentation/synthetic_task_guide.md: -------------------------------------------------------------------------------- 1 | 2 | # Synthetic task guide 3 | 4 | All self-supervised tasks must extend and implement SelfSupTask, defined in 5 | [nnood/self_supervised_task/self_sup_task.py](../nnood/self_supervised_task/self_sup_task.py) as having an `apply`, `loss`, and `calibrate` function (details in docstrings of 6 | methods). 7 | 8 | To aid in experimenting with patch-blending tasks we have decomposed the existing tasks, making it is easier to swap 9 | out individual parts of the task to understand which are contributing the most to performance. 10 | 11 | The classes of component are: 12 | - `PatchShapeMaker` - returns a mask which is used to extract the source patch 13 | - `PatchTransforms` - applies a transform to the extracted patch, which can be either spatial or altering the content 14 | of the patch. A list of these is used to define the task. 15 | - `PatchBlender` - integrates the source patch into the target image at the given location. 16 | - `PatchLabeller` - given the original image and the altered image, compute the pixel-wise label for the image. 17 | 18 | These components are combined using `nnood/self_supervised_task/patch_ex.patch_ex`, which also has a number of other 19 | parameters such as the number of anomalies introduced (see the docstring for full details). 20 | 21 | Here are some example tasks reimplemented using this framework: 22 | - [Foreign Patch Interpolation](../nnood/self_supervised_task/fpi.py) 23 | - [CutPaste](../nnood/self_supervised_task/cutpaste.py) 24 | - [Poisson Image Interpolation](../nnood/self_supervised_task/pii.py) 25 | - [Natural Synthetic Anomalies](../nnood/self_supervised_task/nsa.py) (both source and mixed gradient variants). 26 | -------------------------------------------------------------------------------- /nnood/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/configuration.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | THREAD_NUM_VAR = 'nnood_def_n_proc' 4 | 5 | default_num_processes = int(os.environ[THREAD_NUM_VAR]) if THREAD_NUM_VAR in os.environ else 8 6 | RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD = 3 7 | -------------------------------------------------------------------------------- /nnood/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/data/dataset_conversion/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/data/dataset_conversion/convert_chestxray14.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | from typing import Union, Optional 5 | 6 | import cv2 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from nnood.data.dataset_conversion.utils import generate_dataset_json 11 | from nnood.paths import raw_data_base, DATASET_JSON_FILE 12 | 13 | # Dataset available at: 14 | # https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/36938765345 15 | 16 | script_dir = Path(os.path.realpath(__file__)).parent 17 | 18 | xray_list_dir = script_dir / 'chestxray14_lists' 19 | 20 | train_male_list_path = xray_list_dir / 'norm_MaleAdultPA_train_list.txt' 21 | test_male_list_path = xray_list_dir / 'anomaly_MaleAdultPA_test_list.txt' 22 | 23 | train_female_list_path = xray_list_dir / 'norm_FemaleAdultPA_train_list.txt' 24 | test_female_list_path = xray_list_dir / 'anomaly_FemaleAdultPA_test_list.txt' 25 | 26 | bbox_data_file_path = xray_list_dir / 'BBox_List_2017.csv' 27 | bbox_csv = pd.read_csv(bbox_data_file_path, index_col=0, usecols=['Image Index', 'Bbox [x', 'y', 'w', 'h]']) 28 | 29 | train_test_dict = { 30 | 'male': (train_male_list_path, test_male_list_path), 31 | 'female': (train_female_list_path, test_female_list_path) 32 | } 33 | 34 | 35 | def organise_xray_data(in_dir: Union[str, Path], data_type: str): 36 | 37 | in_dir_path = Path(in_dir) 38 | assert in_dir_path.is_dir(), 'Not a valid directory: ' + in_dir 39 | 40 | train_list_path, test_list_path = train_test_dict[data_type] 41 | 42 | out_dir_path = Path(raw_data_base) / f'chestXray14_PA_{data_type}' 43 | out_dir_path.mkdir(parents=True, exist_ok=True) 44 | 45 | out_train_path = out_dir_path / 'imagesTr' 46 | out_train_path.mkdir(parents=True, exist_ok=True) 47 | 48 | def load_and_save(f_name: str, curr_out_dir_path: Path, mask_out_dir_path: Optional[Path]): 49 | f_path = in_dir_path / f_name 50 | assert f_path.is_file(), f'Missing file: {f_path}' 51 | 52 | opencv_img = cv2.imread(f_path.__str__(), cv2.IMREAD_GRAYSCALE) 53 | assert len(opencv_img.shape) == 2, 'Greyscale shape not 2?? ' + opencv_img.shape 54 | 55 | sample_id = f_path.stem 56 | 57 | f_out_path = curr_out_dir_path / (sample_id + '_0000.png') 58 | cv2.imwrite(f_out_path.__str__(), opencv_img) 59 | 60 | if mask_out_dir_path is not None: 61 | assert f_name in bbox_csv.index, 'Missing bbox data for ' + f_name 62 | sample_mask = np.zeros_like(opencv_img) 63 | 64 | curr_mask_data = bbox_csv.loc[f_name].to_numpy().round().astype(int) 65 | if len(curr_mask_data.shape) == 1: 66 | curr_mask_data = [curr_mask_data] 67 | 68 | for bbox_x, bbox_y, bbox_w, bbox_h in curr_mask_data: 69 | sample_mask[bbox_y: bbox_y + bbox_h, bbox_x: bbox_x + bbox_w] = 1 70 | 71 | mask_path = mask_out_dir_path / f_name 72 | cv2.imwrite(mask_path.__str__(), sample_mask) 73 | 74 | return sample_id 75 | 76 | # # Load healthy files 77 | # with open(train_list_path, 'r') as train_list_file: 78 | # for f in train_list_file.readlines(): 79 | # load_and_save(f.strip(), out_train_path, None) 80 | 81 | out_test_path = out_dir_path / 'imagesTs' 82 | out_test_path.mkdir(parents=True, exist_ok=True) 83 | 84 | out_test_labels_path = out_dir_path / 'labelsTs' 85 | out_test_labels_path.mkdir(parents=True, exist_ok=True) 86 | 87 | # # Load test files 88 | # with open(test_list_path, 'r') as test_list_file: 89 | # for f in test_list_file.readlines(): 90 | # f = f.strip() 91 | # # Only include in test set if has bounding box 92 | # if f in bbox_csv.index: 93 | # load_and_save(f.strip(), out_test_path, out_test_labels_path) 94 | 95 | data_augs = { 96 | 'scaling': {'scale_range': [0.97, 1.03]} 97 | } 98 | 99 | generate_dataset_json(out_dir_path / DATASET_JSON_FILE, out_train_path, out_test_path, ('png-xray',), 100 | out_dir_path.name, 101 | dataset_description='Images from the NIH Chest X-ray dataset; limited to posteroanterior ' 102 | f'views of {data_type} adult patients (over 18), with the test set only ' 103 | 'including patients which had a bounding box provided.', 104 | data_augs=data_augs) 105 | 106 | # CHANGE THESE TO MATCH YOUR DATA!!! 107 | organise_xray_data('/vol/biodata/data/chest_xray/ChestXray-NIHCC/images', 'male') 108 | organise_xray_data('/vol/biodata/data/chest_xray/ChestXray-NIHCC/images', 'female') 109 | -------------------------------------------------------------------------------- /nnood/data/dataset_conversion/convert_mvtec.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import sys 3 | import os 4 | from pathlib import Path 5 | import shutil 6 | 7 | from nnood.paths import raw_data_base, DATASET_JSON_FILE 8 | from nnood.data.dataset_conversion.utils import generate_dataset_json 9 | 10 | CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 11 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 12 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper'] 13 | 14 | OBJECTS = ['bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut', 15 | 'pill', 'screw', 'toothbrush', 'transistor', 'zipper'] 16 | TEXTURES = ['carpet', 'grid', 'leather', 'tile', 'wood'] 17 | 18 | HAS_UNIFORM_BACKGROUND = ['bottle', 'capsule', 'hazelnut', 'metal_nut', 'pill', 'screw', 'toothbrush', 'zipper'] 19 | 20 | # Bottle is aligned but it's symmetric under rotation 21 | UNALIGNED_OBJECTS = ['bottle', 'hazelnut', 'metal_nut', 'screw'] 22 | 23 | H_FLIP = TEXTURES + ['bottle', 'hazelnut', 'toothbrush', 'transistor', 'zipper'] 24 | V_FLIP = TEXTURES + ['bottle', 'hazelnut'] 25 | ROTATE = UNALIGNED_OBJECTS + ['leather', 'tile', 'grid'] 26 | SMALL_ROTATE = ['carpet', 'wood'] 27 | 28 | GREYSCALE = ['grid', 'screw', 'zipper'] 29 | 30 | 31 | for g in [OBJECTS, TEXTURES, UNALIGNED_OBJECTS, H_FLIP, V_FLIP, ROTATE, SMALL_ROTATE]: 32 | assert all([e in CLASS_NAMES for e in g]), f'Element of {g} not in CLASS_NAMES' 33 | 34 | 35 | def organise_class(in_dir: Union[str, Path]): 36 | 37 | assert os.path.isdir(in_dir), 'Not a valid directory: ' + in_dir 38 | in_dir_path = Path(in_dir) 39 | 40 | in_train_path = in_dir_path / 'train' / 'good' 41 | assert in_train_path.is_dir() 42 | 43 | in_test_examples_path = in_dir_path / 'test' 44 | assert in_test_examples_path.is_dir() 45 | 46 | in_test_labels_path = in_dir_path / 'ground_truth' 47 | assert in_test_labels_path.is_dir() 48 | 49 | test_dirs = [d for d in in_test_examples_path.iterdir() if d.is_dir()] 50 | assert len(test_dirs) > 1, 'Test must include good and bad examples' 51 | 52 | object_class = in_dir_path.name 53 | 54 | out_dir_path = Path(raw_data_base) / ('mvtec_ad_' + object_class) 55 | out_dir_path.mkdir(parents=True, exist_ok=True) 56 | 57 | out_train_path = out_dir_path / 'imagesTr' 58 | out_train_path.mkdir(parents=True, exist_ok=True) 59 | 60 | # Copy normal training data 61 | for f in in_train_path.iterdir(): 62 | file_name = f.name 63 | number, ext = file_name.split('.') 64 | shutil.copy(f, out_train_path / f'normal_{number}_0000.{ext}') 65 | 66 | out_test_path = out_dir_path / 'imagesTs' 67 | out_test_path.mkdir(parents=True, exist_ok=True) 68 | 69 | out_test_labels_path = out_dir_path / 'labelsTs' 70 | out_test_labels_path.mkdir(parents=True, exist_ok=True) 71 | 72 | # Copy testing data 73 | for d in test_dirs: 74 | folder_name = d.name 75 | 76 | if folder_name != 'good': 77 | # Check labels folder exists 78 | test_class_label_dir = in_test_labels_path / folder_name 79 | assert test_class_label_dir.is_dir(), 'Missing labels folder: ' + test_class_label_dir.__str__() 80 | 81 | for f in d.iterdir(): 82 | file_name = f.name 83 | id_num, ext = file_name.split('.') 84 | 85 | if folder_name != 'good': 86 | # Verify and copy label for test example 87 | test_label_path = in_test_labels_path / folder_name / f'{id_num}_mask.{ext}' 88 | shutil.copy(test_label_path, out_test_labels_path / f'{folder_name}_{id_num}.{ext}') 89 | 90 | shutil.copy(f, out_test_path / f'{folder_name}_{id_num}_0000.{ext}') 91 | 92 | data_augs = {} 93 | 94 | if object_class in ROTATE: 95 | data_augs['rotation'] = {'rot_max': 5} 96 | elif object_class in SMALL_ROTATE: 97 | data_augs['rotation'] = {'rot_max': 2} 98 | 99 | if object_class in H_FLIP or object_class in V_FLIP: 100 | axes = [] 101 | if object_class in V_FLIP: 102 | axes.append(0) 103 | 104 | if object_class in H_FLIP: 105 | axes.append(1) 106 | 107 | data_augs['mirror'] = {'mirror_axes': axes} 108 | 109 | png_type = 'png-bw' if object_class in GREYSCALE else 'png' 110 | 111 | generate_dataset_json(out_dir_path / DATASET_JSON_FILE, out_train_path, out_test_path, (png_type,), in_dir_path.name, 112 | licence='CC BY-NC-SA 4.0', 113 | dataset_description='Images from the MVTec Anomaly Detection Dataset for the class ' + 114 | in_dir_path.name, dataset_reference='MVTec Software GmbH', 115 | dataset_release='1.0 16/04/2021', data_augs=data_augs, 116 | has_uniform_background=object_class in HAS_UNIFORM_BACKGROUND) 117 | 118 | 119 | if __name__ == '__main__': 120 | # Folder of image class, or root mvtec dataset folder 121 | in_root_dir: str = sys.argv[1] 122 | 123 | if len(sys.argv) == 3 and sys.argv[2] == 'full_dataset': 124 | print('Processing entire MVTec AD Dataset') 125 | in_root_path = Path(in_root_dir) 126 | 127 | for in_class_dir in in_root_path.iterdir(): 128 | if not in_class_dir.is_dir(): 129 | continue 130 | 131 | print(f'Processing {in_class_dir}...') 132 | 133 | organise_class(in_class_dir) 134 | 135 | else: 136 | print('Processing single class...') 137 | organise_class(in_root_dir) 138 | print('Done!') 139 | -------------------------------------------------------------------------------- /nnood/data/dataset_conversion/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | from pathlib import Path 3 | 4 | from nnood.utils.file_operations import save_json 5 | 6 | 7 | def generate_dataset_json(output_file: Union[str, Path], train_dir: Union[str, Path], 8 | test_dir: Optional[Union[str, Path]], modalities: Tuple[str], dataset_name: str, 9 | licence: str = 'hands off!', dataset_description: str = '', dataset_reference='', 10 | dataset_release: str = '0.0', data_augs: dict = {}, has_uniform_background: bool = False): 11 | """ 12 | :param output_file: This needs to be the full path to the dataset.json you intend to write, so 13 | output_file='DATASET_PATH/dataset.json' where the folder DATASET_PATH points to is the one with the 14 | imagesTr and optional imagesTs. 15 | :param train_dir: Path to the imagesTr folder of that dataset. 16 | :param test_dir: Path to the imagesTs folder of that dataset. Can be None. 17 | :param modalities: Tuple of strings with modality names. must be in the same order as the images (first entry 18 | corresponds to _0000.nii.gz, etc). Example: ('T1', 'T2', 'FLAIR'). 19 | :param dataset_name: The name of the dataset. 20 | :param licence: Dataset licence. 21 | :param dataset_description: Brief description of dataset. 22 | :param dataset_reference: Website of the dataset, if available. 23 | :param dataset_release: 24 | :param data_augs: Dictionary of valid data augmentations for data (which keep samples within normal distribution). 25 | :param has_uniform_background: Whether images have a uniform background. If true, the preprocessing creates a 26 | foreground mask. In this process it assumes the corners of the image are included in the background. 27 | :return: 28 | """ 29 | 30 | train_path = Path(train_dir) 31 | test_path = Path(test_dir) 32 | 33 | train_ids = set(['_'.join(f.name.split('_')[:-1]) for f in train_path.iterdir() if f.is_file()]) 34 | test_ids = set(['_'.join(f.name.split('_')[:-1]) for f in test_path.iterdir() if f.is_file()]) 35 | 36 | dataset_json = { 37 | 'name': dataset_name, 38 | 'description': dataset_description, 39 | 'reference': dataset_reference, 40 | 'licence': licence, 41 | 'release': dataset_release, 42 | 'tensorImageSize': '3D' if any('png' in m for m in modalities) else '4D', 43 | 'modality': {str(i): modalities[i] for i in range(len(modalities))}, 44 | 'numTraining': len(train_ids), 45 | 'numTest': len(test_ids), 46 | 'training': [ident for ident in train_ids], 47 | 'test': [ident for ident in test_ids], 48 | 'data_augs': data_augs, 49 | 'has_uniform_background': has_uniform_background 50 | } 51 | save_json(dataset_json, output_file) 52 | -------------------------------------------------------------------------------- /nnood/data/readme.md: -------------------------------------------------------------------------------- 1 | # Contains code for dataset conversion 2 | 3 | I recommend using dataset_conversion/utils.generate_dataset_json to match correct format. 4 | 5 | dataset.json components: 6 | - name: str - dataset name 7 | - description: str - dataset description 8 | - reference: str - reference for dataset source 9 | - licence: str - dataset licence 10 | - release: str - date of dataset release 11 | - tensorImageSize: str - dimensionality of data including channels, either '3D' or '4D' 12 | - tensorImageSize: Dict[str, str] - dictionary matching file number to modality 13 | - numTraining: int - number of training examples 14 | - numTest: int - number of test examples 15 | - training: List[str] - list of training sample ids 16 | - test: List[str] - list of test sample ids 17 | - data_augs: Dict[str, Dict[str, Any]] - Dictionary, mapping data augmentation name to parameters and values. 18 | - Possible transforms + parameters: 19 | - elastic: 20 | - deform_alpha: List[int] - Alpha values range e.g. [0,900] 21 | - deform_sigma: List[int] - Sigma values range e.g. [9, 13] 22 | - scaling: 23 | - scale_range: List[float] - Scale factor range e.g. [0.85, 1.25] 24 | - rotation: 25 | - rot_max: int - Maximum amount of rotation around axis e.g. 15 26 | - gamma: 27 | - gamma_range: List[float] - Range of gamma transform values e.g. [0.7, 1.5] 28 | - mirror: 29 | - mirror_axes: List[int] - List of valid axes to mirror, where axis 0 is the axis immediately after the channels 30 | dimension (y axis for 2D images, z axis for 3D) e.g. [0, 1, 2]. 31 | - additive_brightness: 32 | - additive_brightness_mu: float - Mean of additive brightness nouse 33 | - additive_brightness_sigma: float - Standard deviation of additive brightness nouse 34 | - Potentially implemented in future: 35 | - gaussian_noise 36 | - gaussian_blur 37 | - brightness_multiplicative 38 | - contrast_aug 39 | - sim_low_res 40 | -------------------------------------------------------------------------------- /nnood/data/sanity_checks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | from typing import List 4 | 5 | import nibabel as nib 6 | import SimpleITK as sitk 7 | 8 | from nnood.utils.file_operations import load_json 9 | 10 | 11 | def verify_all_same_orientation(all_files: List[Path]): 12 | 13 | # Assumes files are either .nii.gz or .png 14 | nii_files = [f for f in all_files if f.suffix != '.png'] 15 | orientations = [] 16 | for n in nii_files: 17 | img = nib.load(n) 18 | affine = img.affine 19 | orientation = nib.aff2axcodes(affine) 20 | orientations.append(orientation) 21 | # now we need to check whether they are all the same 22 | orientations = np.array(orientations) 23 | unique_orientations = np.unique(orientations, axis=0) 24 | all_same = len(unique_orientations) == 1 25 | return all_same, unique_orientations 26 | 27 | 28 | def verify_same_geometry(img_1: sitk.Image, img_2: sitk.Image): 29 | ori1, spacing1, direction1, size1 = img_1.GetOrigin(), img_1.GetSpacing(), img_1.GetDirection(), img_1.GetSize() 30 | ori2, spacing2, direction2, size2 = img_2.GetOrigin(), img_2.GetSpacing(), img_2.GetDirection(), img_2.GetSize() 31 | 32 | same_ori = np.all(np.isclose(ori1, ori2)) 33 | if not same_ori: 34 | print('The origin does not match between the images:') 35 | print(ori1) 36 | print(ori2) 37 | 38 | same_spac = np.all(np.isclose(spacing1, spacing2)) 39 | if not same_spac: 40 | print('The spacing does not match between the images') 41 | print(spacing1) 42 | print(spacing2) 43 | 44 | same_dir = np.all(np.isclose(direction1, direction2)) 45 | if not same_dir: 46 | print('The direction does not match between the images') 47 | print(direction1) 48 | print(direction2) 49 | 50 | same_size = np.all(np.isclose(size1, size2)) 51 | if not same_size: 52 | print('The size does not match between the images') 53 | print(size1) 54 | print(size2) 55 | 56 | if same_ori and same_spac and same_dir and same_size: 57 | return True 58 | else: 59 | return False 60 | 61 | 62 | def verify_dataset_integrity(dataset_folder: Path): 63 | """ 64 | folder needs the imagesTr folder and dataset.json. 65 | Optional imagesTs and labelsTr folders. 66 | Checks if all training cases and labels are present. 67 | Checks if all test cases (if any) are present. 68 | For each case, checks whether all modalities are present. 69 | For each case, checks whether the pixel grids are aligned. 70 | :param dataset_folder: 71 | :return: 72 | """ 73 | print('Verifying dataset in ', dataset_folder) 74 | 75 | assert dataset_folder.is_dir(), 'Dataset folder doesn\'t exist!' 76 | 77 | dataset_file = dataset_folder / 'dataset.json' 78 | assert dataset_file.is_file(), 'Missing dataset.json' 79 | 80 | train_folder = dataset_folder / 'imagesTr' 81 | assert train_folder.is_dir(), 'Missing training folder imagesTr' 82 | 83 | dataset = load_json(dataset_file) 84 | modalities = dataset['modality'] 85 | num_modalities = len(modalities) 86 | file_suffixes = ['png' if 'png' in modalities[str(i)] else 'nii.gz' for i in range(num_modalities)] 87 | 88 | training_ids = dataset['training'] 89 | assert len(training_ids) == len(np.unique(training_ids)), 'Duplicate training ids in dataset.json!' 90 | training_files = [f for f in train_folder.iterdir() if f.is_file()] 91 | 92 | num_train_files = len(training_files) 93 | num_expected_train_files = num_modalities * len(training_ids) 94 | assert num_train_files <= num_expected_train_files, 'Extra files in training folder (should be just training ' \ 95 | f'data modalities): {num_train_files} > ' \ 96 | f'{num_expected_train_files}' 97 | assert num_train_files >= num_expected_train_files, f'Missing files in training folder: {num_train_files} < ' \ 98 | f'{num_expected_train_files}' 99 | 100 | test_ids = dataset['test'] 101 | assert len(test_ids) == len(np.unique(test_ids)), 'Duplicate test ids in dataset.json!' 102 | 103 | geometries_OK = True 104 | has_nan = False 105 | all_files = [] 106 | 107 | print('Verifying training set') 108 | for c in training_ids: 109 | print('Checking case', c) 110 | 111 | # Check if all files are present 112 | expected_image_files = [train_folder / f'{c}_{i:04d}.{file_suffixes[i]}' for i in range(num_modalities)] 113 | all_files += expected_image_files 114 | 115 | for f in expected_image_files: 116 | assert f in training_files, f'Missing file {f}' 117 | 118 | images_itk = [sitk.ReadImage(i.__str__()) for i in expected_image_files] 119 | 120 | for i, img in enumerate(images_itk): 121 | nans_in_image = np.any(np.isnan(sitk.GetArrayFromImage(img))) 122 | has_nan = has_nan | nans_in_image 123 | same_geometry = verify_same_geometry(img, images_itk[0]) 124 | 125 | if not same_geometry: 126 | geometries_OK = False 127 | print(f'The geometry of the image {expected_image_files[i]} does not match the geometry of the label ' 128 | 'file. The pixel arrays will not be aligned and nnU-Net cannot use this data. Please make sure ' 129 | 'your image modalities are coregistered and have the same geometry as the label') 130 | 131 | if nans_in_image: 132 | print(f'There are NAN values in image {expected_image_files[i]}') 133 | 134 | # check test set, but only if there actually is a test set 135 | if len(test_ids) > 0: 136 | print('Verifying test set') 137 | 138 | test_folder = dataset_folder / 'imagesTs' 139 | assert test_folder.is_dir(), 'Test ids present, but no imagesTs folder!' 140 | 141 | test_labels_folder = dataset_folder / 'labelsTs' 142 | assert test_labels_folder.is_dir(), 'Test ids present, but no labelsTs folder! (if the examples are without ' \ 143 | 'ground truths, and you just want to predict them as examples, don\'t ' \ 144 | 'label them as tests' 145 | 146 | test_files = [f for f in test_folder.iterdir() if f.is_file()] 147 | num_test_files = len(test_files) 148 | num_expected_test_files = num_modalities * len(test_ids) 149 | assert num_test_files <= num_expected_test_files, 'Extra files in test folder (should be just test data ' \ 150 | f'modalities): {num_test_files} > {num_expected_test_files}' 151 | assert num_test_files >= num_expected_test_files, f'Missing files in training folder: {num_test_files} < ' \ 152 | f'{num_expected_test_files}' 153 | 154 | test_label_files = [f for f in test_labels_folder.iterdir() if f.is_file()] 155 | assert len(test_label_files) >= 1, 'Must have at least 1 ground truth label (cannot be all normal images)!' 156 | 157 | for c in test_ids: 158 | # Check if all files are present 159 | expected_image_files = [test_folder / f'{c}_{i:04d}.{file_suffixes[i]}' for i in range(num_modalities)] 160 | all_files += expected_image_files 161 | 162 | for f in expected_image_files: 163 | assert f in test_files, f'Missing file {f}' 164 | 165 | images_itk = [sitk.ReadImage(i.__str__()) for i in expected_image_files] 166 | 167 | files_to_check = images_itk 168 | 169 | # Bit of an assumption that label has same suffix as first image file 170 | # Unlikely that examples mix .png and .nii.gz 171 | curr_label_file = test_labels_folder / f'{c}.{file_suffixes[0]}' 172 | if curr_label_file in test_label_files: 173 | all_files.append(curr_label_file) 174 | files_to_check.append(sitk.ReadImage(curr_label_file.__str__())) 175 | 176 | for i, img in enumerate(files_to_check): 177 | nans_in_image = np.any(np.isnan(sitk.GetArrayFromImage(img))) 178 | has_nan = has_nan | nans_in_image 179 | same_geometry = verify_same_geometry(img, files_to_check[0]) 180 | 181 | if not same_geometry: 182 | geometries_OK = False 183 | print( 184 | f'The geometry of the image {expected_image_files[i]} does not match the geometry of the label ' 185 | 'file. The pixel arrays will not be aligned and nnU-Net cannot use this data. Please make sure ' 186 | 'your image modalities are coregistered and have the same geometry as the label') 187 | 188 | if nans_in_image: 189 | print(f'There are NAN values in image {expected_image_files[i]}') 190 | 191 | all_same = verify_all_same_orientation(all_files) 192 | if not all_same: 193 | print('WARNING: Not all images in the dataset have the same axis ordering. We very strongly recommend you ' 194 | 'correct that by reorienting the data. fslreorient2std should do the trick') 195 | # save unique orientations to dataset.json 196 | if not geometries_OK: 197 | raise Warning('GEOMETRY MISMATCH FOUND! CHECK THE TEXT OUTPUT! This does not cause an error at this point but ' 198 | 'you should definitely check whether your geometries are alright!') 199 | else: 200 | print('Dataset OK') 201 | 202 | if has_nan: 203 | raise RuntimeError('Some images have nan values in them. This will break the training. See text output above ' 204 | 'to see which ones') 205 | -------------------------------------------------------------------------------- /nnood/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from datetime import datetime 3 | import hashlib 4 | import json 5 | from pathlib import Path 6 | import random 7 | from typing import Dict, List, Optional, Tuple 8 | 9 | import SimpleITK as sitk 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from nnood.evaluation.metrics import ALL_METRICS 14 | from nnood.training.dataloading.dataset_loading import load_npy_or_npz 15 | from nnood.utils.file_operations import save_json 16 | 17 | TARGET_METRIC_BATCH_SIZE = 25 18 | 19 | 20 | def _load_file_pairs_list(pred_ref_pairs: List[Tuple[Path, Optional[Path]]]) -> List[Tuple[np.ndarray, np.ndarray]]: 21 | def _load_file(f: Path) -> np.ndarray: 22 | 23 | if f.suffix == '.npz' or f.suffix == '.npy': 24 | # A file saved in a numpy representation must be in the correct form (channels first). 25 | return load_npy_or_npz(f, 'r') 26 | 27 | sitk_img = sitk.ReadImage(f.__str__()) 28 | img = sitk.GetArrayFromImage(sitk_img) 29 | 30 | sitk_shape = np.array(sitk_img.GetSize())[::-1] 31 | img_shape = np.array(img.shape) 32 | 33 | if len(img_shape) != len(sitk_shape): 34 | # Must be a vector image, represented with channels last. Convert to channels first. 35 | return np.moveaxis(img, -1, 0) 36 | else: 37 | # No channels, so no swap needed. 38 | return img 39 | 40 | def _load_file_pair(p_r_pair: Tuple[Path, Optional[Path]]) -> Tuple[np.ndarray, np.ndarray]: 41 | p, r = p_r_pair 42 | pred = _load_file(p) 43 | 44 | # If ref is None, then image is normal, so return array of zeroes. 45 | if r is None: 46 | return pred, np.zeros_like(pred) 47 | else: 48 | return pred, _load_file(r) 49 | 50 | return [_load_file_pair(p) for p in pred_ref_pairs] 51 | 52 | 53 | def compute_metric_scores(pred_ref_file_pairs: List[Tuple[Path, Optional[Path]]], **metric_kwargs) -> Dict: 54 | pred_ref_img_pairs = _load_file_pairs_list(pred_ref_file_pairs) 55 | 56 | preds = [p for p, _ in pred_ref_img_pairs] 57 | ref_labels = [l for _, l in pred_ref_img_pairs] 58 | 59 | all_ref_labels = np.concatenate([r.flatten() for r in ref_labels]) 60 | label_values = np.unique(all_ref_labels) 61 | 62 | # Evaluation labels should be binary (whether or not anomaly is present) 63 | unexpected_labels = [v for v in label_values if v not in [0, 1]] 64 | if len(unexpected_labels) > 0: 65 | # print('Binarising reference labels as found values other than [0, 1]: ', unexpected_labels) 66 | for r_l in ref_labels: 67 | r_l[r_l != 0] = 1 68 | 69 | # Default to computing all metrics 70 | chosen_metrics = metric_kwargs.get('metrics', ALL_METRICS.keys()) 71 | results = OrderedDict() 72 | 73 | for m in tqdm(chosen_metrics, desc='Computing different metrics...'): 74 | results[m] = ALL_METRICS[m](preds, ref_labels, **metric_kwargs) 75 | 76 | return results 77 | 78 | 79 | def aggregate_scores(pred_ref_file_pairs: List[Tuple[Path, Optional[Path]]], 80 | json_output_file=None, 81 | json_name='', 82 | json_description='', 83 | json_author='Anonymous', 84 | json_task='', 85 | **metric_kwargs) -> Dict: 86 | """ 87 | test = predicted image 88 | :param pred_ref_file_pairs: 89 | :param json_output_file: 90 | :param json_name: 91 | :param json_description: 92 | :param json_author: 93 | :param json_task: 94 | :param metric_kwargs: 95 | :return: 96 | """ 97 | 98 | random.shuffle(pred_ref_file_pairs) 99 | 100 | num_pairs = len(pred_ref_file_pairs) 101 | if num_pairs < TARGET_METRIC_BATCH_SIZE: 102 | print('Computing metrics over entire dataset') 103 | results = compute_metric_scores(pred_ref_file_pairs, **metric_kwargs) 104 | else: 105 | num_batches = round(num_pairs / TARGET_METRIC_BATCH_SIZE) 106 | batch_size = int(num_pairs / num_batches) 107 | 108 | print(f'Computing metrics over {num_batches} batches, each of size around {batch_size}') 109 | 110 | all_results = [] 111 | last_index = 0 112 | for i in range(num_batches - 1): 113 | print(f'Computing batch {i}') 114 | last_index = (i + 1) * batch_size 115 | all_results.append(compute_metric_scores(pred_ref_file_pairs[i * batch_size: last_index], **metric_kwargs)) 116 | 117 | print('Computing final batch') 118 | all_results.append(compute_metric_scores(pred_ref_file_pairs[last_index: num_pairs], **metric_kwargs)) 119 | 120 | results = {} 121 | for k in all_results[0].keys(): 122 | results[k] = np.mean([r[k] for r in all_results]) 123 | 124 | # We create a hopefully unique id by hashing the entire output dictionary 125 | if json_output_file is not None: 126 | json_dict = OrderedDict() 127 | json_dict['name'] = json_name 128 | json_dict['description'] = json_description 129 | timestamp = datetime.today() 130 | json_dict['timestamp'] = str(timestamp) 131 | json_dict['task'] = json_task 132 | json_dict['author'] = json_author 133 | json_dict['results'] = results 134 | json_dict['id'] = hashlib.md5(json.dumps(json_dict).encode('utf-8')).hexdigest()[:12] 135 | save_json(json_dict, json_output_file) 136 | 137 | return results 138 | -------------------------------------------------------------------------------- /nnood/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import bisect 3 | 4 | import cv2 5 | import numpy as np 6 | from sklearn import metrics 7 | 8 | 9 | def _flatten_all(predictions: Union[List[np.ndarray], np.ndarray], labels: Union[List[np.ndarray], np.ndarray]): 10 | assert type(predictions) == type(labels), f'Types of predictions and labels in evaluation metric must match: ' \ 11 | f'{type(predictions)}, {type(labels)}' 12 | if type(predictions) is list: 13 | predictions = np.concatenate([p.flatten() for p in predictions]) 14 | labels = np.concatenate([l.flatten() for l in labels]) 15 | else: 16 | predictions = predictions.flatten() 17 | labels = labels.flatten() 18 | 19 | assert len(predictions) == len(labels), f'Length of predictions and labels in evaluation metric must match: ' \ 20 | f'{len(predictions)}, {len(labels)}' 21 | return predictions, labels 22 | 23 | 24 | def auroc(predictions: Union[List[np.ndarray], np.ndarray], labels: Union[List[np.ndarray], np.ndarray], **kwargs): 25 | predictions, labels = _flatten_all(predictions, labels) 26 | if kwargs.get('return_curve', False): 27 | fpr, tpr, _ = metrics.roc_curve(labels, predictions) 28 | return metrics.roc_auc_score(labels, predictions), fpr, tpr 29 | else: 30 | return metrics.roc_auc_score(labels, predictions) 31 | 32 | 33 | def average_precision(predictions: Union[List[np.ndarray], np.ndarray], labels: Union[List[np.ndarray], np.ndarray], 34 | **kwargs): 35 | predictions, labels = _flatten_all(predictions, labels) 36 | if kwargs.get('return_curve', False): 37 | precision, recall, _ = metrics.precision_recall_curve(labels, predictions) 38 | # Don't use auc function, as that uses trapezoidal rule which can be too optimistic 39 | return metrics.average_precision_score(labels, predictions), precision, recall 40 | else: 41 | return metrics.average_precision_score(labels, predictions) 42 | 43 | 44 | def per_region_overlap(predictions: List[np.ndarray], labels: List[np.ndarray], **kwargs): 45 | max_fpr = kwargs.get('pro_max_fpr', 0.3) 46 | max_components = kwargs.get('pro_max_components', 25) 47 | 48 | flat_preds, flat_labels = _flatten_all(predictions, labels) 49 | fpr, _, thresholds = metrics.roc_curve(flat_labels, flat_preds) 50 | split = len(fpr[fpr < max_fpr]) 51 | # last thresh has fpr >= max_fpr 52 | fpr = fpr[:(split + 1)] 53 | thresholds = thresholds[:(split + 1)] 54 | neg_thresholds = -thresholds 55 | for p in predictions: 56 | p[p < thresholds[-1]] = 0 57 | 58 | # calculate per-component-overlap for each threshold and match to global thresholds 59 | pro = np.zeros_like(fpr) 60 | total_components = 0 61 | for j in range(len(labels)): 62 | num_labels, label_img = cv2.connectedComponents(np.uint8(labels[j])) 63 | if num_labels > max_components: 64 | print(f'Invalid label map: too many components ({num_labels}) skipping sample {j}.') 65 | if num_labels == 1: # only background 66 | continue 67 | total_components += num_labels - 1 68 | 69 | y_score = predictions[j].flatten() 70 | desc_score_indices = np.argsort(y_score, kind='mergesort')[::-1] 71 | y_score = y_score[desc_score_indices] 72 | distinct_value_indices = np.where(np.diff(y_score))[0] 73 | threshold_idxs = np.r_[distinct_value_indices, y_score.size - 1] 74 | thresholds_j = y_score[threshold_idxs] 75 | for k in range(1, num_labels): 76 | y_true = np.uint8(label_img == k).flatten() 77 | y_true = y_true[desc_score_indices] 78 | tps = np.cumsum(y_true)[threshold_idxs] 79 | tpr = tps / tps[-1] 80 | 81 | # match tprs to global thresholds so that we can calculate pro 82 | right = len(thresholds) 83 | for tpr_t, t in zip(tpr[::-1], thresholds_j[::-1]): # iterate in ascending order 84 | if t < thresholds[-1]: # remove too small thresholds 85 | continue 86 | i = bisect.bisect_left(neg_thresholds, -t, hi=right) # search for negated as thresholds desc 87 | pro[i: right] += tpr_t 88 | right = i 89 | pro /= total_components 90 | 91 | if fpr[-1] > max_fpr: # interpolate last value 92 | pro[-1] = ((max_fpr - fpr[-2]) * pro[-1] + (fpr[-1] - max_fpr) * pro[-2]) / (fpr[-1] - fpr[-2]) 93 | fpr[-1] = max_fpr 94 | 95 | if kwargs.get('return_curve', False): 96 | return metrics.auc(fpr, pro) / max_fpr, fpr, pro 97 | else: 98 | return metrics.auc(fpr, pro) / max_fpr 99 | 100 | 101 | def all_score_stats(predictions: List[np.ndarray], labels: List[np.ndarray], **kwargs): 102 | predictions, _ = _flatten_all(predictions, labels) 103 | return np.mean(predictions, dtype=float), np.std(predictions, dtype=float) 104 | 105 | 106 | def anomaly_score_stats(predictions: List[np.ndarray], labels: List[np.ndarray], **kwargs): 107 | predictions, labels = _flatten_all(predictions, labels) 108 | anomaly_scores = predictions[labels.astype(bool)] 109 | return np.mean(anomaly_scores, dtype=float), np.std(anomaly_scores, dtype=float) 110 | 111 | 112 | def normal_score_stats(predictions: List[np.ndarray], labels: List[np.ndarray], **kwargs): 113 | predictions, labels = _flatten_all(predictions, labels) 114 | normal_scores = predictions[np.invert(labels.astype(bool))] 115 | return np.mean(normal_scores, dtype=float), np.std(normal_scores, dtype=float) 116 | 117 | 118 | ALL_METRICS = { 119 | 'AUROC': auroc, 120 | 'AP score': average_precision, 121 | # 'AU-PRO': per_region_overlap, # Uses OpenCV so only works on 2D, uint8 images. 122 | } 123 | -------------------------------------------------------------------------------- /nnood/evaluation/nnOOD_evaluate_folder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | from nnood.evaluation.evaluator import aggregate_scores 5 | 6 | 7 | def evaluate_folder(folder_with_gts: str, folder_with_predictions: str, ref_suffix: str, pred_suffix: str, 8 | **metric_kwargs): 9 | """ 10 | writes a summary.json to folder_with_predictions 11 | :param folder_with_gts: folder where the ground truth segmentations are saved. Must be nifti files. 12 | :param folder_with_predictions: folder where the predicted segmentations are saved. Must be nifti files. 13 | :param ref_suffix: 14 | :param pred_suffix: 15 | :return: 16 | """ 17 | folder_with_gts = Path(folder_with_gts) 18 | folder_with_predictions = Path(folder_with_predictions) 19 | # Can't just check suffix as same suffix's include '.' within them (like .nii.gz) and pathlib only counts the final 20 | # .XXX as the suffix 21 | files_pred = [f for f in folder_with_predictions.iterdir() if f.is_file() and f.name.endswith(pred_suffix)] 22 | 23 | files_gt = [folder_with_gts / (f.name[:-len(pred_suffix)] + ref_suffix) for f in files_pred] 24 | 25 | missing_gts = [] 26 | 27 | def check_gt_exists(f: Path) -> Optional[Path]: 28 | if f.is_file(): 29 | return f 30 | else: 31 | missing_gts.append(f.name) 32 | return None 33 | 34 | files_gt = list(map(check_gt_exists, files_gt)) 35 | 36 | # noinspection PySimplifyBooleanCheck 37 | if missing_gts != []: 38 | print(f'Files missing gt, assumed to be entirely normal ({len(missing_gts)} in total).') 39 | for f in missing_gts: 40 | print(f) 41 | 42 | test_ref_pairs = list(zip(files_pred, files_gt)) 43 | res = aggregate_scores(test_ref_pairs, json_output_file=folder_with_predictions / 'summary.json', 44 | **metric_kwargs) 45 | print() 46 | print('Evaluation results:') 47 | print(res) 48 | 49 | 50 | def main(): 51 | import argparse 52 | parser = argparse.ArgumentParser('Evaluates the anomaly scores located in the folder pred. Output of this script ' 53 | 'is a json file. At the very bottom of the json file is going to be a \'mean\' ' 54 | 'entry with averages metrics across all cases.') 55 | parser.add_argument('-ref', required=True, type=str, help='Folder containing the reference labels.') 56 | parser.add_argument('-pred', required=True, type=str, help='Folder containing the predicted scores. File names ' 57 | 'must match between the folders!') 58 | parser.add_argument('-r_s', '--ref_suffix', required=True, type=str, help='File suffix of the reference images.') 59 | parser.add_argument('-p_s', '--pred_suffix', required=True, type=str, help='File suffix of the predictions.') 60 | 61 | args = parser.parse_args() 62 | evaluate_folder(args.ref, args.pred, args.ref_suffix, args.pred_suffix) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /nnood/evaluation/nnOOD_run_testing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from nnood.evaluation.nnOOD_evaluate_folder import evaluate_folder 8 | from nnood.inference.predict import predict_from_folder 9 | from nnood.paths import default_plans_identifier, preprocessed_data_base, raw_data_base, results_base 10 | from nnood.utils.file_operations import load_pickle 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-d', '--dataset', help='Dataset which the model was trained on.', required=True) 16 | parser.add_argument('-t', '--task_name', help='Self-supervised task which the model was trained on', required=True) 17 | parser.add_argument('-tr', '--trainer_class_name', 18 | help='Name of the nnOODTrainer used for full resolution and low resolution U-Net. If you are ' 19 | 'running inference with the cascade and the folder pointed to by --lowres_maps ' 20 | 'does not contain the anomaly maps generated by the low resolution U-Net then the low ' 21 | 'resolution anomaly maps will be automatically generated. For this case, make sure to set ' 22 | 'the trainer class here that matches your --cascade_trainer_class_name.', 23 | required=True) 24 | parser.add_argument('-ctr', '--cascade_trainer_class_name', 25 | help='Trainer class name used for predicting the full resolution U-Net part of the cascade.', 26 | default=None, required=False) 27 | parser.add_argument('-id', '--test_identifier', default=None, required=False, 28 | help='Use identifier when making results folder, optional.') 29 | parser.add_argument('-p', '--plans_identifier', help='do not touch this unless you know what you are doing', 30 | default=default_plans_identifier, required=False) 31 | parser.add_argument('-f', '--folds', nargs='+', default='None', 32 | help='folds to use for prediction. Default is None which means that folds will be detected ' 33 | 'automatically in the model output folder') 34 | parser.add_argument('--num_threads_preprocessing', required=False, default=6, type=int, 35 | help='Determines many background processes will be used for data preprocessing. Reduce this if ' 36 | 'you run into out of memory (RAM) problems. Default: 6') 37 | parser.add_argument('--num_threads_save', required=False, default=2, type=int, 38 | help='Determines many background processes will be used for exporting. Reduce this if you run ' 39 | 'into out of memory (RAM) problems. Default: 2') 40 | parser.add_argument('--disable_tta', required=False, default=False, action='store_true', 41 | help='set this flag to disable test time data augmentation via mirroring. Speeds up inference ' 42 | 'by roughly factor 4 (2D) or 8 (3D)') 43 | parser.add_argument('--overwrite_existing', required=False, default=False, action='store_true', 44 | help='Set this flag if the target folder contains predictions that you would like to overwrite') 45 | parser.add_argument('--all_in_gpu', type=str, default='None', required=False, help='can be None, False or True. ' 46 | 'Do not touch.') 47 | parser.add_argument('--step_size', type=float, default=0.5, required=False, help='don\'t touch') 48 | parser.add_argument('-chk', 49 | help='checkpoint name, default: model_final_checkpoint', 50 | required=False, 51 | default='model_final_checkpoint') 52 | parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False, 53 | help='Predictions are done with mixed precision by default. This improves speed and reduces ' 54 | 'the required vram. If you want to disable mixed precision you can set this flag. Note ' 55 | 'that this is not recommended (mixed precision is ~2x faster!)') 56 | parser.add_argument('--lowres_only', required=False, default=False, action='store_true', 57 | help='Set this flag if you want to only use the lowres stage of a 2 step pipeline') 58 | 59 | args = parser.parse_args() 60 | task_name = args.task_name 61 | dataset = args.dataset 62 | plans_identifier = args.plans_identifier 63 | folds = args.folds 64 | num_threads_preprocessing = args.num_threads_preprocessing 65 | num_threads_save = args.num_threads_save 66 | disable_tta = args.disable_tta 67 | step_size = args.step_size 68 | overwrite_existing = args.overwrite_existing 69 | all_in_gpu = args.all_in_gpu 70 | trainer_class_name = args.trainer_class_name 71 | cascade_trainer_class_name = args.cascade_trainer_class_name 72 | lowres_only = args.lowres_only 73 | 74 | if isinstance(folds, list): 75 | if folds[0] == 'all' and len(folds) == 1: 76 | pass 77 | else: 78 | folds = [int(i) for i in folds] 79 | elif folds == 'None': 80 | folds = None 81 | else: 82 | raise ValueError('Unexpected value for argument folds') 83 | 84 | assert all_in_gpu in ['None', 'False', 'True'] 85 | if all_in_gpu == 'None': 86 | all_in_gpu = None 87 | elif all_in_gpu == 'True': 88 | all_in_gpu = True 89 | elif all_in_gpu == 'False': 90 | all_in_gpu = False 91 | 92 | plans_file = Path(preprocessed_data_base, dataset, plans_identifier) 93 | assert plans_file.is_file(), f'Missing plans file: {plans_file}' 94 | 95 | input_folder = Path(raw_data_base, dataset, 'imagesTs') 96 | labels_folder = Path(raw_data_base, dataset, 'labelsTs') 97 | assert input_folder.is_dir(), f'Missing test images folder: {input_folder}' 98 | assert labels_folder.is_dir(), f'Missing test labels folder: {labels_folder}' 99 | 100 | plans = load_pickle(plans_file) 101 | possible_stages = list(plans['plans_per_stage'].keys()) 102 | 103 | models = ['fullres'] if len(possible_stages) == 1 else ['lowres', 'cascade_fullres'] 104 | 105 | if lowres_only: 106 | assert 'lowres' in models, 'Cannot run lowres only on a pipeline without a lowres stage!' 107 | models = ['lowres'] 108 | 109 | if 'cascade_fullres' in models: 110 | assert cascade_trainer_class_name is not None, 'Cannot use cascade_fullres model without defining' \ 111 | 'cascade_trainer_class_name' 112 | 113 | output_folder_base = Path(results_base, dataset, task_name, 'testResults', plans_identifier) 114 | if args.test_identifier is not None: 115 | output_folder_base /= args.test_identifier 116 | 117 | lowres_scores = None 118 | 119 | for model in models: 120 | print(f'Starting predictions for {model}') 121 | 122 | curr_trainer = cascade_trainer_class_name if model == 'cascade_fullres' else trainer_class_name 123 | curr_output = output_folder_base / ('lowres_predictions' if model == 'lowres' and not lowres_only else '') 124 | 125 | if model == 'cascade_fullres': 126 | assert lowres_scores.is_dir(), 'Somehow attempting cascade_fullres without lowres_scores being a dir.' 127 | 128 | model_folder = Path(results_base, dataset, task_name, model, curr_trainer + '__' + plans_identifier) 129 | print(f'Model is stored in: {model_folder}') 130 | assert model_folder.is_dir(), f'Model output folder not found, expected: {model_folder}' 131 | 132 | predict_from_folder(model_folder, input_folder, curr_output, folds, True, num_threads_preprocessing, 133 | num_threads_save, lowres_scores, 0, 1, not disable_tta, 134 | mixed_precision=not args.disable_mixed_precision, overwrite_existing=overwrite_existing, 135 | overwrite_all_in_gpu=all_in_gpu, step_size=step_size, checkpoint_name=args.chk) 136 | 137 | if model == 'lowres': 138 | lowres_scores = curr_output 139 | torch.cuda.empty_cache() 140 | 141 | label_suffix = 'png' if np.array(['png' in p for p in plans['modalities'].values()]).any() else 'nii.gz' 142 | 143 | print(f'Starting test evaluation, with label suffix {label_suffix}') 144 | evaluate_folder(labels_folder.__str__(), output_folder_base.__str__(), label_suffix, 'npz') 145 | 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /nnood/evaluation/readme.md: -------------------------------------------------------------------------------- 1 | Entrypoints: 2 | - nnood_evaluate_folder 3 | - evaluate performance of folder of predictions against folder of ground truths 4 | 5 | - nnood_run_testing 6 | - combines prediction and evaluation of test images for dataset (presuming they are organised in the standard data format). -------------------------------------------------------------------------------- /nnood/experiment_planning/DatasetAnalyser.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from multiprocessing import Pool 3 | from typing import List, Dict, Optional 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import SimpleITK as sitk 8 | 9 | from nnood.preprocessing.normalisation import GLOBAL_NORMALISATION_MODALITIES 10 | from nnood.utils.file_operations import load_json, save_pickle 11 | from nnood.paths import DATASET_JSON_FILE, DATASET_PROPERTIES_FILE 12 | 13 | 14 | class DatasetAnalyser: 15 | def __init__(self, raw_data_path: Path, num_processes: int): 16 | 17 | self.raw_data_path = raw_data_path 18 | self.num_processes = num_processes 19 | self.sizes = None 20 | self.spacings = None 21 | 22 | self.dataset_json = load_json(raw_data_path / DATASET_JSON_FILE) 23 | 24 | self.sample_identifiers: List[str] = self.dataset_json['training'] 25 | 26 | def get_modalities(self): 27 | str_modalities = self.dataset_json['modality'] 28 | return {int(k): str_modalities[k] for k in str_modalities} 29 | 30 | @staticmethod 31 | def requires_global_normalisation(modality: str): 32 | return modality.lower() in GLOBAL_NORMALISATION_MODALITIES 33 | 34 | @staticmethod 35 | def get_voxels_in_foreground(image_array: np.ndarray) -> np.ndarray: 36 | 37 | assert len(image_array.shape) in [2, 3], f'Image must have 2 or 3 dimensions: {len(image_array.shape)}' 38 | 39 | mask = image_array > 0 40 | # Only include every 5, reducing memory with little effect on global statistics 41 | return image_array[mask][::5] 42 | 43 | def analyse_sample(self, sample_id: str) -> (OrderedDict, Dict[str, Optional[np.ndarray]]): 44 | 45 | modalities = self.get_modalities() 46 | properties = OrderedDict() 47 | properties['sample_id'] = sample_id 48 | properties['data_files'] = [] 49 | channel_intensity_properties = [] 50 | modality_intensities = OrderedDict() 51 | 52 | for i in range(len(modalities)): 53 | mod = modalities[i] 54 | suffix = 'png' if 'png' in mod else 'nii.gz' 55 | file_path = self.raw_data_path / 'imagesTr' / f'{sample_id}_{i:04d}.{suffix}' 56 | 57 | properties['data_files'].append(file_path) 58 | 59 | # Assume the size/spacing statistics are equal across modalities of sample 60 | itk_image = sitk.ReadImage(file_path.__str__()) 61 | if i == 0: 62 | properties['original_size'] = np.array(itk_image.GetSize())[::-1] 63 | properties['original_spacing'] = np.array(itk_image.GetSpacing())[::-1] 64 | properties['itk_origin'] = itk_image.GetOrigin() 65 | properties['itk_spacing'] = itk_image.GetSpacing() 66 | properties['itk_direction'] = itk_image.GetDirection() 67 | 68 | image_array = sitk.GetArrayFromImage(itk_image).astype(np.float32) 69 | 70 | if 'png' in mod: 71 | image_array /= 255 72 | 73 | if DatasetAnalyser.requires_global_normalisation(mod): 74 | modality_intensities[mod] = DatasetAnalyser.get_voxels_in_foreground(image_array) 75 | else: 76 | modality_intensities[mod] = None 77 | 78 | if 'png' in mod and len(image_array.shape) != 2: 79 | for c in range(image_array.shape[-1]): 80 | d = OrderedDict() 81 | curr_channel = image_array[:, :, c] 82 | d['mean'] = np.mean(curr_channel) 83 | d['sd'] = np.std(curr_channel) 84 | channel_intensity_properties.append(d) 85 | else: 86 | d = OrderedDict() 87 | d['mean'] = np.mean(image_array) 88 | d['sd'] = np.std(image_array) 89 | channel_intensity_properties.append(d) 90 | 91 | properties['channel_intensity_properties'] = OrderedDict() 92 | for c in range(len(channel_intensity_properties)): 93 | properties['channel_intensity_properties'][c] = channel_intensity_properties[c] 94 | 95 | return properties, modality_intensities 96 | 97 | def calc_intensity_properties(self, all_sample_intensities: List[Dict[str, Optional[np.ndarray]]])\ 98 | -> Dict[str, Optional[OrderedDict]]: 99 | 100 | modality_statistics = OrderedDict() 101 | 102 | mods = self.get_modalities() 103 | for i in range(len(mods)): 104 | mod = mods[i] 105 | if all_sample_intensities[0][mod] is None: 106 | modality_statistics[mod] = None 107 | else: 108 | all_mod_intensities = np.concatenate(list(map(lambda d: d[mod], all_sample_intensities))) 109 | 110 | modality_statistics[i] = OrderedDict() 111 | modality_statistics[i]['median'] = np.median(all_mod_intensities) 112 | modality_statistics[i]['mean'] = np.mean(all_mod_intensities) 113 | modality_statistics[i]['sd'] = np.std(all_mod_intensities) 114 | modality_statistics[i]['min'] = np.min(all_mod_intensities) 115 | modality_statistics[i]['max'] = np.max(all_mod_intensities) 116 | modality_statistics[i]['percentile_99_5'] = np.percentile(all_mod_intensities, 99.5) 117 | modality_statistics[i]['percentile_00_5'] = np.percentile(all_mod_intensities, 00.5) 118 | 119 | return modality_statistics 120 | 121 | def analyse_dataset(self): 122 | 123 | with Pool(self.num_processes) as pool: 124 | all_sample_analytics = pool.map(self.analyse_sample, self.sample_identifiers) 125 | 126 | # List of tuples to tuple of lists 127 | all_sample_properties, all_sample_intensities = map(list, zip(*all_sample_analytics)) 128 | 129 | dataset_properties = OrderedDict() 130 | dataset_properties['all_sizes'] = list(map(lambda ps: ps['original_size'], all_sample_properties)) 131 | dataset_properties['all_spacings'] = list(map(lambda ps: ps['original_spacing'], all_sample_properties)) 132 | # dataset_properties['all_data_files'] = map(lambda ps: ps['data_files'], all_sample_properties) 133 | dataset_properties['sample_identifiers'] = self.sample_identifiers 134 | dataset_properties['modalities'] = self.get_modalities() 135 | dataset_properties['intensity_properties'] = self.calc_intensity_properties(all_sample_intensities) 136 | dataset_properties['tensor_dimensions'] = self.dataset_json['tensorImageSize'] 137 | dataset_properties['data_augs'] = self.dataset_json['data_augs'] 138 | dataset_properties['has_uniform_background'] = self.dataset_json['has_uniform_background'] 139 | 140 | dataset_properties['sample_properties'] = OrderedDict() 141 | for sample_id in self.sample_identifiers: 142 | dataset_properties['sample_properties'][sample_id] = next(s_p for s_p in all_sample_properties 143 | if s_p['sample_id'] == sample_id) 144 | 145 | save_pickle(dataset_properties, self.raw_data_path / DATASET_PROPERTIES_FILE) 146 | return dataset_properties 147 | -------------------------------------------------------------------------------- /nnood/experiment_planning/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/experiment_planning/modality_conversion.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | # Some 'modalities' actually contain multiple modalities 4 | MODALITY_TRANSLATOR = { 5 | 'png': ('png-r', 'png-g', 'png-b'), 6 | 'ct': ('ct',), 7 | 'mri': ('mri',), 8 | 'png-bw': ('png-bw',), # Greyscale image 9 | 'png-xray': ('png-xray',) 10 | } 11 | 12 | 13 | def num_modality_components(mod: str): 14 | 15 | if mod not in MODALITY_TRANSLATOR: 16 | print(f'Unknown modality {mod}, assuming it only has 1 component') 17 | return 1 18 | 19 | return len(MODALITY_TRANSLATOR[mod]) 20 | 21 | 22 | def get_channel_list(modalities: OrderedDict): 23 | channel_list = [] 24 | for i in modalities.keys(): 25 | mod_components = MODALITY_TRANSLATOR[modalities[i]] 26 | for m in mod_components: 27 | channel_list.append(m) 28 | return channel_list 29 | -------------------------------------------------------------------------------- /nnood/experiment_planning/nnOOD_plan_and_preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from pathlib import Path 4 | 5 | from nnood.configuration import default_num_processes 6 | from nnood.data.sanity_checks import verify_dataset_integrity 7 | from nnood.paths import raw_data_base, preprocessed_data_base 8 | from nnood.experiment_planning.DatasetAnalyser import DatasetAnalyser 9 | from nnood.experiment_planning.experiment_planner import ExperimentPlanner 10 | 11 | # Plan experiment, and convert dataset to .npz format (gathering modalities of each sample) 12 | if __name__ == '__main__': 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-d', '--dataset_names', nargs='+', help='List of datasets you wish to run planning and' 16 | ' preprocessing for.') 17 | parser.add_argument('-no_pp', action='store_true', 18 | help='Set this flag if you dont want to run the preprocessing. If this is set then this script ' 19 | 'will only run the experiment planning and create the plans file') 20 | # parser.add_argument('-tl', type=int, required=False, default=8, 21 | # help='Number of processes used for preprocessing the low resolution data for the 3D low ' 22 | # 'resolution U-Net. This can be larger than -tf. Don't overdo it or you will run out of ' 23 | # 'RAM') 24 | parser.add_argument('-p', type=int, required=False, default=default_num_processes, 25 | help='Number of processes used for preprocessing the full resolution data. Don\'t overdo it or ' 26 | 'you will run out of RAM') 27 | parser.add_argument('--verify_dataset_integrity', required=False, default=False, action='store_true', 28 | help='Set this flag to check the dataset integrity. This is useful and should be done once for ' 29 | 'each dataset!') 30 | parser.add_argument('--disable_skip', type=int, required=False, default=0, 31 | help='Number of skip connections to disable (starting from top of U-Net). Remember to change' 32 | 'plans identifier when generating new plans!') 33 | 34 | args = parser.parse_args() 35 | dataset_names = args.dataset_names 36 | run_preprocessing = not args.no_pp 37 | num_processes = args.p 38 | verify_dataset_integ = args.verify_dataset_integrity 39 | disable_skip = args.disable_skip 40 | 41 | raw_data_path = Path(raw_data_base) 42 | 43 | dataset_paths = [] 44 | for d_n in dataset_names: 45 | d_path = raw_data_path / d_n 46 | 47 | if verify_dataset_integ: 48 | verify_dataset_integrity(d_path) 49 | 50 | dataset_paths.append(d_path) 51 | 52 | for d_path in dataset_paths: 53 | print('Planning for dataset at: ', d_path) 54 | 55 | print('Analysing dataset...') 56 | dataset_analyser = DatasetAnalyser(d_path, num_processes=num_processes) 57 | _ = dataset_analyser.analyse_dataset() 58 | 59 | print('Copying dataset.json and properties files') 60 | preprocessed_data_dir = Path(preprocessed_data_base) / d_path.name 61 | preprocessed_data_dir.mkdir(exist_ok=True) 62 | shutil.copy(d_path / 'dataset.json', preprocessed_data_dir) 63 | shutil.copy(d_path / 'dataset_properties.pkl', preprocessed_data_dir) 64 | 65 | print('Planning experiment...') 66 | exp_planner = ExperimentPlanner(d_path, preprocessed_data_dir, num_processes, disable_skip) 67 | exp_planner.plan_experiment() 68 | if run_preprocessing: 69 | print('Running preprocessing...') 70 | exp_planner.run_preprocessing() 71 | 72 | print('Done') 73 | -------------------------------------------------------------------------------- /nnood/experiment_planning/nnOOD_update_plans_number.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import re 4 | import shutil 5 | 6 | from nnood.paths import default_plans_identifier, preprocessed_data_base 7 | 8 | # Copy old plan to current plans identifier. 9 | # Use if you've made a change in another aspect of the pipeline (like the trainer) 10 | if __name__ == '__main__': 11 | 12 | match = re.search(r'(.*\.)(\d+)$', default_plans_identifier) 13 | assert match is not None, f"Default plans identifier doesn't match expected pattern: {default_plans_identifier}" 14 | plans_id_stem = match.group(1) 15 | curr_id_num = int(match.group(2)) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-c', '--copy_from', required=False, default=plans_id_stem + str(curr_id_num - 1), 19 | help='Identifier of plans to copy from. Defaults to previous.') 20 | parser.add_argument('-d', '--dataset_names', nargs='+', help='List of datasets you wish to update the plans for.') 21 | 22 | args = parser.parse_args() 23 | copy_from = args.copy_from 24 | dataset_names = args.dataset_names 25 | 26 | print('Copying plans from: ', copy_from) 27 | print('New plans identifier: ', default_plans_identifier) 28 | print('Updating plans of datasets:') 29 | 30 | for d_n in dataset_names: 31 | print(d_n) 32 | 33 | preprocessed_data_path = Path(preprocessed_data_base) 34 | 35 | for d_n in dataset_names: 36 | preprocessed_data_dir = preprocessed_data_path / d_n 37 | assert preprocessed_data_dir.is_dir(), f'Missing directory for preprocessed data: {preprocessed_data_dir}' 38 | 39 | old_plans_path = preprocessed_data_dir / copy_from 40 | assert old_plans_path.is_file(), f'Missing plans to copy from: {old_plans_path}' 41 | 42 | new_plans_path = preprocessed_data_dir / default_plans_identifier 43 | assert not new_plans_path.is_file(), f'Plans with current identifier already exist: {new_plans_path}' 44 | 45 | shutil.copy(old_plans_path, new_plans_path) 46 | print('Done') 47 | -------------------------------------------------------------------------------- /nnood/experiment_planning/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | 5 | 6 | def get_pool_and_conv_props(patch_size: np.ndarray, min_feature_map_size: int, max_num_pool: int, spacing: np.ndarray): 7 | """ 8 | :param patch_size: 9 | :param min_feature_map_size: min edge length of feature maps in bottleneck 10 | :param max_num_pool 11 | :param spacing: 12 | :return: 13 | """ 14 | dim = len(spacing) 15 | 16 | current_spacing = deepcopy(list(spacing)) 17 | current_size = deepcopy(list(patch_size)) 18 | 19 | pool_op_kernel_sizes = [] 20 | conv_kernel_sizes = [] 21 | 22 | num_pool_per_axis = [0] * dim 23 | 24 | while True: 25 | min_spacing = min(current_spacing) 26 | valid_axes_for_pool = [i for i in range(dim) if current_spacing[i] / min_spacing < 2] 27 | axes = [] 28 | for a in range(dim): 29 | my_spacing = current_spacing[a] 30 | partners = [i for i in range(dim) if 31 | current_spacing[i] / my_spacing < 2 and my_spacing / current_spacing[i] < 2] 32 | if len(partners) > len(axes): 33 | axes = partners 34 | conv_kernel_size = [3 if i in axes else 1 for i in range(dim)] 35 | 36 | # Exclude axes which cannot be pooled further 37 | valid_axes_for_pool = [i for i in valid_axes_for_pool if current_size[i] >= 2 * min_feature_map_size 38 | and num_pool_per_axis[i] < max_num_pool] 39 | 40 | if len(valid_axes_for_pool) == 0: 41 | break 42 | 43 | other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] 44 | 45 | pool_kernel_sizes = [0] * dim 46 | for v in valid_axes_for_pool: 47 | pool_kernel_sizes[v] = 2 48 | num_pool_per_axis[v] += 1 49 | current_spacing[v] *= 2 50 | current_size[v] = np.ceil(current_size[v] / 2) 51 | for nv in other_axes: 52 | pool_kernel_sizes[nv] = 1 53 | 54 | pool_op_kernel_sizes.append(pool_kernel_sizes) 55 | conv_kernel_sizes.append(conv_kernel_size) 56 | 57 | must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) 58 | patch_size = pad_shape(patch_size, must_be_divisible_by) 59 | 60 | # Add bottleneck conv 61 | conv_kernel_sizes.append([3] * dim) 62 | 63 | # Note: this computes the number of conv/pools for one side of the U-Net, not both 64 | return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by 65 | 66 | 67 | def get_shape_must_be_divisible_by(net_numpool_per_axis): 68 | return 2 ** np.array(net_numpool_per_axis) 69 | 70 | 71 | def pad_shape(shape, must_be_divisible_by): 72 | """ 73 | pads shape so that it is divisibly by must_be_divisible_by 74 | :param shape: 75 | :param must_be_divisible_by: 76 | :return: 77 | """ 78 | if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)): 79 | must_be_divisible_by = [must_be_divisible_by] * len(shape) 80 | else: 81 | assert len(must_be_divisible_by) == len(shape) 82 | 83 | new_shape = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))] 84 | 85 | for i in range(len(shape)): 86 | if shape[i] % must_be_divisible_by[i] == 0: 87 | new_shape[i] -= must_be_divisible_by[i] 88 | new_shape = np.array(new_shape).astype(int) 89 | return new_shape 90 | 91 | 92 | def summarise_plans(plans): 93 | print("modalities: ", plans['modalities']) 94 | print("normalization_schemes", plans['normalization_schemes']) 95 | print("stages...\n") 96 | 97 | for i in range(len(plans['plans_per_stage'])): 98 | print("stage: ", i) 99 | print(plans['plans_per_stage'][i]) 100 | print("") 101 | -------------------------------------------------------------------------------- /nnood/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/inference/export_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from pathlib import Path 4 | from typing import Union 5 | 6 | import numpy as np 7 | import SimpleITK as sitk 8 | 9 | from nnood.preprocessing.preprocessing import get_lowres_axis, get_do_separate_z, resample_data 10 | from nnood.utils.file_operations import save_pickle 11 | 12 | 13 | def save_data_as_file(data: Union[str, Path, np.ndarray], out_file_name: Path, 14 | properties_dict: dict, order: int = 1, 15 | postprocess_fn: callable = None, postprocess_args: tuple = None, 16 | resampled_npz_file_name: Path = None, 17 | non_postprocessed_file_name: Path = None, force_separate_z: bool = None, 18 | interpolation_order_z: int = 0, verbose: bool = True): 19 | """ 20 | This is a utility for writing data to nifti/png and npz. It requires the data to have been preprocessed by 21 | GenericPreprocessor because it depends on the property dictionary output (dct) to know the geometry of the original 22 | data. data does not have to have the same size in pixels as the original data, it will be 23 | resampled to match that. This is generally useful because the spacings our networks operate on are most of the time 24 | not the native spacings of the image data. 25 | If postprogess_fn is not None then postprogess_fn(data, *postprocess_args) 26 | will be called before nifti export 27 | There is a problem with python process communication that prevents us from communicating obejcts 28 | larger than 2 GB between processes (basically when the length of the pickle string that will be sent is 29 | communicated by the multiprocessing.Pipe object then the placeholder (I think) does not allow for long 30 | enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually 31 | patching system python code.) We circumvent that problem here by saving pred to a npy file that will 32 | then be read (and finally deleted) by the Process. save_score_as_nifti can take either 33 | filename or np.ndarray for data and will handle this automatically 34 | :param data: Image with shape (c, [z, ], y, x) - channels first. 35 | :param out_file_name: 36 | :param properties_dict: 37 | :param order: 38 | :param postprocess_fn: 39 | :param postprocess_args: 40 | :param resampled_npz_file_name: 41 | :param non_postprocessed_file_name: 42 | :param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always 43 | /never resample along z separately. Do not touch unless you know what you are doing 44 | :param interpolation_order_z: if separate z resampling is done then this is the order for resampling in z 45 | :param verbose: 46 | :return: 47 | """ 48 | if verbose: 49 | print("force_separate_z:", force_separate_z, "interpolation order:", order) 50 | 51 | if isinstance(data, str) or isinstance(data, Path): 52 | data = Path(data) 53 | assert data.is_file() 54 | del_file = deepcopy(data) 55 | data = np.load(data) 56 | os.remove(del_file) 57 | 58 | # first resample, then put result into bbox of cropping, then save 59 | current_shape = data.shape 60 | shape_original = properties_dict.get('original_size') 61 | 62 | if np.any([i != j for i, j in zip(np.array(current_shape[1:]), np.array(shape_original))]): 63 | if force_separate_z is None: 64 | if get_do_separate_z(properties_dict.get('original_spacing')): 65 | do_separate_z = True 66 | lowres_axis = get_lowres_axis(properties_dict.get('original_spacing')) 67 | elif get_do_separate_z(properties_dict.get('spacing_after_resampling')): 68 | do_separate_z = True 69 | lowres_axis = get_lowres_axis(properties_dict.get('spacing_after_resampling')) 70 | else: 71 | do_separate_z = False 72 | lowres_axis = None 73 | else: 74 | do_separate_z = force_separate_z 75 | if do_separate_z: 76 | lowres_axis = get_lowres_axis(properties_dict.get('original_spacing')) 77 | else: 78 | lowres_axis = None 79 | 80 | if lowres_axis is not None and len(lowres_axis) != 1: 81 | # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample 82 | # separately in the out of plane axis 83 | do_separate_z = False 84 | 85 | if verbose: 86 | print("separate z:", do_separate_z, "lowres axis", lowres_axis) 87 | data_old_spacing = resample_data(data, shape_original, axis=lowres_axis, order=order, 88 | do_separate_z=do_separate_z, order_z=interpolation_order_z) 89 | else: 90 | if verbose: 91 | print("no resampling necessary") 92 | data_old_spacing = data 93 | 94 | if resampled_npz_file_name is not None: 95 | np.savez_compressed(resampled_npz_file_name, data=data_old_spacing) 96 | save_pickle(properties_dict, resampled_npz_file_name.with_suffix('.pkl')) 97 | 98 | # Currently we don't have a separate cropping stage, due to the lack of map to say where background is 99 | # I'll leave this logic commented in case we want it back later 100 | 101 | # bbox = properties_dict.get('crop_bbox') 102 | # if bbox is not None: 103 | # data_old_size = np.zeros(shape_original_before_cropping) 104 | # for c in range(3): 105 | # bbox[c][1] = np.min((bbox[c][0] + data_old_spacing.shape[c], shape_original_before_cropping[c])) 106 | # data_old_size[bbox[0][0]:bbox[0][1], 107 | # bbox[1][0]:bbox[1][1], 108 | # bbox[2][0]:bbox[2][1]] = data_old_spacing 109 | # else: 110 | # data_old_size = data_old_spacing 111 | 112 | if postprocess_fn is not None: 113 | data_old_size_postprocessed = postprocess_fn(np.copy(data_old_spacing), *postprocess_args) 114 | else: 115 | data_old_size_postprocessed = data_old_spacing 116 | 117 | def _save_array(data_to_save: np.ndarray, file_path: Path): 118 | # Move to channels last, to match SITK representation 119 | data_to_save = np.moveaxis(data_to_save, 0, -1) 120 | 121 | if file_path.suffix == '.png': 122 | # To be saved as a png the image must be uint8, with values in range [0,255] 123 | data_to_save = (data_to_save * 255).astype(np.uint8) 124 | 125 | data_to_save_itk = sitk.GetImageFromArray(data_to_save, isVector=True) 126 | data_to_save_itk.SetSpacing(properties_dict['itk_spacing']) 127 | data_to_save_itk.SetOrigin(properties_dict['itk_origin']) 128 | data_to_save_itk.SetDirection(properties_dict['itk_direction']) 129 | sitk.WriteImage(data_to_save_itk, file_path.__str__()) 130 | 131 | _save_array(data_old_size_postprocessed, out_file_name) 132 | 133 | if (non_postprocessed_file_name is not None) and (postprocess_fn is not None): 134 | _save_array(data_old_spacing, non_postprocessed_file_name) 135 | -------------------------------------------------------------------------------- /nnood/inference/model_restore.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | 5 | import nnood 6 | from nnood.training.network_training.nnOODTrainer import nnOODTrainer 7 | from nnood.utils.file_operations import load_pickle 8 | from nnood.utils.miscellaneous import recursive_find_python_class 9 | 10 | 11 | def restore_model(pkl_file: Path, checkpoint=None, train=False, fp16=None) -> nnOODTrainer: 12 | """ 13 | This is a utility function to load any nnOODTrainer from a pkl. It will recursively search 14 | nnood.training.network_training for the file that contains the trainer and instantiate it with the arguments saved 15 | in the pkl file. If checkpoint is specified, it will furthermore load the checkpoint file in train/test mode (as 16 | specified by train). The pkl file required here is the one that will be saved automatically when calling 17 | nnOODTrainer.save_checkpoint. 18 | :param pkl_file: 19 | :param checkpoint: 20 | :param train: 21 | :param fp16: if None then we take no action. If True/False we overwrite what the model has in its init 22 | :return: 23 | """ 24 | info = load_pickle(pkl_file) 25 | init = info['init'] 26 | trainer_class_name = info['name'] 27 | 28 | trainer_class = recursive_find_python_class([Path(nnood.__path__[0], 'training', 'network_training').__str__()], 29 | trainer_class_name, current_module='nnood.training.network_training') 30 | 31 | if trainer_class is None: 32 | raise RuntimeError('Could not find the model trainer specified in checkpoint in' 33 | 'nnood.training.network_training. If it is not locatd there, please move it or change the' 34 | 'code of restore_model. Your model can be located in any directory within' 35 | 'nnood.training.network_training (search is recursive. \n Debug info: \n checkpoint file:' 36 | f'{checkpoint}\nName of trainer: {trainer_class_name}') 37 | 38 | assert issubclass(trainer_class, nnOODTrainer), 'The network trainer was found but is not a subclass of ' \ 39 | 'nnOODTrainer. Please make it so!' 40 | 41 | # From nnUNet, meaning lost to time: 42 | # ToDo Fabian make saves use kwargs, please... 43 | 44 | trainer = trainer_class(*init) 45 | 46 | # We can hack fp16 overwriting into the trainer without changing the init arguments because nothing happens with 47 | # fp16 in the init, it just saves it to a member variable 48 | if fp16 is not None: 49 | trainer.fp16 = fp16 50 | 51 | trainer.process_plans(info['plans']) 52 | if checkpoint is not None: 53 | trainer.load_checkpoint(checkpoint, train) 54 | return trainer 55 | 56 | 57 | def load_model_and_checkpoint_files(folder: Path, folds=None, mixed_precision=None, checkpoint_name='model_best'): 58 | """ 59 | used for if you need to ensemble the five models of a cross-validation. This will restore the model from the 60 | checkpoint in fold 0, load all parameters of the five folds in ram and return both. This will allow for fast 61 | switching between parameters (as opposed to loading them form disk each time). 62 | This is best used for inference and test prediction 63 | :param folder: 64 | :param folds: 65 | :param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init 66 | :param checkpoint_name: 67 | :return: 68 | """ 69 | if isinstance(folds, str): 70 | folds = [folder / 'all'] 71 | assert folds[0].is_dir(), f'no output folder for fold {folds} found' 72 | elif isinstance(folds, (list, tuple)): 73 | if len(folds) == 1 and folds[0] == 'all': 74 | folds = [folder / 'all'] 75 | else: 76 | folds = [folder / f'fold_{i}' for i in folds] 77 | elif isinstance(folds, int): 78 | folds = [folder / f'fold_{folds}'] 79 | elif folds is None: 80 | print('folds is None so we will automatically look for output folders (not using \'all\'!)') 81 | folds = [f for f in folder.iterdir() if f.is_dir() and f.name.startswith('fold')] 82 | print('found the following folds: ', folds) 83 | else: 84 | raise ValueError(f'Unknown value for folds. Type: {type(folds)}. Expected: list of int, int, str or None') 85 | 86 | assert all([f.is_dir() for f in folds]), 'list of folds specified but not all output folders are present' 87 | 88 | trainer = restore_model(folds[0] / f'{checkpoint_name}.pkl', fp16=mixed_precision) 89 | trainer.output_folder = folder 90 | trainer.output_folder_base = folder 91 | # I think fold is set as otherwise load_best_checkpoint raises an exception, even though trainer.fold isn't used 92 | # during inference 93 | trainer.update_fold(0) 94 | trainer.initialize(False) 95 | all_model_files = [f / f'{checkpoint_name}.model' for f in folds] 96 | print('using the following model files: ', all_model_files) 97 | all_params = [torch.load(i, map_location=torch.device('cpu')) for i in all_model_files] 98 | return trainer, all_params 99 | -------------------------------------------------------------------------------- /nnood/inference/predict_simple.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | from nnood.inference.predict import predict_from_folder 7 | from nnood.paths import default_plans_identifier, results_base 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-i', '--input_folder', help='Must contain all modalities for each patient in the correct' 13 | ' order (same as training). Files must be named CASENAME_XXXX.ext ' 14 | 'where XXXX is the modality and ext is the file extension' 15 | 'identifier (0000, 0001, etc)', required=True) 16 | parser.add_argument('-o', '--output_folder', help='folder for saving predictions', required=True) 17 | parser.add_argument('-d', '--dataset', help='Dataset which the model was trained on.', required=True) 18 | parser.add_argument('-t', '--task_name', help='Self-supervised task which the model was trained on', required=True) 19 | parser.add_argument('-tr', '--trainer_class_name', 20 | help='Name of the nnOODTrainer used for full resolution and low resolution U-Net. If you are ' 21 | 'running inference with the cascade and the folder pointed to by --lowres_maps ' 22 | 'does not contain the anomaly maps generated by the low resolution U-Net then the low ' 23 | 'resolution anomaly maps will be automatically generated. For this case, make sure to set ' 24 | 'the trainer class here that matches your --cascade_trainer_class_name.', 25 | required=True) 26 | parser.add_argument('-ctr', '--cascade_trainer_class_name', 27 | help='Trainer class name used for predicting the full resolution U-Net part of the cascade.', 28 | default=None, required=False) 29 | 30 | parser.add_argument('-m', '--model', help='lowres, fullres or cascade_fullres. Default: fullres', 31 | default='fullres', required=False) 32 | 33 | parser.add_argument('-p', '--plans_identifier', help='do not touch this unless you know what you are doing', 34 | default=default_plans_identifier, required=False) 35 | 36 | parser.add_argument('-f', '--folds', nargs='+', default='None', 37 | help='folds to use for prediction. Default is None which means that folds will be detected ' 38 | 'automatically in the model output folder') 39 | 40 | parser.add_argument('-z', '--save_npz', required=False, action='store_true', 41 | help='use this if you want to ensemble these predictions with those of other models.') 42 | 43 | parser.add_argument('-l', '--lowres_scores', required=False, default='None', 44 | help='if model is the highres stage of the cascade then you can use this folder to provide ' 45 | 'predictions from the low resolution U-Net (as numpy files). If this is left at default, ' 46 | 'the predictions will be generated automatically (provided that the low resolution U-Net ' 47 | 'network weights are present') 48 | 49 | parser.add_argument('--part_id', type=int, required=False, default=0, 50 | help='Used to parallelize the prediction of the folder over several GPUs. If you want to use n ' 51 | 'GPUs to predict this folder you need to run this command n times with --part_id=0, ... ' 52 | 'n-1 and --num_parts=n (each with a different GPU (for example via ' 53 | 'CUDA_VISIBLE_DEVICES=X)') 54 | 55 | parser.add_argument('--num_parts', type=int, required=False, default=1, 56 | help='Used to parallelize the prediction of the folder over several GPUs. If you want to use n ' 57 | 'GPUs to predict this folder you need to run this command n times with --part_id=0, ... ' 58 | 'n-1 and --num_parts=n (each with a different GPU (via CUDA_VISIBLE_DEVICES=X)') 59 | 60 | parser.add_argument('--num_threads_preprocessing', required=False, default=6, type=int, 61 | help='Determines many background processes will be used for data preprocessing. Reduce this if ' 62 | 'you run into out of memory (RAM) problems. Default: 6') 63 | 64 | parser.add_argument('--num_threads_save', required=False, default=2, type=int, 65 | help='Determines many background processes will be used for exporting. Reduce this if you run ' 66 | 'into out of memory (RAM) problems. Default: 2') 67 | 68 | parser.add_argument('--disable_tta', required=False, default=False, action='store_true', 69 | help='set this flag to disable test time data augmentation via mirroring. Speeds up inference ' 70 | 'by roughly factor 4 (2D) or 8 (3D)') 71 | 72 | parser.add_argument('--overwrite_existing', required=False, default=False, action='store_true', 73 | help='Set this flag if the target folder contains predictions that you would like to overwrite') 74 | 75 | parser.add_argument('--all_in_gpu', type=str, default='None', required=False, help='can be None, False or True. ' 76 | 'Do not touch.') 77 | parser.add_argument('--step_size', type=float, default=0.5, required=False, help='don\'t touch') 78 | parser.add_argument('-chk', 79 | help='checkpoint name, default: model_final_checkpoint', 80 | required=False, 81 | default='model_final_checkpoint') 82 | parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False, 83 | help='Predictions are done with mixed precision by default. This improves speed and reduces ' 84 | 'the required vram. If you want to disable mixed precision you can set this flag. Note ' 85 | 'that this is not recommended (mixed precision is ~2x faster!)') 86 | 87 | args = parser.parse_args() 88 | input_folder = args.input_folder 89 | output_folder = args.output_folder 90 | dataset = args.dataset 91 | part_id = args.part_id 92 | num_parts = args.num_parts 93 | folds = args.folds 94 | save_npz = args.save_npz 95 | lowres_scores = args.lowres_scores 96 | num_threads_preprocessing = args.num_threads_preprocessing 97 | num_threads_save = args.num_threads_save 98 | disable_tta = args.disable_tta 99 | step_size = args.step_size 100 | overwrite_existing = args.overwrite_existing 101 | all_in_gpu = args.all_in_gpu 102 | model = args.model 103 | trainer_class_name = args.trainer_class_name 104 | cascade_trainer_class_name = args.cascade_trainer_class_name 105 | task_name = args.task_name 106 | 107 | assert model in ['lowres', 'fullres', 'cascade_fullres'], '-m must be lowres, fullres or cascade_fullres' 108 | 109 | input_folder = Path(input_folder) 110 | output_folder = Path(output_folder) 111 | 112 | if lowres_scores == 'None': 113 | lowres_scores = None 114 | else: 115 | lowres_scores = Path(lowres_scores) 116 | 117 | if isinstance(folds, list): 118 | if folds[0] == 'all' and len(folds) == 1: 119 | pass 120 | else: 121 | folds = [int(i) for i in folds] 122 | elif folds == 'None': 123 | folds = None 124 | else: 125 | raise ValueError('Unexpected value for argument folds') 126 | 127 | assert all_in_gpu in ['None', 'False', 'True'] 128 | if all_in_gpu == 'None': 129 | all_in_gpu = None 130 | elif all_in_gpu == 'True': 131 | all_in_gpu = True 132 | elif all_in_gpu == 'False': 133 | all_in_gpu = False 134 | 135 | # we need to catch the case where model is cascade fullres and the low resolution folder has not been set. 136 | # In that case we need to try and predict with lowres first 137 | if model == 'cascade_fullres' and lowres_scores is None: 138 | print('lowres_scores is None. Attempting to predict lowres first...') 139 | assert part_id == 0 and num_parts == 1, 'if you don\'t specify a --lowres_scores folder for the ' \ 140 | 'inference of the cascade, custom values for part_id and num_parts ' \ 141 | 'are not supported. If you wish to have multiple parts, please ' \ 142 | 'run the lowres inference first (separately)' 143 | model_folder = Path(results_base, dataset, task_name, 'lowres', trainer_class_name + '__' + 144 | args.plans_identifier) 145 | assert model_folder.is_dir(), 'model output folder not found. Expected: %s' % model_folder 146 | lowres_output_folder = output_folder / 'lowres_predictions' 147 | predict_from_folder(model_folder, input_folder, lowres_output_folder, folds, False, num_threads_preprocessing, 148 | num_threads_save, None, part_id, num_parts, not disable_tta, 149 | mixed_precision=not args.disable_mixed_precision, overwrite_existing=overwrite_existing, 150 | overwrite_all_in_gpu=all_in_gpu, step_size=step_size, checkpoint_name=args.chk) 151 | lowres_scores = lowres_output_folder 152 | torch.cuda.empty_cache() 153 | print('lowres done') 154 | 155 | if model == 'cascade_fullres': 156 | assert cascade_trainer_class_name is not None, 'Cannot use cascade_fullres model without defining' \ 157 | 'cascade_trainer_class_name' 158 | trainer = cascade_trainer_class_name 159 | else: 160 | trainer = trainer_class_name 161 | 162 | model_folder = Path(results_base, dataset, task_name, model, trainer + '__' + args.plans_identifier) 163 | print('using model stored in ', model_folder) 164 | assert model_folder.is_dir(), 'model output folder not found. Expected: %s' % model_folder 165 | 166 | predict_from_folder(model_folder, input_folder, output_folder, folds, save_npz, num_threads_preprocessing, 167 | num_threads_save, lowres_scores, part_id, num_parts, not disable_tta, 168 | mixed_precision=not args.disable_mixed_precision, overwrite_existing=overwrite_existing, 169 | overwrite_all_in_gpu=all_in_gpu, step_size=step_size, checkpoint_name=args.chk) 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | -------------------------------------------------------------------------------- /nnood/network_architecture/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/network_architecture/initialisation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class InitWeights_He(object): 5 | def __init__(self, neg_slope=1e-2): 6 | self.neg_slope = neg_slope 7 | 8 | def __call__(self, module): 9 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 10 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) 11 | if module.bias is not None: 12 | module.bias = nn.init.constant_(module.bias, 0) -------------------------------------------------------------------------------- /nnood/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from pathlib import Path 4 | 5 | RAW_ENVIRON_VAR = 'nnood_raw_data_base' 6 | PREPROCESSED_ENVIRON_VAR = 'nnood_preprocessed_data_base' 7 | RESULTS_ENVIRON_VAR = 'nnood_results_base' 8 | 9 | DATASET_JSON_FILE = 'dataset.json' 10 | DATASET_PROPERTIES_FILE = 'dataset_properties.pkl' 11 | 12 | default_plans_identifier = 'nnood_plans_v1.0' 13 | default_data_identifier = 'nnood_data_v1.0' 14 | 15 | def setup_directory(var_name: str) -> Optional[str]: 16 | var_value = os.environ[var_name] if var_name in os.environ else None 17 | 18 | if var_value is not None: 19 | dir_path = Path(var_value) 20 | dir_path.mkdir(parents=True, exist_ok=True) 21 | else: 22 | print(var_name + ' is not defined, preventing nnood from completing any actions involving it\'s files') 23 | return var_value 24 | 25 | 26 | raw_data_base = setup_directory(RAW_ENVIRON_VAR) 27 | preprocessed_data_base = setup_directory(PREPROCESSED_ENVIRON_VAR) 28 | results_base = setup_directory(RESULTS_ENVIRON_VAR) 29 | -------------------------------------------------------------------------------- /nnood/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/preprocessing/foreground_mask.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from skimage import filters, measure, morphology, segmentation 5 | 6 | from nnood.utils.miscellaneous import make_hypersphere_mask 7 | 8 | 9 | def get_object_mask(img): 10 | # Excluding channels dimension 11 | num_channels = img.shape[0] 12 | image_shape = np.array(img.shape)[1:] 13 | 14 | # 5% of each dimension length 15 | corner_lens = image_shape // 20 16 | 17 | # Used for upper_bounds, so use dim lengths 18 | img_corners = list(itertools.product(*[[0, s] for s in image_shape])) 19 | 20 | all_corner_ranges = [[(0, l) if c == 0 else (c - l, c) for c, l in zip(c_coord, corner_lens)] 21 | for c_coord in img_corners] 22 | corner_patch_slices = [tuple([slice(lb, ub) for lb, ub in cr]) for cr in all_corner_ranges] 23 | 24 | num_corner_seed_points = 2 ** len(image_shape) 25 | 26 | masks = [] 27 | 28 | for i in range(num_channels): 29 | 30 | sobel_channel = filters.sobel(img[i]) 31 | 32 | curr_channel_masks = [] 33 | 34 | for c_c, c_r, c_s in zip(img_corners, all_corner_ranges, corner_patch_slices): 35 | 36 | patch_tolerance = sobel_channel[c_s].std() 37 | 38 | for _ in range(num_corner_seed_points): 39 | random_c = tuple([np.random.randint(lb, ub) for lb, ub in c_r]) 40 | 41 | curr_channel_masks.append(segmentation.flood(sobel_channel, random_c, tolerance=patch_tolerance)) 42 | 43 | masks.append(np.any(np.stack(curr_channel_masks), axis=0)) 44 | 45 | bg_mask = np.all(np.stack(masks), axis=0) 46 | fg_mask = np.logical_not(bg_mask) 47 | 48 | def get_biggest_connected_component(m): 49 | label_m = measure.label(m) 50 | region_sizes = np.bincount(label_m.flatten()) 51 | # Zero size of background, so is ignored when finding biggest region 52 | region_sizes[0] = 0 53 | biggest_region_label = np.argmax(region_sizes) 54 | return label_m == biggest_region_label 55 | 56 | init_biggest_region = get_biggest_connected_component(fg_mask) 57 | 58 | # Apply binary opening to mask of largest object, to smooth out edges / disconnect any spurious 59 | opening_structure_r = max(np.median(image_shape) // 250, 1) 60 | opening_structure = make_hypersphere_mask(opening_structure_r, len(image_shape)) 61 | opened_biggest_region = morphology.binary_opening(init_biggest_region, footprint=opening_structure) 62 | 63 | return get_biggest_connected_component(opened_biggest_region) 64 | -------------------------------------------------------------------------------- /nnood/preprocessing/normalisation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | IMAGENET_STATS = { 4 | 'png-r': (0.485, 0.456), 5 | 'png-g': (0.456, 0.224), 6 | 'png-b': (0.406, 0.225) 7 | } 8 | 9 | GLOBAL_NORMALISATION_MODALITIES = ['ct', 'png-bw'] 10 | 11 | 12 | def normalise_png_channel(channel_data: np.ndarray, channel: str, norm_fn): 13 | c_mean, c_std = IMAGENET_STATS[channel] 14 | return norm_fn(channel_data, c_mean, c_std) 15 | 16 | 17 | # Either normalises or denormalises, depending on is_norm 18 | def _norm_helper(data, normalisation_scheme_per_modality, intensity_properties, channel_properties, is_norm: bool): 19 | 20 | norm_fn = (lambda d, mean, std: (d - mean) / std) if is_norm else (lambda d, mean, std: d * std + mean) 21 | norm_fn_stable = lambda d, mean, std: norm_fn(d, mean, std + 1e-8) 22 | 23 | assert len(normalisation_scheme_per_modality) == len(data), 'self.normalisation_scheme_per_modality ' \ 24 | 'must have as many entries as data has ' \ 25 | 'modalities' 26 | 27 | result = [] 28 | # Without a GT segmentation, we cannot use the same nonzero_mask for normalisation 29 | for c in range(len(data)): 30 | scheme = normalisation_scheme_per_modality[c] 31 | curr_channel = data[c] 32 | 33 | if scheme == 'ct': 34 | # clip to lb and ub from train data foreground and use foreground mn and sd from training data 35 | assert intensity_properties is not None, 'Cannot normalise CT without intensity properties' 36 | lower_bound = intensity_properties[c]['percentile_00_5'] 37 | upper_bound = intensity_properties[c]['percentile_99_5'] 38 | 39 | if is_norm: 40 | curr_channel = np.clip(curr_channel, lower_bound, upper_bound) 41 | else: 42 | print('WARNING: when denormalising a CT image we cannot invert the clipping, so result will not be ' 43 | 'exact') 44 | 45 | result.append(norm_fn(curr_channel, intensity_properties[c]['mean'], intensity_properties[c]['sd'])) 46 | elif scheme == 'global-z': 47 | assert intensity_properties is not None, f'Cannot normalise modality {c} without intensity properties' 48 | 49 | result.append(norm_fn(curr_channel, intensity_properties[c]['mean'], intensity_properties[c]['sd'])) 50 | elif scheme == 'noNorm': 51 | pass 52 | elif scheme in IMAGENET_STATS.keys(): 53 | result.append(normalise_png_channel(curr_channel, scheme, norm_fn)) 54 | elif scheme == 'z-score': 55 | result.append(norm_fn_stable(curr_channel, channel_properties[c]['mean'], channel_properties[c]['sd'])) 56 | else: 57 | assert False, f'Unrecognised normalisation scheme: {scheme}' 58 | 59 | return np.stack(result) 60 | 61 | 62 | def normalise(data, normalisation_scheme_per_modality, intensity_properties, channel_properties): 63 | return _norm_helper(data, normalisation_scheme_per_modality, intensity_properties, channel_properties, True) 64 | 65 | 66 | def denormalise(data, normalisation_scheme_per_modality, intensity_properties, channel_properties): 67 | return _norm_helper(data, normalisation_scheme_per_modality, intensity_properties, channel_properties, False) 68 | 69 | 70 | def modality_norm_scheme(mod: str): 71 | 72 | if mod == 'CT' or mod == 'ct': 73 | return 'ct' 74 | elif mod in GLOBAL_NORMALISATION_MODALITIES: 75 | return 'global-z' 76 | elif mod == 'noNorm': 77 | return 'noNorm' 78 | elif mod in IMAGENET_STATS.keys(): 79 | return mod 80 | else: 81 | return 'z-score' 82 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/cutpaste.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import List, Tuple 7 | 8 | from nnood.configuration import default_num_processes 9 | from nnood.training.dataloading.dataset_loading import load_npy_or_npz 10 | from nnood.self_supervised_task.patch_blender import SwapPatchBlender 11 | from nnood.self_supervised_task.patch_ex import patch_ex 12 | from nnood.self_supervised_task.patch_shape_maker import UnequalUniformPatchMaker 13 | from nnood.self_supervised_task.patch_labeller import BinaryPatchLabeller 14 | from nnood.self_supervised_task.patch_transforms.colour_transforms import AdjustContrast, PsuedoAdjustBrightness 15 | from nnood.self_supervised_task.patch_transforms.spatial_transforms import TranslatePatch 16 | from nnood.self_supervised_task.self_sup_task import SelfSupTask 17 | 18 | 19 | def cutpaste_sample_dimensions(_: List[Tuple[int, int]], img_dims: np.ndarray): 20 | num_dimensions = len(img_dims) 21 | 22 | patch_area = np.random.uniform(0.02, 0.15) * np.product(img_dims) 23 | dim_aspect_ratios = [np.random.choice([np.random.uniform(0.3, 1), np.random.uniform(1, 3.3)]) 24 | for _ in range(num_dimensions - 1)] 25 | 26 | area_root = np.power(patch_area, 1 / num_dimensions) 27 | 28 | shape = [] 29 | for i in range(num_dimensions): 30 | # First is multplied by sqrt(AR[0]) 31 | # Last is multiplied by sqrt(1 / AR[-1]) 32 | # All others are multiplied by sqrt(AR[i] / AR[i - 1]) 33 | # Causes product(shape) = Area 34 | 35 | num = 1.0 if i == num_dimensions - 1 else dim_aspect_ratios[i] 36 | den = 1.0 if i == 0 else dim_aspect_ratios[i - 1] 37 | 38 | shape.append(np.round(area_root * np.sqrt(num / den)).astype(int)) 39 | 40 | # Shuffle dimensions, to avoid early/later dimensions being more extreme than middle ones 41 | np.random.shuffle(shape) 42 | 43 | return shape 44 | 45 | 46 | def _load_get_min(f): 47 | return load_npy_or_npz(f, 'r').min() 48 | 49 | 50 | class CutPaste(SelfSupTask): 51 | 52 | def __init__(self): 53 | self.shape_maker = UnequalUniformPatchMaker(sample_dist=cutpaste_sample_dimensions, calc_dims_together=True) 54 | self.transformations = [AdjustContrast(0.1), TranslatePatch()] 55 | self.blender = SwapPatchBlender() 56 | self.labeller = BinaryPatchLabeller() 57 | self.calibrated = False 58 | 59 | def calibrate(self, dataset, exp_plans): 60 | 61 | if not self.calibrated: 62 | print('Calibrating CutPaste...') 63 | dataset_min_val = np.inf 64 | 65 | files_to_load = [v['data_file'] for v in dataset.values()] 66 | 67 | with Pool(default_num_processes) as p: 68 | mins = p.map(_load_get_min, files_to_load) 69 | 70 | dataset_min_val = np.minimum(dataset_min_val, np.amin(mins)) 71 | 72 | assert dataset_min_val != np.inf, 'Minimum dataset value np.inf, is the dataset empty?' 73 | 74 | self.transformations = [PsuedoAdjustBrightness(0.1, dataset_min_val)] + self.transformations 75 | self.calibrated = True 76 | else: 77 | print('WARNING: CutPaste has already been calibrated, cannot be done again.') 78 | 79 | def apply(self, sample, sample_mask, sample_properties, sample_fn=None, dest_bbox=None, return_locations=False): 80 | # Note: Width_bounds_pct aren't used to sample dimensions (as CutPaste chooses them based on patch area and 81 | # aspect ratios, meaning they only serve to decide how close to the edge we put patches to the edge. UB is 82 | # meaningless 83 | 84 | if not self.calibrated: 85 | print('WARNING: CutPaste has not been calibrated, so cannot use PseudoBrightness transform') 86 | 87 | result = patch_ex(sample, same=True, shape_maker=self.shape_maker, patch_transforms=self.transformations, 88 | blender=self.blender, labeller=self.labeller, binary_factor=True, dest_bbox=dest_bbox, 89 | return_anomaly_locations=return_locations, width_bounds_pct=(0.15, 0.30)) 90 | return result 91 | 92 | def loss(self, pred, target): 93 | return F.binary_cross_entropy_with_logits(pred, target) 94 | 95 | def label_is_seg(self): 96 | return False 97 | 98 | def inference_nonlin(self, data): 99 | return torch.sigmoid(data) 100 | 101 | 102 | if __name__ == '__main__': 103 | np.seterr(all='raise') 104 | test_input = np.random.random((3, 50, 50)) 105 | test_output = CutPaste()(test_input)[0] 106 | 107 | assert not (test_input == test_output).all() 108 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/fpi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from nnood.self_supervised_task.patch_blender import UniformPatchBlender 5 | from nnood.self_supervised_task.patch_ex import patch_ex 6 | from nnood.self_supervised_task.patch_shape_maker import EqualUniformPatchMaker 7 | from nnood.self_supervised_task.patch_labeller import ContinuousPatchLabeller 8 | from nnood.self_supervised_task.self_sup_task import SelfSupTask 9 | 10 | 11 | class FPI(SelfSupTask): 12 | 13 | def __init__(self): 14 | self.shape_maker = EqualUniformPatchMaker() 15 | self.prev_sample = None 16 | self.blender = UniformPatchBlender() 17 | self.labeller = ContinuousPatchLabeller() 18 | 19 | def apply(self, sample, sample_mask, sample_properties, sample_fn=None, dest_bbox=None, return_locations=False): 20 | src = sample_fn(False)[0] if self.prev_sample is None else self.prev_sample 21 | if isinstance(src, tuple): 22 | src = src[0] 23 | 24 | result = patch_ex(sample, src, shape_maker=self.shape_maker, blender=self.blender, labeller=self.labeller, 25 | binary_factor=False, dest_bbox=dest_bbox, extract_within_bbox=True, 26 | return_anomaly_locations=return_locations, width_bounds_pct=(0.1, 0.4)) 27 | self.prev_sample = sample 28 | return result 29 | 30 | def loss(self, pred, target): 31 | return F.binary_cross_entropy_with_logits(pred, target) 32 | 33 | def label_is_seg(self): 34 | return False 35 | 36 | def inference_nonlin(self, data): 37 | return torch.sigmoid(data) 38 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/nsa.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | 3 | import numpy as np 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | from nnood.configuration import default_num_processes 8 | from nnood.self_supervised_task.nsa_utils import nsa_sample_dimension, compute_nsa_mask_params 9 | from nnood.self_supervised_task.patch_labeller import IntensityPatchLabeller, LogisticIntensityPatchLabeller 10 | from nnood.self_supervised_task.patch_blender import PoissonPatchBlender 11 | from nnood.self_supervised_task.patch_ex import patch_ex 12 | from nnood.self_supervised_task.patch_shape_maker import UnequalUniformPatchMaker 13 | from nnood.self_supervised_task.patch_transforms.spatial_transforms import ResizePatch, TranslatePatch 14 | from nnood.self_supervised_task.self_sup_task import SelfSupTask 15 | from nnood.training.dataloading.dataset_loading import load_npy_or_npz 16 | 17 | 18 | class NSA(SelfSupTask): 19 | def __init__(self, mix_gradients=False): 20 | self.mix_gradients = mix_gradients 21 | self.min_obj_overlap_pct = 0.25 22 | self.shape_maker = UnequalUniformPatchMaker(sample_dist=nsa_sample_dimension) 23 | self.transforms = [ResizePatch(min_overlap_pct=0, validate_position=False), 24 | TranslatePatch(min_overlap_pct=self.min_obj_overlap_pct)] 25 | self.blender = PoissonPatchBlender(self.mix_gradients) 26 | 27 | self.prev_sample = self.prev_sample_mask = None 28 | 29 | self._calibrated = False 30 | self.width_bounds_pct = self.labeller = self.min_obj_pct = self.class_has_foreground = self.num_patches = None 31 | 32 | def apply(self, sample, sample_mask, sample_properties, sample_fn=None, dest_bbox=None, return_locations=False): 33 | src = sample_fn(sample_mask is not None)[0] if self.prev_sample is None else self.prev_sample 34 | 35 | if sample_mask is not None: 36 | if self.prev_sample is None: 37 | src_mask = src[1] 38 | src = src[0] 39 | else: 40 | src_mask = self.prev_sample_mask 41 | else: 42 | src_mask = None 43 | 44 | assert self._calibrated, 'NSA task requires calibration!' 45 | 46 | result = patch_ex(sample, src, 47 | shape_maker=self.shape_maker, 48 | patch_transforms=self.transforms, 49 | blender=self.blender, 50 | labeller=self.labeller, 51 | binary_factor=True, 52 | # 0 overlap for initial patch source location, as we check that when translating patch. 53 | min_overlap_pct=0.0, 54 | width_bounds_pct=self.width_bounds_pct, 55 | min_object_pct=self.min_obj_pct if self.class_has_foreground else None, 56 | dest_bbox=dest_bbox, 57 | num_patches=self.num_patches, 58 | skip_background=(sample_mask, src_mask) if self.class_has_foreground else None, 59 | return_anomaly_locations=return_locations) 60 | 61 | self.prev_sample = sample 62 | self.prev_sample_mask = sample_mask 63 | return result 64 | 65 | def _load_img_m_pair(self, f): 66 | curr_img = load_npy_or_npz(f, 'r', self.class_has_foreground) 67 | 68 | if self.class_has_foreground: 69 | return curr_img # Actually a tuple of image and mask 70 | else: 71 | return curr_img, None 72 | 73 | def _collect_continuous_NSA_examples(self, files_to_load): 74 | last_img, last_img_m = self._load_img_m_pair(files_to_load[0]['data_file']) 75 | 76 | patch_changes = [] 77 | temp_labeller = IntensityPatchLabeller() 78 | 79 | for j in files_to_load[1:]: 80 | new_img, new_img_m = self._load_img_m_pair(j['data_file']) 81 | 82 | _, cont_label = patch_ex(new_img, last_img, 83 | shape_maker=self.shape_maker, 84 | patch_transforms=self.transforms, 85 | blender=self.blender, 86 | labeller=temp_labeller, 87 | binary_factor=True, 88 | min_overlap_pct=0.0, 89 | width_bounds_pct=self.width_bounds_pct, 90 | min_object_pct=self.min_obj_pct, 91 | num_patches=self.num_patches, 92 | skip_background=(new_img_m, last_img_m) 93 | if self.class_has_foreground else None) 94 | 95 | patch_changes.append(cont_label[cont_label > 0]) 96 | 97 | last_img = new_img 98 | last_img_m = new_img_m 99 | 100 | return np.concatenate(patch_changes) 101 | 102 | def calibrate(self, dataset, exp_plans): 103 | if not self._calibrated: 104 | 105 | data_num_dims = len(exp_plans['transpose_forward']) 106 | self.class_has_foreground = exp_plans['dataset_properties']['has_uniform_background'] 107 | 108 | # Compute NSA parameters based on the object masks 109 | self.width_bounds_pct, self.num_patches, self.min_obj_pct = \ 110 | compute_nsa_mask_params(self.class_has_foreground, dataset, data_num_dims) 111 | 112 | # Measure distribution of changes caused by NSA anomalies 113 | keys = list(dataset.keys()) 114 | 115 | with Pool(default_num_processes) as pool: 116 | num_test_samples = 500 117 | samples_per_process = num_test_samples // default_num_processes 118 | 119 | all_patch_changes = pool.map(self._collect_continuous_NSA_examples, 120 | [[dataset[keys[j % len(keys)]] 121 | for j in range(i, i + samples_per_process)] 122 | for i in range(0, num_test_samples, samples_per_process)]) 123 | 124 | all_patch_changes = np.concatenate(all_patch_changes) 125 | 126 | # Calculate logistic function parameters such that: 127 | # - lower bound of patch labels is 0.1. 128 | # - patches saturate at 40th percentile of changes observed. 129 | 130 | scale = np.log(99 * 9) / np.percentile(all_patch_changes, 40) 131 | x0 = np.log(9) / scale 132 | self.labeller = LogisticIntensityPatchLabeller(scale, x0) 133 | 134 | self._calibrated = True 135 | else: 136 | print('WARNING: NSA has already been calibrated, cannot be done again.') 137 | 138 | def loss(self, pred, target): 139 | return F.binary_cross_entropy_with_logits(pred, target) 140 | 141 | def label_is_seg(self): 142 | return False 143 | 144 | def inference_nonlin(self, data): 145 | return torch.sigmoid(data) 146 | 147 | 148 | class NSAMixed(NSA): 149 | 150 | def __init__(self): 151 | super().__init__(mix_gradients=True) 152 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/nsa_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from multiprocessing import Pool 3 | 4 | import numpy as np 5 | 6 | from nnood.configuration import default_num_processes 7 | from nnood.training.dataloading.dataset_loading import load_npy_or_npz 8 | 9 | 10 | def get_avg_mask_bounds(m): 11 | non_zero_coords = np.where(m) 12 | dims = len(m.shape) 13 | 14 | unique_non_zero_coords = np.array([np.unique(cs) for cs in non_zero_coords], 15 | dtype=object) 16 | 17 | obj_avg_dim_lens = [] 18 | 19 | for d in range(dims): 20 | other_ds = [d2 for d2 in range(dims) if d2 != d] 21 | 22 | other_coord_combs = itertools.product(*(unique_non_zero_coords[other_ds])) 23 | 24 | curr_obj_dim_lens = [] 25 | 26 | # For each combination of other coordinates 27 | for coord_comb in other_coord_combs: 28 | 29 | test = [non_zero_coords[d2] == c2 for d2, c2 in zip(other_ds, coord_comb)] 30 | # Get indices of coordinates which match this combination 31 | coord_inds = np.all(test, 32 | axis=0) 33 | if not coord_inds.any(): 34 | continue 35 | 36 | # Get corresponding coordinates. Min and Max are beginning and end, 37 | # due to ordering made by flattening 38 | min_c, max_c = non_zero_coords[d][coord_inds][[0, -1]] 39 | 40 | curr_obj_dim_lens.append(max_c - min_c + 1) 41 | 42 | obj_avg_dim_lens.append(np.mean(curr_obj_dim_lens) / m.shape[d]) 43 | 44 | return obj_avg_dim_lens 45 | 46 | 47 | def load_mask_and_get_stats(f): 48 | _, s_mask = load_npy_or_npz(f, 'r', True) 49 | 50 | return get_avg_mask_bounds(s_mask), np.mean(s_mask) 51 | 52 | 53 | def nsa_sample_dimension(lb, ub, img_d): 54 | gamma_lb = 0.03 55 | gamma_shape = 2 56 | gamma_scale = 0.1 57 | 58 | gamma_sample = (gamma_lb + np.random.gamma(gamma_shape, gamma_scale)) * img_d 59 | 60 | return int(np.clip(gamma_sample, lb, ub)) 61 | 62 | 63 | def compute_nsa_mask_params(class_has_foreground, dataset, data_num_dims): 64 | if class_has_foreground: 65 | 66 | with Pool(default_num_processes) as pool: 67 | mask_stats = pool.map(load_mask_and_get_stats, [v['data_file'] for v in dataset.values()]) 68 | 69 | # Average proportional length of object along each dimension 70 | avg_obj_dim_len = np.mean([m_s[0] for m_s in mask_stats], axis=0) 71 | # Average area of object proportional to entire image 72 | avg_obj_area = np.mean([m_s[1] for m_s in mask_stats]) 73 | 74 | width_bounds_pct = [(0.06, np.clip(d_len * 4 / 3, 0.25, 0.8)) for d_len in avg_obj_dim_len] 75 | num_patches = 3 if avg_obj_area < 0.75 else 4 76 | min_obj_pct = 0.5 if avg_obj_area < 0.4 else 0.7 77 | 78 | else: 79 | width_bounds_pct = [(0.06, 0.8)] * data_num_dims 80 | num_patches = 4 81 | min_obj_pct = None 82 | 83 | return width_bounds_pct, num_patches, min_obj_pct -------------------------------------------------------------------------------- /nnood/self_supervised_task/opencv_nsa.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | from nnood.configuration import default_num_processes 9 | from nnood.self_supervised_task.nsa_utils import nsa_sample_dimension, compute_nsa_mask_params 10 | from nnood.self_supervised_task.patch_labeller import IntensityPatchLabeller, LogisticIntensityPatchLabeller 11 | from nnood.self_supervised_task.patch_blender import OpenCVPoissonPatchBlender 12 | from nnood.self_supervised_task.patch_ex import patch_ex 13 | from nnood.self_supervised_task.patch_shape_maker import UnequalUniformPatchMaker 14 | from nnood.self_supervised_task.patch_transforms.spatial_transforms import ResizePatch, TranslatePatch 15 | from nnood.self_supervised_task.self_sup_task import SelfSupTask 16 | from nnood.training.dataloading.dataset_loading import load_npy_or_npz 17 | from nnood.utils.file_operations import load_pickle 18 | 19 | 20 | class OpenCVNSA(SelfSupTask): 21 | def __init__(self, mode=cv2.NORMAL_CLONE): 22 | self.mode = mode 23 | self.min_obj_overlap_pct = 0.25 24 | self.shape_maker = UnequalUniformPatchMaker(sample_dist=nsa_sample_dimension) 25 | self.transforms = [ResizePatch(min_overlap_pct=0, validate_position=False), 26 | TranslatePatch(min_overlap_pct=self.min_obj_overlap_pct)] 27 | 28 | self.prev_sample = self.prev_sample_mask = None 29 | 30 | self._calibrated = False 31 | self.width_bounds_pct = self.blender = self.labeller = self.min_obj_pct = self.class_has_foreground \ 32 | = self.num_patches = None 33 | 34 | def apply(self, sample, sample_mask, sample_properties, sample_fn=None, dest_bbox=None, return_locations=False): 35 | src = sample_fn(sample_mask is not None)[0] if self.prev_sample is None else self.prev_sample 36 | 37 | if sample_mask is not None: 38 | if self.prev_sample is None: 39 | src_mask = src[1] 40 | src = src[0] 41 | else: 42 | src_mask = self.prev_sample_mask 43 | else: 44 | src_mask = None 45 | 46 | assert self._calibrated, 'OpenCVNSA task requires calibration!' 47 | 48 | self.blender.norm_args = (self.blender.norm_args[0], self.blender.norm_args[1], 49 | sample_properties['channel_intensity_properties']) 50 | 51 | result = patch_ex(sample, src, 52 | shape_maker=self.shape_maker, 53 | patch_transforms=self.transforms, 54 | blender=self.blender, 55 | labeller=self.labeller, 56 | binary_factor=True, 57 | # 0 overlap for initial patch source location, as we check that when translating patch. 58 | min_overlap_pct=0.0, 59 | width_bounds_pct=self.width_bounds_pct, 60 | min_object_pct=self.min_obj_pct if self.class_has_foreground else None, 61 | dest_bbox=dest_bbox, 62 | num_patches=self.num_patches, 63 | skip_background=(sample_mask, src_mask) if self.class_has_foreground else None, 64 | return_anomaly_locations=return_locations) 65 | 66 | self.prev_sample = sample 67 | self.prev_sample_mask = sample_mask 68 | return result 69 | 70 | def _load_img_m_pair(self, f): 71 | curr_img = load_npy_or_npz(f, 'r', self.class_has_foreground) 72 | 73 | if self.class_has_foreground: 74 | return curr_img # Actually a tuple of image and mask 75 | else: 76 | return curr_img, None 77 | 78 | def _collect_continuous_NSA_examples(self, files_to_load): 79 | last_img, last_img_m = self._load_img_m_pair(files_to_load[0]['data_file']) 80 | 81 | patch_changes = [] 82 | temp_labeller = IntensityPatchLabeller() 83 | temp_blender = OpenCVPoissonPatchBlender(self.mode, 84 | (self.blender.norm_args[0], 85 | self.blender.norm_args[1], 86 | None)) 87 | 88 | for j in files_to_load[1:]: 89 | new_img, new_img_m = self._load_img_m_pair(j['data_file']) 90 | new_prop = j['properties'] if 'properties' in j.keys() else load_pickle(j['properties_file']) 91 | temp_blender.norm_args = (temp_blender.norm_args[0], 92 | temp_blender.norm_args[1], 93 | new_prop['channel_intensity_properties']) 94 | 95 | _, cont_label = patch_ex(new_img, last_img, 96 | shape_maker=self.shape_maker, 97 | patch_transforms=self.transforms, 98 | blender=temp_blender, 99 | labeller=temp_labeller, 100 | binary_factor=True, 101 | min_overlap_pct=0.0, 102 | width_bounds_pct=self.width_bounds_pct, 103 | min_object_pct=self.min_obj_pct, 104 | num_patches=self.num_patches, 105 | skip_background=(new_img_m, last_img_m) 106 | if self.class_has_foreground else None) 107 | 108 | patch_changes.append(cont_label[cont_label > 0]) 109 | 110 | last_img = new_img 111 | last_img_m = new_img_m 112 | 113 | return np.concatenate(patch_changes) 114 | 115 | def calibrate(self, dataset, exp_plans): 116 | if not self._calibrated: 117 | 118 | data_num_dims = len(exp_plans['transpose_forward']) 119 | self.class_has_foreground = exp_plans['dataset_properties']['has_uniform_background'] 120 | 121 | self.blender = OpenCVPoissonPatchBlender(self.mode, 122 | (exp_plans['normalization_schemes'], 123 | exp_plans['dataset_properties']['intensity_properties'], 124 | None)) 125 | 126 | # Compute NSA parameters based on the object masks 127 | self.width_bounds_pct, self.num_patches, self.min_obj_pct = \ 128 | compute_nsa_mask_params(self.class_has_foreground, dataset, data_num_dims) 129 | 130 | # Measure distribution of changes caused by NSA anomalies 131 | keys = list(dataset.keys()) 132 | 133 | with Pool(default_num_processes) as pool: 134 | num_test_samples = 500 135 | samples_per_process = num_test_samples // default_num_processes 136 | 137 | all_patch_changes = pool.map(self._collect_continuous_NSA_examples, 138 | [[dataset[keys[j % len(keys)]] 139 | for j in range(i, i + samples_per_process)] 140 | for i in range(0, num_test_samples, samples_per_process)]) 141 | 142 | all_patch_changes = np.concatenate(all_patch_changes) 143 | 144 | # Calculate logistic function parameters such that: 145 | # - lower bound of patch labels is 0.1. 146 | # - patches saturate at 40th percentile of changes observed. 147 | 148 | scale = np.log(99 * 9) / np.percentile(all_patch_changes, 40) 149 | x0 = np.log(9) / scale 150 | self.labeller = LogisticIntensityPatchLabeller(scale, x0) 151 | 152 | self._calibrated = True 153 | else: 154 | print('WARNING: OpenCVNSA has already been calibrated, cannot be done again.') 155 | 156 | def loss(self, pred, target): 157 | return F.binary_cross_entropy_with_logits(pred, target) 158 | 159 | def label_is_seg(self): 160 | return False 161 | 162 | def inference_nonlin(self, data): 163 | return torch.sigmoid(data) 164 | 165 | 166 | class OpenCVNSAMixed(OpenCVNSA): 167 | 168 | def __init__(self): 169 | super().__init__(cv2.MIXED_CLONE) 170 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/opencv_pii.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from nnood.self_supervised_task.patch_blender import OpenCVPoissonPatchBlender 7 | from nnood.self_supervised_task.patch_ex import patch_ex 8 | from nnood.self_supervised_task.patch_shape_maker import EqualUniformPatchMaker 9 | from nnood.self_supervised_task.patch_labeller import ContinuousPatchLabeller 10 | from nnood.self_supervised_task.self_sup_task import SelfSupTask 11 | 12 | 13 | class OpenCVPII(SelfSupTask): 14 | 15 | def __init__(self): 16 | self.shape_maker = EqualUniformPatchMaker() 17 | self.prev_sample = None 18 | self.labeller = ContinuousPatchLabeller() 19 | 20 | self._calibrated = False 21 | self.blender = None # Initialised during calibration 22 | 23 | def apply(self, sample, sample_mask, sample_properties, sample_fn=None, dest_bbox=None, return_locations=False): 24 | 25 | sample = np.array(sample) 26 | 27 | src = np.array(sample_fn(False)[0]) if self.prev_sample is None else self.prev_sample 28 | if isinstance(src, tuple): 29 | src = src[0] 30 | 31 | assert self._calibrated, 'OpenCVPII task requires calibration!' 32 | 33 | self.blender.norm_args = (self.blender.norm_args[0], self.blender.norm_args[1], 34 | sample_properties['channel_intensity_properties']) 35 | 36 | result = patch_ex(sample, src, shape_maker=self.shape_maker, blender=self.blender, labeller=self.labeller, 37 | binary_factor=False, dest_bbox=dest_bbox, extract_within_bbox=True, 38 | return_anomaly_locations=return_locations, width_bounds_pct=(0.1, 0.4)) 39 | self.prev_sample = sample 40 | return result 41 | 42 | def calibrate(self, dataset, exp_plans): 43 | if not self._calibrated: 44 | 45 | self.blender = OpenCVPoissonPatchBlender(cv2.NORMAL_CLONE, 46 | (exp_plans['normalization_schemes'], 47 | exp_plans['dataset_properties']['intensity_properties'], 48 | None)) 49 | 50 | self._calibrated = True 51 | else: 52 | print('WARNING: OpenCVPII has already been calibrated, cannot be done again.') 53 | 54 | def loss(self, pred, target): 55 | return F.binary_cross_entropy_with_logits(pred, target) 56 | 57 | def label_is_seg(self): 58 | return False 59 | 60 | def inference_nonlin(self, data): 61 | return torch.sigmoid(data) 62 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/patch_blender.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import cv2 6 | from pietorch import blend_dst_numpy 7 | 8 | from nnood.preprocessing.normalisation import _norm_helper 9 | from nnood.self_supervised_task.utils import extract_dest_patch_mask, get_patch_image_slices 10 | 11 | 12 | class PatchBlender(ABC): 13 | 14 | @abstractmethod 15 | def blend(self, factor: float, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, 16 | dest: np.ndarray, patch_object_mask: Optional[np.ndarray], dest_object_mask: Optional[np.ndarray]) \ 17 | -> Tuple[np.ndarray, np.ndarray]: 18 | """ 19 | :param factor: Extent that patch is used in blending. In range [0-1]. May be ignored. 20 | :param patch: Patch extracted from source image. 21 | :param patch_mask: Mask of which elements of patch will be put/blended into the destination image. 22 | :param patch_corner: Coordinate of the minimum element of patch relative to the destination image. 23 | :param dest: Destination image, available for querying but probably shouldn't change! 24 | :param patch_object_mask: Mask showing which elements of the patch contain an object of interest. 25 | :param dest_object_mask: Mask showing which elements of the destination image contain an object of interest. 26 | :returns Tuple of blended image and final patch mask. 27 | """ 28 | pass 29 | 30 | def __call__(self, factor: float, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, 31 | dest: np.ndarray, patch_object_mask: Optional[np.ndarray], dest_object_mask: Optional[np.ndarray]) \ 32 | -> Tuple[np.ndarray, np.ndarray]: 33 | return self.blend(factor, patch, patch_mask, patch_corner, dest, patch_object_mask, dest_object_mask) 34 | 35 | @staticmethod 36 | def get_final_patch_mask(patch_mask: np.ndarray, patch_corner: np.ndarray, patch_object_mask: np.ndarray, 37 | dest_object_mask: np.ndarray) -> np.ndarray: 38 | return patch_mask & (patch_object_mask | extract_dest_patch_mask(patch_corner, patch_mask.shape, 39 | dest_object_mask)) 40 | 41 | 42 | class SwapPatchBlender(PatchBlender): 43 | def blend(self, _: float, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, 44 | dest: np.ndarray, patch_object_mask: Optional[np.ndarray], dest_object_mask: Optional[np.ndarray]) \ 45 | -> Tuple[np.ndarray, np.ndarray]: 46 | if patch_object_mask is not None and dest_object_mask is not None: 47 | patch_mask = self.get_final_patch_mask(patch_mask, patch_corner, patch_object_mask, dest_object_mask) 48 | 49 | image_slices = get_patch_image_slices(patch_corner, patch_mask.shape, patch.shape[0]) 50 | blended_img = dest.copy() 51 | before = dest[image_slices] 52 | blended_img[image_slices] -= patch_mask * before 53 | blended_img[image_slices] += patch_mask * patch 54 | return blended_img, patch_mask 55 | 56 | 57 | class UniformPatchBlender(PatchBlender): 58 | def blend(self, factor: float, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, 59 | dest: np.ndarray, patch_object_mask: Optional[np.ndarray], dest_object_mask: Optional[np.ndarray]) \ 60 | -> Tuple[np.ndarray, np.ndarray]: 61 | if patch_object_mask is not None and dest_object_mask is not None: 62 | patch_mask = self.get_final_patch_mask(patch_mask, patch_corner, patch_object_mask, dest_object_mask) 63 | 64 | image_slices = get_patch_image_slices(patch_corner, patch_mask.shape, patch.shape[0]) 65 | blended_img = dest.copy() 66 | before = dest[image_slices] 67 | blended_img[image_slices] -= factor * patch_mask * before 68 | blended_img[image_slices] += factor * patch_mask * patch 69 | return blended_img, patch_mask 70 | 71 | 72 | # This should not be used for training, only implemented to later compare with own, generic poisson image editing 73 | # implementation 74 | class OpenCVPoissonPatchBlender(PatchBlender): 75 | 76 | def __init__(self, mode=cv2.NORMAL_CLONE, norm_args=None): 77 | self.mode = mode 78 | # Custom normalisation arguments, to be used with _norm_helper 79 | self.norm_args = norm_args 80 | 81 | def blend(self, factor: float, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, 82 | dest: np.ndarray, patch_object_mask: Optional[np.ndarray], dest_object_mask: Optional[np.ndarray]) \ 83 | -> Tuple[np.ndarray, np.ndarray]: 84 | 85 | assert len(dest.shape) == 3 and dest.shape[0] in [1, 3], 'OpenCV patch blending only works on 3 or 1 ' \ 86 | 'channeled, 2D images.' 87 | assert len(patch.shape) == 3 and patch.shape[0] in [1, 3], 'OpenCV patch blending only works on 3 or 1 ' \ 88 | 'channeled, 2D images.' 89 | 90 | blended_img = dest.copy() 91 | 92 | norm_scheme = None 93 | init_dtype = blended_img.dtype 94 | 95 | if self.norm_args is not None: 96 | # Assume custom mean/std denormalises to 0-1. Still need to convert to 0-255 97 | patch = np.uint8(np.round(255 * _norm_helper(patch, *self.norm_args, False))) 98 | blended_img = np.uint8(np.round(255 * _norm_helper(blended_img, *self.norm_args, False))) 99 | 100 | if dest.shape[0] == 1: 101 | patch = np.repeat(patch, 3, axis=0) 102 | blended_img = np.repeat(blended_img, 3, axis=0) 103 | 104 | elif init_dtype is not np.uint8: 105 | # Need to convert patch and blended_img to uint8, as OpenCV only takes those dtypes. 106 | blended_img_min, blended_img_max = np.min(blended_img), np.max(blended_img) 107 | 108 | if 0 <= blended_img_min: 109 | if blended_img_max <= 1: 110 | print('Assuming images are normalised to be in [0-1] range') 111 | norm_scheme = '0-1' 112 | patch = np.uint8(np.round(255 * patch)) 113 | blended_img = np.uint8(np.round(255 * blended_img)) 114 | elif blended_img_max <= 255: 115 | print('Assuming images are in range [0-255] and just need rounding') 116 | norm_scheme = '0-255' 117 | patch = np.uint8(np.round(patch)) 118 | blended_img = np.uint8(np.round(blended_img)) 119 | else: 120 | print('Failed to normalise image, returning unchanged image.') 121 | return dest.copy(), np.zeros_like(patch_mask) 122 | else: 123 | print('Assuming images are normalised using ImageNet statistics') 124 | norm_scheme = 'imagenet' 125 | 126 | patch = np.uint8(np.round(255 * _norm_helper(patch, ['png-r', 'png-g', 'png-b'], None, None, False))) 127 | blended_img = np.uint8(np.round(255 * _norm_helper(blended_img, ['png-r', 'png-g', 'png-b'], None, None, 128 | False))) 129 | 130 | patch_mask_scaled = np.uint8(np.ceil(factor * 255) * patch_mask) 131 | 132 | # zero border to avoid artefacts 133 | patch_mask_scaled[0] = patch_mask_scaled[-1] = patch_mask_scaled[:, 0] = patch_mask_scaled[:, -1] = 0 134 | 135 | # cv2 seamlessClone will fail if positive mask area is too small 136 | if np.sum(patch_mask_scaled > 0) < 50: 137 | print('Masked area is too small to perform poisson image editing.') 138 | return dest.copy(), np.zeros_like(patch_mask) 139 | 140 | # Coordinates are reversed, because OpenCV expects images to be [H, W, C], yet for coordinates expects (x, y) 141 | # why are you like this OpenCV 142 | centre = tuple((patch_corner + np.array(patch_mask.shape) // 2)[::-1]) 143 | 144 | # Move to channels last for opencv 145 | blended_img = np.moveaxis(blended_img, 0, -1) 146 | patch = np.moveaxis(patch, 0, -1) 147 | 148 | try: 149 | blended_img = cv2.seamlessClone(patch, blended_img, patch_mask_scaled, centre, self.mode) 150 | except cv2.error as e: 151 | print('WARNING, tried bad interpolation mask and got:', e) 152 | print('Info dump:') 153 | print('Dest orig shape: ', dest.shape) 154 | print('Dest curr shape: ', blended_img.shape) 155 | print('Patch shape (after moving axis): ', patch.shape) 156 | print('Patch mask shape: ', patch_mask.shape) 157 | print('Scaled patch mask shape: ', patch_mask_scaled.shape) 158 | print('Patch corner: ', patch_corner) 159 | print('OpenCV centre: ', centre) 160 | return dest.copy(), np.zeros_like(patch_mask) 161 | 162 | # Switch channels back 163 | blended_img = np.moveaxis(blended_img, -1, 0) 164 | 165 | if self.norm_args is not None: 166 | 167 | if dest.shape[0] == 1: 168 | blended_img = blended_img[:1] 169 | 170 | blended_img = _norm_helper(blended_img / 255, *self.norm_args, True) 171 | 172 | elif norm_scheme is not None: 173 | if norm_scheme == '0-1': 174 | blended_img = (blended_img / 255).astype(init_dtype) 175 | elif norm_scheme == '0-255': 176 | blended_img = blended_img.astype(init_dtype) 177 | elif norm_scheme == 'imagenet': 178 | blended_img = _norm_helper(blended_img / 255, ['png-r', 'png-g', 'png-b'], None, None, True) 179 | else: 180 | assert False, f'Somehow got invalid norm scheme? : {norm_scheme}' 181 | 182 | return blended_img, patch_mask 183 | 184 | 185 | class PoissonPatchBlender(PatchBlender): 186 | 187 | def __init__(self, mixed_gradients=False): 188 | self.mix_gradients = mixed_gradients 189 | 190 | def blend(self, factor: float, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, 191 | dest: np.ndarray, patch_object_mask: Optional[np.ndarray], dest_object_mask: Optional[np.ndarray]) \ 192 | -> Tuple[np.ndarray, np.ndarray]: 193 | interp_mask = factor * patch_mask 194 | blended_img = blend_dst_numpy(dest, patch, interp_mask, patch_corner, mix_gradients=self.mix_gradients, 195 | channels_dim=0) 196 | 197 | return blended_img, patch_mask 198 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/patch_ex.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | 5 | from nnood.self_supervised_task.patch_blender import PatchBlender, SwapPatchBlender 6 | from nnood.self_supervised_task.patch_labeller import PatchLabeller, BinaryPatchLabeller 7 | from nnood.self_supervised_task.patch_shape_maker import PatchShapeMaker, EqualUniformPatchMaker 8 | from nnood.self_supervised_task.patch_transforms.base_transform import PatchTransform 9 | from nnood.self_supervised_task.utils import check_object_overlap, get_patch_slices, get_patch_image_slices 10 | 11 | 12 | def patch_ex(img_dest: np.ndarray, img_src: Optional[np.ndarray] = None, same: bool = False, num_patches: int = 1, 13 | shape_maker: PatchShapeMaker = EqualUniformPatchMaker(), patch_transforms: List[PatchTransform] = [], 14 | blender: PatchBlender = SwapPatchBlender(), labeller: PatchLabeller = BinaryPatchLabeller(), 15 | return_anomaly_locations: bool = False, binary_factor: bool = True, min_overlap_pct: float = 0.25, 16 | width_bounds_pct: Union[Tuple[float, float], List[Tuple[float, float]]] = (0.05, 0.4), 17 | min_object_pct: float = 0.25, dest_bbox: Optional[np.ndarray] = None, extract_within_bbox: bool = False, 18 | skip_background: Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]] = None, verbose=True) \ 19 | -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, List[np.ndarray]]]: 20 | """ 21 | Create a synthetic training example from the given images by pasting/blending random patches. 22 | Args: 23 | :param img_dest: Image with shape (C[,Z],H,W) where patch should be changed. 24 | :param img_src: Optional, equal dimensions to img_dest, otherwise use ima_dest as source. 25 | :param same: Use ima_dest as source even if ima_src given. 26 | :param num_patches: How many patches to add. the method will always attempt to add the first patch, 27 | for each subsequent patch it flips a coin. 28 | :param shape_maker: Used to create initial patch mask. 29 | :param patch_transforms: List of transforms to be applied to patch after it is selected with shape_maker. 30 | :param blender: Used to blend patches into destination image. 31 | :param labeller: Used to produce the label for the blended image. 32 | :param return_anomaly_locations: Bool for whether we return the list of anomaly centre coordinates. 33 | :param binary_factor: Whether to use a binary factor or sample from np.random.uniform(0.05, 0.95) 34 | :param min_overlap_pct: Minimum percentage object in patch extracted from source image and object in destination 35 | image should overlap. 36 | :param width_bounds_pct: Either tuple of limits for every dimension, or list of tuples of limits for each 37 | dimension. 38 | :param min_object_pct: Minimum percentage of object which the patch extracted from the source image must cover. 39 | :param dest_bbox: Shape (num_dims, 2). Specify which region to apply the anomalies within. If None, treat as 40 | entire image. 41 | :param extract_within_bbox: Specify whether patches are extracted from inside dest_bbox from the source image. 42 | :param skip_background : Optional tuple of foreground masks for dest and src, or if 'same' then only src object 43 | mask. 44 | :param verbose: Allows debugging prints. 45 | """ 46 | img_src = img_dest.copy() if same or (img_src is None) else img_src 47 | 48 | if skip_background is not None: 49 | if same: 50 | dest_object_mask = src_object_mask = skip_background 51 | else: 52 | dest_object_mask = skip_background[0] 53 | src_object_mask = skip_background[1] 54 | 55 | else: 56 | dest_object_mask = None 57 | src_object_mask = None 58 | 59 | # add patches 60 | mask = np.zeros_like(img_dest[0], dtype=bool) # single channel 61 | blended_img = img_dest.copy() 62 | 63 | if isinstance(width_bounds_pct, tuple): 64 | width_bounds_pct = [width_bounds_pct] * len(mask.shape) 65 | 66 | # Shape (spatial_dimensions, 2) 67 | width_bounds_pct = np.array(width_bounds_pct) 68 | 69 | if binary_factor: 70 | factor = 1.0 71 | else: 72 | factor = np.random.uniform(0.05, 0.95) 73 | 74 | anomaly_centres = [] 75 | 76 | dest_bbox_lbs = np.zeros(len(img_dest.shape) - 1, dtype=int) if dest_bbox is None else dest_bbox[:, 0] 77 | dest_bbox_ubs = np.array(img_dest.shape[1:]) if dest_bbox is None else dest_bbox[:, 1] 78 | 79 | for i in range(num_patches): 80 | if i == 0 or np.random.randint(2) > 0: # at least one patch 81 | blended_img, patch_corner, patch_mask = _patch_ex( 82 | blended_img, img_src, dest_object_mask, src_object_mask, shape_maker, patch_transforms, blender, 83 | width_bounds_pct, min_object_pct, min_overlap_pct, factor, verbose, dest_bbox_lbs, dest_bbox_ubs, 84 | extract_within_bbox) 85 | 86 | if patch_mask is not None: 87 | assert patch_corner is not None, 'patch_mask is not None, but patch_corner is???' 88 | assert patch_mask is not None, 'Should never be triggered, just for nice typing :)' 89 | 90 | mask[get_patch_slices(patch_corner, patch_mask.shape)] |= patch_mask 91 | 92 | anomaly_centres.append(patch_corner + np.array(patch_mask.shape) // 2) 93 | 94 | mask = mask.astype(float) 95 | final_label = labeller(factor, blended_img, img_dest, mask)[None] 96 | 97 | if return_anomaly_locations: 98 | # Convert label to single channel, to match network output 99 | return blended_img, final_label, anomaly_centres 100 | else: 101 | return blended_img, final_label 102 | 103 | 104 | def _patch_ex(ima_dest: np.ndarray, ima_src: np.ndarray, dest_object_mask: Optional[np.ndarray], 105 | src_object_mask: Optional[np.ndarray], shape_maker: PatchShapeMaker, 106 | patch_transforms: List[PatchTransform], blender: PatchBlender, 107 | width_bounds_pct: np.ndarray, min_object_pct: float, min_overlap_pct: float, factor: float, 108 | verbose: bool, dest_bbox_lbs: np.ndarray, dest_bbox_ubs: np.ndarray, extract_within_bbox: bool) \ 109 | -> Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: 110 | skip_background = (src_object_mask is not None) and (dest_object_mask is not None) 111 | dims = np.array(ima_dest.shape) 112 | bbox_shape = dest_bbox_ubs - dest_bbox_lbs 113 | 114 | min_dim_lens = (width_bounds_pct[:, 0] * bbox_shape).round().astype(int) 115 | max_dim_lens = (width_bounds_pct[:, 1] * bbox_shape).round().astype(int) 116 | dim_bounds = list(zip(min_dim_lens, max_dim_lens)) 117 | 118 | patch_mask = shape_maker(dim_bounds, bbox_shape) 119 | 120 | found_patch = False 121 | attempts = 0 122 | 123 | src_patch_lb = dest_bbox_lbs if extract_within_bbox else np.zeros(len(ima_src.shape) - 1) 124 | src_patch_ub = dest_bbox_ubs if extract_within_bbox else np.array(ima_src.shape[1:]) 125 | 126 | # Use minimum patch size as buffer to stop patch being too close to edge 127 | patch_centre_bounds = [(lb + b, ub - b) for (b, lb, ub) in zip(min_dim_lens, src_patch_lb, src_patch_ub)] 128 | 129 | if skip_background: 130 | # Reduce search space, so patch centre is within object bounding box. Reduces change of missing object, and 131 | # requiring more iterations. 132 | 133 | for d in range(len(patch_centre_bounds)): 134 | curr_lb, curr_ub = patch_centre_bounds[d] 135 | other_dims = tuple([d2 for d2 in range(len(patch_centre_bounds)) if d2 != d]) 136 | obj_m_min_ind, obj_m_max_ind = np.nonzero(np.any(src_object_mask, axis=other_dims))[0][[0, -1]] 137 | 138 | patch_centre_bounds[d] = (max(curr_lb, obj_m_min_ind), min(curr_ub, obj_m_max_ind)) 139 | 140 | while not found_patch: 141 | 142 | centers = np.array([np.random.randint(lb, ub) for lb, ub in patch_centre_bounds]) 143 | patch_dims = np.array(patch_mask.shape) 144 | 145 | # Indices of patch corners relative to source, could be out of bounds! 146 | min_corner = centers - patch_dims // 2 147 | max_corner = min_corner + patch_dims 148 | 149 | # Indices of valid area WITHIN patch 150 | patch_min_indices = np.maximum(-min_corner, 0) 151 | patch_max_indices = patch_dims - np.maximum(max_corner - dims[1:], 0) 152 | 153 | test_patch_mask = patch_mask[tuple([slice(lb, ub) for (lb, ub) in zip(patch_min_indices, patch_max_indices)])] 154 | test_patch_corner = np.maximum(min_corner, 0) 155 | 156 | if skip_background: 157 | test_patch_object_mask = src_object_mask[get_patch_slices(test_patch_corner, test_patch_mask.shape)] 158 | object_area = np.sum(test_patch_mask & test_patch_object_mask) 159 | obj_area_sat = (object_area / np.prod(test_patch_mask.shape) > min_object_pct) 160 | 161 | # Want both conditions to hold. If first fails, skip second and iterate faster 162 | if obj_area_sat: 163 | found_patch = check_object_overlap(test_patch_corner, test_patch_mask, test_patch_object_mask, 164 | dest_object_mask, min_overlap_pct) 165 | else: 166 | found_patch = False 167 | else: 168 | found_patch = True 169 | attempts += 1 170 | if attempts == 200: 171 | if verbose: 172 | print('No suitable patch found (initial location failed).') 173 | return ima_dest.copy(), None, None 174 | 175 | patch = ima_src[get_patch_image_slices(test_patch_corner, test_patch_mask.shape, ima_src.shape[0])] 176 | patch_mask = test_patch_mask 177 | patch_corner = test_patch_corner 178 | patch_object_mask = src_object_mask if src_object_mask is None or not skip_background else test_patch_object_mask 179 | 180 | for p_t in patch_transforms: 181 | patch, patch_mask, patch_corner, patch_object_mask = \ 182 | p_t(patch, patch_mask, patch_corner, ima_dest, dest_bbox_lbs, dest_bbox_ubs, patch_object_mask, 183 | dest_object_mask) 184 | 185 | blended_img, patch_mask = blender(factor, patch, patch_mask, patch_corner, ima_dest, patch_object_mask, 186 | dest_object_mask) 187 | 188 | return blended_img, patch_corner, patch_mask 189 | 190 | 191 | if __name__ == '__main__': 192 | patch_ex(np.random.random((3, 50, 50)), np.random.random((3, 50, 50))) 193 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/patch_labeller.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | from scipy.ndimage import binary_closing, grey_closing 5 | 6 | 7 | class PatchLabeller(ABC): 8 | 9 | def __init__(self, tolerance=0.01): 10 | self.tolerance = tolerance 11 | 12 | @abstractmethod 13 | def label(self, factor: float, blended_img: np.ndarray, orig_img: np.ndarray, mask: np.ndarray) -> np.ndarray: 14 | """ 15 | :param factor: Extent that patch is used in blending. In range [0-1]. May be ignored. 16 | :param blended_img: Image with patches blended within it. 17 | :param orig_img: Original image, prior to blending. 18 | :param mask: Mask of where patches have been blended into blended_img. 19 | """ 20 | pass 21 | 22 | def __call__(self, factor: float, blended_img: np.ndarray, orig_img: np.ndarray, mask: np.ndarray) -> np.ndarray: 23 | return self.label(factor, blended_img, orig_img, mask) 24 | 25 | def remove_no_change(self, blended_img: np.ndarray, orig_img: np.ndarray, mask: np.ndarray) -> np.ndarray: 26 | mask = (np.mean(mask * np.abs(blended_img - orig_img), axis=0) > self.tolerance).astype(int) 27 | # Remove grain from threshold choice, using scipy morphology 28 | # Equivalent to using structure of (5, 5, 5) 29 | return binary_closing(mask, structure=np.ones([3] * len(mask.shape)), iterations=2) 30 | 31 | 32 | class BinaryPatchLabeller(PatchLabeller): 33 | 34 | def label(self, factor: float, blended_img: np.ndarray, orig_img: np.ndarray, mask: np.ndarray) -> np.ndarray: 35 | return self.remove_no_change(blended_img, orig_img, mask) 36 | 37 | 38 | class ContinuousPatchLabeller(PatchLabeller): 39 | 40 | def label(self, factor: float, blended_img: np.ndarray, orig_img: np.ndarray, mask: np.ndarray) -> np.ndarray: 41 | return factor * self.remove_no_change(blended_img, orig_img, mask) 42 | 43 | 44 | class IntensityPatchLabeller(PatchLabeller): 45 | 46 | def label(self, factor: float, blended_img: np.ndarray, orig_img: np.ndarray, mask: np.ndarray) -> np.ndarray: 47 | mask = self.remove_no_change(blended_img, orig_img, mask) 48 | 49 | label = np.mean(mask * np.abs(blended_img - orig_img), axis=0) 50 | return grey_closing(label, size=[7] * len(mask.shape)) 51 | 52 | 53 | class LogisticIntensityPatchLabeller(IntensityPatchLabeller): 54 | 55 | def __init__(self, k, x0, tolerance=0.01): 56 | super().__init__(tolerance) 57 | self.k = k 58 | self.x0 = x0 59 | 60 | def label(self, factor: float, blended_img: np.ndarray, orig_img: np.ndarray, mask: np.ndarray) -> np.ndarray: 61 | intensity_label = super().label(factor, blended_img, orig_img, mask) 62 | return (intensity_label > 0).astype(int) / (1 + np.exp(-self.k * (intensity_label - self.x0))) 63 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/patch_shape_maker.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | 6 | 7 | class PatchShapeMaker(ABC): 8 | 9 | @abstractmethod 10 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> np.ndarray: 11 | """ 12 | :param dim_bounds: Tuples giving lower and upper bounds for patch size in each dimension 13 | :param img_dims: Image dimensions, can be used as scaling factor. 14 | Creates a patch mask to be used in the self-supervised task. 15 | Mask must have length(dim_bounds) dimensions. 16 | """ 17 | pass 18 | 19 | def __call__(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> np.ndarray: 20 | return self.get_patch_mask(dim_bounds, img_dims) 21 | 22 | 23 | # For squares, cubes, etc 24 | class EqualUniformPatchMaker(PatchShapeMaker): 25 | 26 | def __init__(self, sample_dist=lambda lb, ub, _: np.random.randint(lb, ub)): 27 | self.sample_dist = sample_dist 28 | 29 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> np.ndarray: 30 | lbs, ubs = zip(*dim_bounds) 31 | 32 | # As all dimensions must be equal, take maximum lower bound and minimum upperbound as patch size bounds 33 | # and give the minimum image dimension to be used as an optional scaling factor 34 | patch_dim = self.sample_dist(max(lbs), min(ubs), min(img_dims)) 35 | return np.ones([patch_dim] * len(dim_bounds), dtype=bool) 36 | 37 | 38 | # For rectangles, cuboids, etc 39 | class UnequalUniformPatchMaker(PatchShapeMaker): 40 | 41 | def __init__(self, sample_dist=lambda lb, ub, _: np.random.randint(lb, ub), calc_dims_together=False): 42 | self.sample_dist = sample_dist 43 | self.calc_dims_together = calc_dims_together 44 | 45 | def get_patch_mask(self, dim_bounds: List[Tuple[int, int]], img_dims: np.ndarray) -> np.ndarray: 46 | 47 | shape = self.sample_dist(dim_bounds, img_dims) if self.calc_dims_together else\ 48 | [self.sample_dist(lb, ub, d) for ((lb, ub), d) in zip(dim_bounds, img_dims)] 49 | return np.ones(shape, dtype=bool) 50 | 51 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/patch_transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matt-baugh/nnOOD/a953bcad86c59cd016169141a24631cc3ded02ff/nnood/self_supervised_task/patch_transforms/__init__.py -------------------------------------------------------------------------------- /nnood/self_supervised_task/patch_transforms/base_transform.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | 6 | 7 | class PatchTransform(ABC): 8 | 9 | def __init__(self, min_overlap_pct: Optional[float] = None): 10 | self.min_overlap_pct = min_overlap_pct 11 | 12 | @abstractmethod 13 | def transform(self, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, dest: np.ndarray, 14 | dest_bbox_lbs: np.ndarray, dest_bbox_ubs: np.ndarray, patch_object_mask: Optional[np.ndarray], 15 | dest_object_mask: Optional[np.ndarray]) \ 16 | -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]: 17 | """ 18 | :param patch: Patch extracted from source image. 19 | :param patch_mask: Mask of which elements of patch will be put/blended into the destination image. 20 | :param patch_corner: Coordinate of the minimum element of patch relative to the destination image. 21 | :param dest: Destination image, available for querying but probably shouldn't change! 22 | :param dest_bbox_lbs: Lower bound of box to be extracted from destination image 23 | :param dest_bbox_ubs: Upper bound of box to be extracted from destination image 24 | :param patch_object_mask: Mask showing which elements of the patch contain an object of interest. 25 | :param dest_object_mask: Mask showing which elements of the destination image contain an object of interest. 26 | :returns Tuple containing updated patch, patch_mask, patch_corner and patch_object_mask 27 | """ 28 | pass 29 | 30 | def __call__(self, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, dest: np.ndarray, 31 | dest_bbox_lbs: np.ndarray, dest_bbox_ubs: np.ndarray, patch_object_mask: Optional[np.ndarray], 32 | dest_object_mask: Optional[np.ndarray]) \ 33 | -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]: 34 | return self.transform(patch, patch_mask, patch_corner, dest, dest_bbox_lbs, dest_bbox_ubs, patch_object_mask, 35 | dest_object_mask) 36 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/patch_transforms/colour_transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import numpy as np 4 | 5 | from nnood.self_supervised_task.patch_transforms.base_transform import PatchTransform 6 | 7 | 8 | class AdjustContrast(PatchTransform): 9 | 10 | def __init__(self, contrast: Union[float, Tuple[float, float]], min_overlap_pct: Optional[float] = None): 11 | super().__init__(min_overlap_pct) 12 | if isinstance(contrast, (tuple, list)): 13 | self.contrast_min = max(contrast[0], 0.0) 14 | self.contrast_max = min(contrast[1], 2.0) 15 | else: 16 | self.contrast_min = max(1.0 - contrast, 0.0) 17 | self.contrast_max = min(1.0 + contrast, 1.0) 18 | 19 | def transform(self, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, dest: np.ndarray, 20 | dest_bbox_lbs: np.ndarray, dest_bbox_ubs: np.ndarray, patch_object_mask: Optional[np.ndarray], 21 | dest_object_mask: Optional[np.ndarray]) \ 22 | -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]: 23 | 24 | contrast_f = np.random.uniform(self.contrast_min, self.contrast_max) 25 | avg = np.mean(patch, axis=0)[patch_mask == 1].mean() 26 | new_patch = (contrast_f * patch + (1.0 - contrast_f) * avg) * patch_mask 27 | 28 | return new_patch, patch_mask, patch_corner, patch_object_mask 29 | 30 | 31 | class PsuedoAdjustBrightness(PatchTransform): 32 | 33 | def __init__(self, brightness: Union[float, Tuple[float, float]], min_dset_val: float, 34 | min_overlap_pct: Optional[float] = None): 35 | super().__init__(min_overlap_pct) 36 | if isinstance(brightness, (tuple, list)): 37 | self.brightness_min = max(brightness[0], 0.0) 38 | self.brightness_max = min(brightness[1], 2.0) 39 | else: 40 | self.brightness_min = max(1.0 - brightness, 0.0) 41 | self.brightness_max = min(1.0 + brightness, 1.0) 42 | self.min_val = min_dset_val 43 | 44 | def transform(self, patch: np.ndarray, patch_mask: np.ndarray, patch_corner: np.ndarray, dest: np.ndarray, 45 | dest_bbox_lbs: np.ndarray, dest_bbox_ubs: np.ndarray, patch_object_mask: Optional[np.ndarray], 46 | dest_object_mask: Optional[np.ndarray]) \ 47 | -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]: 48 | 49 | brightness_f = np.random.uniform(self.brightness_min, self.brightness_max) 50 | new_patch = (self.min_val + (patch - self.min_val) * brightness_f) * patch_mask 51 | 52 | return new_patch, patch_mask, patch_corner, patch_object_mask 53 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/pii.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from nnood.self_supervised_task.patch_blender import PoissonPatchBlender 6 | from nnood.self_supervised_task.patch_ex import patch_ex 7 | from nnood.self_supervised_task.patch_shape_maker import EqualUniformPatchMaker 8 | from nnood.self_supervised_task.patch_labeller import ContinuousPatchLabeller 9 | from nnood.self_supervised_task.self_sup_task import SelfSupTask 10 | 11 | 12 | class PII(SelfSupTask): 13 | 14 | def __init__(self, mixed_gradients=False): 15 | self.mixed_gradients = mixed_gradients 16 | 17 | self.shape_maker = EqualUniformPatchMaker() 18 | self.prev_sample = None 19 | self.labeller = ContinuousPatchLabeller() 20 | self.blender = PoissonPatchBlender(self.mixed_gradients) 21 | 22 | def apply(self, sample, sample_mask, sample_properties, sample_fn=None, dest_bbox=None, return_locations=False): 23 | 24 | sample = np.array(sample) 25 | 26 | src = np.array(sample_fn(False)[0]) if self.prev_sample is None else self.prev_sample 27 | if isinstance(src, tuple): 28 | src = src[0] 29 | 30 | result = patch_ex(sample, src, shape_maker=self.shape_maker, blender=self.blender, labeller=self.labeller, 31 | binary_factor=False, dest_bbox=dest_bbox, extract_within_bbox=True, 32 | return_anomaly_locations=return_locations, width_bounds_pct=(0.1, 0.4)) 33 | self.prev_sample = sample 34 | return result 35 | 36 | def loss(self, pred, target): 37 | return F.binary_cross_entropy_with_logits(pred, target) 38 | 39 | def label_is_seg(self): 40 | return False 41 | 42 | def inference_nonlin(self, data): 43 | return torch.sigmoid(data) 44 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/rect_fpi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from nnood.self_supervised_task.patch_blender import UniformPatchBlender 6 | from nnood.self_supervised_task.patch_ex import patch_ex 7 | from nnood.self_supervised_task.patch_shape_maker import UnequalUniformPatchMaker 8 | from nnood.self_supervised_task.patch_labeller import ContinuousPatchLabeller 9 | from nnood.self_supervised_task.self_sup_task import SelfSupTask 10 | 11 | 12 | class RectFPI(SelfSupTask): 13 | 14 | def __init__(self): 15 | self.shape_maker = UnequalUniformPatchMaker(sample_dist=self.sample_dimension) 16 | self.prev_sample = None 17 | self.blender = UniformPatchBlender() 18 | self.labeller = ContinuousPatchLabeller() 19 | 20 | @staticmethod 21 | def sample_dimension(lb, ub, img_d): 22 | gamma_lb = 0.03 23 | gamma_shape = 2 24 | gamma_scale = 0.1 25 | 26 | gamma_sample = (gamma_lb + np.random.gamma(gamma_shape, gamma_scale)) * img_d 27 | 28 | return int(np.clip(gamma_sample, lb, ub)) 29 | 30 | def apply(self, sample, sample_mask, sample_properties, sample_fn=None, dest_bbox=None, return_locations=False): 31 | src = sample_fn(False)[0] if self.prev_sample is None else self.prev_sample 32 | if isinstance(src, tuple): 33 | src = src[0] 34 | 35 | result = patch_ex(sample, src, shape_maker=self.shape_maker, blender=self.blender, labeller=self.labeller, 36 | binary_factor=False, dest_bbox=dest_bbox, extract_within_bbox=True, 37 | return_anomaly_locations=return_locations, width_bounds_pct=(0.06, 0.8)) 38 | self.prev_sample = sample 39 | return result 40 | 41 | def loss(self, pred, target): 42 | return F.binary_cross_entropy_with_logits(pred, target) 43 | 44 | def label_is_seg(self): 45 | return False 46 | 47 | def inference_nonlin(self, data): 48 | return torch.sigmoid(data) 49 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/self_sup_task.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Callable, List, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | from collections import OrderedDict 6 | 7 | 8 | class SelfSupTask(ABC): 9 | """ 10 | A generic class defining the form of a self-supervised anomaly detection task. 11 | What you need to override: 12 | - __apply__ 13 | """ 14 | 15 | def calibrate(self, dataset, exp_plans): 16 | """ 17 | If any parameters of the task depend on the dataset being used, set them here (for example, whether data is 2D 18 | or 3D). May not be needed. 19 | :param dataset: dictionary of all training/validation samples, so use dataset_loading.load_npy_or_npz to 20 | actually get samples 21 | :param exp_plans: plans for experiment 22 | """ 23 | pass 24 | 25 | @abstractmethod 26 | def apply(self, sample: np.ndarray, sample_mask: Optional[np.ndarray], sample_properties: OrderedDict, 27 | sample_fn: Optional[ 28 | Callable[[bool], Tuple[Union[Tuple[np.ndarray, np.ndarray], np.ndarray], Any]]] = None, 29 | dest_bbox: Optional[np.ndarray] = None, return_locations: bool = False) \ 30 | -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, List[np.ndarray]]]: 31 | """ 32 | Apply the self-supervised task to the single data sample. 33 | :param sample: 34 | :param sample_mask: Optional mask showing region of interest (allows task to focus on making anomalies here). 35 | :param sample_properties: 36 | :param sample_fn: Optional function to get auxiliary samples for task (such as in FPI) 37 | :param dest_bbox: Specify which region to apply the anomalies within. If None, treat as entire image. 38 | :param return_locations: Optional boolean for whether to return list of anomaly centres 39 | :return: sample with task applied, label map and (if return_locations=True) a list of anomaly centres. 40 | """ 41 | pass 42 | 43 | def __call__(self, sample: np.ndarray, sample_mask: Optional[np.ndarray], sample_properties: OrderedDict, 44 | sample_fn: Optional[ 45 | Callable[[bool], Tuple[Union[Tuple[np.ndarray, np.ndarray], np.ndarray], Any]]] = None, 46 | dest_bbox: Optional[np.ndarray] = None, return_locations: bool = False) \ 47 | -> Union[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, List[Tuple[np.ndarray, np.ndarray]]]]: 48 | return self.apply(sample, sample_mask, sample_properties, sample_fn, dest_bbox, return_locations) 49 | 50 | @abstractmethod 51 | def loss(self, pred, target): 52 | """ 53 | Loss function to be used when training with this task. May be simply calling a torch loss function 54 | :param pred: 55 | :param target: 56 | :return: 57 | """ 58 | 59 | @abstractmethod 60 | def label_is_seg(self): 61 | """ 62 | Returns whether the label for this loss function is a segmentation (of classes) or not. 63 | :return: 64 | """ 65 | 66 | def inference_nonlin(self, data): 67 | """ 68 | Optional nonlinearity to be applied to network output at inference time. 69 | :param data: 70 | :return: 71 | """ 72 | return data 73 | -------------------------------------------------------------------------------- /nnood/self_supervised_task/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | def check_object_overlap(patch_corner: np.ndarray, patch_mask: np.ndarray, patch_object_mask: np.ndarray, 7 | dest_object_mask: np.ndarray, min_overlap_pct: float): 8 | # If no overlap required, skip calculations and return true. 9 | if min_overlap_pct == 0: 10 | return True 11 | 12 | dest_object_mask = extract_dest_patch_mask(patch_corner, patch_mask.shape, dest_object_mask) 13 | patch_and_object = patch_mask & patch_object_mask 14 | patch_and_dest = patch_and_object & dest_object_mask 15 | 16 | patch_obj_dst_area = np.sum(patch_and_dest) 17 | patch_obj_area = np.sum(patch_and_object) 18 | 19 | # Avoid division by zero. If the patch covers none of source object, we want to reject. 20 | if patch_obj_area == 0: 21 | return False 22 | 23 | return (patch_obj_dst_area / patch_obj_area) >= min_overlap_pct 24 | 25 | 26 | def extract_dest_patch_mask(patch_corner: np.ndarray, patch_shape: Tuple[int], dest_object_mask: np.ndarray): 27 | assert len(patch_corner) == len(dest_object_mask.shape), 'Patch coordinate and destination object mask must have' \ 28 | f'equal number of dimensions: {patch_corner}, ' \ 29 | f'{dest_object_mask.shape}' 30 | assert len(patch_corner) == len(dest_object_mask.shape), 'Patch coordinate and patch shape must have equal ' \ 31 | f'number of dimensions: {patch_corner}, {patch_shape}' 32 | 33 | return dest_object_mask[get_patch_slices(patch_corner, patch_shape)] 34 | 35 | 36 | def get_patch_slices(patch_corner: np.ndarray, patch_shape: Tuple[int]) -> Tuple[slice]: 37 | return tuple([slice(c, c + d) for (c, d) in zip(patch_corner, patch_shape)]) 38 | 39 | 40 | # Same as above, but with additional slice at beginning to include all image channels. 41 | def get_patch_image_slices(patch_corner: np.ndarray, patch_shape: Tuple[int], img_channels: int) -> Tuple[slice]: 42 | return tuple([slice(img_channels)] + list(get_patch_slices(patch_corner, patch_shape))) 43 | -------------------------------------------------------------------------------- /nnood/training/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/training/data_augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/training/data_augmentation/downsampling.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.transforms.abstract_transforms import AbstractTransform 2 | import numpy as np 3 | from skimage.transform import resize 4 | 5 | 6 | class DownsampleSegForDSTransform(AbstractTransform): 7 | """ 8 | data_dict['output_key'] will be a list of segmentations scaled according to ds_scales 9 | """ 10 | def __init__(self, ds_scales=(1, 0.5, 0.25), order=1, input_key='seg', output_key='seg', axes=None): 11 | self.axes = axes 12 | self.output_key = output_key 13 | self.input_key = input_key 14 | self.order = order 15 | self.ds_scales = ds_scales 16 | 17 | def __call__(self, **data_dict): 18 | data_dict[self.output_key] = downsample_seg_for_ds_transform(data_dict[self.input_key], self.ds_scales, 19 | self.order, self.axes) 20 | return data_dict 21 | 22 | 23 | def downsample_seg_for_ds_transform(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), order=1, 24 | axes=None): 25 | if axes is None: 26 | axes = list(range(2, len(seg.shape))) 27 | 28 | output = [] 29 | 30 | for s in ds_scales: 31 | 32 | if all([i == 1 for i in s]): 33 | output.append(seg) 34 | else: 35 | new_shape = np.array(seg.shape).astype(float) 36 | for i, a in enumerate(axes): 37 | new_shape[a] *= s[i] 38 | new_shape = np.round(new_shape).astype(int) 39 | out_seg = np.zeros(new_shape, dtype=seg.dtype) 40 | for b in range(seg.shape[0]): 41 | for c in range(seg.shape[1]): 42 | out_seg[b, c] = resize(seg[b, c], new_shape[2:], order) 43 | output.append(out_seg) 44 | 45 | return output 46 | -------------------------------------------------------------------------------- /nnood/training/dataloading/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/training/loss_functions/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/training/loss_functions/deep_supervision.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class MultipleOutputLoss(nn.Module): 5 | def __init__(self, loss, weight_factors=None): 6 | """ 7 | use this if you have several outputs and ground truth (both list of same len) and the loss should be computed 8 | between them (x[0] and y[0], x[1] and y[1] etc) 9 | :param loss: 10 | :param weight_factors: 11 | """ 12 | super(MultipleOutputLoss, self).__init__() 13 | self.weight_factors = weight_factors 14 | self.loss = loss 15 | 16 | def forward(self, x, y): 17 | assert isinstance(x, (tuple, list)), 'x must be either tuple or list' 18 | assert isinstance(y, (tuple, list)), 'y must be either tuple or list' 19 | if self.weight_factors is None: 20 | weights = [1] * len(x) 21 | else: 22 | weights = self.weight_factors 23 | 24 | combined_loss = weights[0] * self.loss(x[0], y[0]) 25 | for i in range(1, len(x)): 26 | if weights[i] != 0: 27 | combined_loss += weights[i] * self.loss(x[i], y[i]) 28 | return combined_loss 29 | -------------------------------------------------------------------------------- /nnood/training/network_training/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/training/network_training/nnOODTrainerDS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from nnood.training.network_training.nnOODTrainer import nnOODTrainer 4 | from nnood.training.loss_functions.deep_supervision import MultipleOutputLoss 5 | 6 | 7 | class nnOODTrainerDS(nnOODTrainer): 8 | 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | 12 | self.ds_loss_weights = None 13 | 14 | def setup_DA_params(self): 15 | super(nnOODTrainerDS, self).setup_DA_params() 16 | self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod( 17 | np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1] 18 | 19 | def initialize(self, training=True, force_load_plans=False): 20 | """ 21 | Add loss function wrapper for deep supervision. 22 | 23 | :param training: 24 | :param force_load_plans: 25 | :return: 26 | """ 27 | if not self.was_initialized: 28 | super(nnOODTrainerDS, self).initialize(training, force_load_plans) 29 | 30 | # Set up deep supervision loss 31 | # We need to know the number of outputs of the network 32 | net_numpool = len(self.net_num_pool_op_kernel_sizes) 33 | 34 | # We give each output a weight which decreases exponentially (division by 2) as the resolution decreases 35 | # this gives higher resolution outputs more weight in the loss 36 | weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) 37 | 38 | # We don't use the lowest 2 outputs. Normalize weights so that they sum to 1 39 | mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)]) 40 | weights[~mask] = 0 41 | weights = weights / weights.sum() 42 | self.ds_loss_weights = weights 43 | 44 | self.loss = MultipleOutputLoss(self.loss, self.ds_loss_weights) 45 | 46 | def run_online_evaluation(self, output, target): 47 | """ 48 | due to deep supervision the return value and the reference are now lists of tensors. We only need the full 49 | resolution output because this is what we are interested in in the end. The others are ignored 50 | :param output: 51 | :param target: 52 | :return: 53 | """ 54 | target = target[0] 55 | output = output[0] 56 | return super().run_online_evaluation(output, target) 57 | 58 | def _wrap_ds_fn(self, tmp_ds_val: bool, fn, *args, **kwargs): 59 | """ 60 | Helper for disabling deep supervision when wrapping functions 61 | :param fn: 62 | :param args: 63 | :param kwargs: 64 | :return: 65 | """ 66 | ds = self.network.do_ds 67 | self.network.do_ds = tmp_ds_val 68 | 69 | ret = fn(*args, **kwargs) 70 | 71 | self.network.do_ds = ds 72 | 73 | return ret 74 | 75 | def validate(self, *args, **kwargs): 76 | """ 77 | We need to wrap this because we need to enforce self.network.do_ds = False for prediction. 78 | """ 79 | return self._wrap_ds_fn(False, super(nnOODTrainerDS, self).validate, *args, **kwargs) 80 | 81 | def predict_preprocessed_data(self, *args, **kwargs) -> np.ndarray: 82 | return self._wrap_ds_fn(False, super().predict_preprocessed_data, *args, **kwargs) 83 | 84 | def run_training(self): 85 | return self._wrap_ds_fn(True, super(nnOODTrainerDS, self).run_training) 86 | -------------------------------------------------------------------------------- /nnood/training/nnOOD_run_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from nnood.paths import default_plans_identifier 4 | from nnood.training.network_training.nnOODTrainer import nnOODTrainer 5 | from nnood.utils.default_configuration import get_default_configuration 6 | from nnood.utils.miscellaneous import load_pretrained_weights 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('network', help='can only be one of: \'lowres\', \'fullres\', \'cascade_fullres\'') 12 | parser.add_argument('network_trainer', help='class name of trainer to be used') 13 | parser.add_argument('dataset', help='Dataset Name') 14 | parser.add_argument('task', help='Task name') 15 | parser.add_argument('fold', help='0, 1, ..., 5 or \'all\'') 16 | parser.add_argument('-val', '--validation_only', help='use this if you want to only run the validation', 17 | action='store_true') 18 | parser.add_argument('-c', '--continue_training', help='use this if you want to continue a training', 19 | action='store_true') 20 | parser.add_argument('-cf', '--continue_from', 21 | help='Use this to specify which checkpoint to continue training from. Must also set' 22 | '--continue_training. Must be one of [model_final_checkpoint, model_latest, model_best].' 23 | 'Optional, if not set then training continues from latest checkpoint.', 24 | required=False, default=None) 25 | parser.add_argument('-p', help='plans identifier. Only change this if you created a custom experiment planner', 26 | default=default_plans_identifier, required=False) 27 | parser.add_argument('--use_compressed_data', default=False, action='store_true', 28 | help='If you set use_compressed_data, the training cases will not be decompressed. Reading ' 29 | 'compressed data is much more CPU and RAM intensive and should only be used if you know ' 30 | 'what you are ' 31 | 'doing', required=False) 32 | parser.add_argument('--deterministic', 33 | help='Makes training deterministic, but reduces training speed substantially. Probably not ' 34 | 'necessary. Deterministic training will make you overfit to some random seed. Don\'t use ' 35 | 'that.', 36 | required=False, default=False, action='store_true') 37 | parser.add_argument('--npz', required=False, default=False, action='store_true', help='if set then nnood will ' 38 | 'export npz files of ' 39 | 'predicted segmentations ' 40 | 'in the validation as well. ' 41 | 'This is needed to run the ' 42 | 'ensembling step so unless ' 43 | 'you are developing nnUNet ' 44 | 'you should enable this') 45 | parser.add_argument('--fp32', required=False, default=False, action='store_true', 46 | help='disable mixed precision training and run old school fp32') 47 | parser.add_argument('--val_folder', required=False, default='validation_raw', 48 | help='name of the validation folder. No need to use this for most people') 49 | parser.add_argument('--disable_saving', required=False, action='store_true', 50 | help='If set nnU-Net will not save any parameter files (except a temporary checkpoint that ' 51 | 'will be removed at the end of the training). Useful for development when you are ' 52 | 'only interested in the results and want to save some disk space') 53 | parser.add_argument('--val_disable_overwrite', action='store_false', default=True, 54 | help='Validation does not overwrite existing segmentations') 55 | parser.add_argument('-pretrained_weights', type=str, required=False, default=None, 56 | help='path to nnU-Net checkpoint file to be used as pretrained model (use .model file, for ' 57 | 'example model_final_checkpoint.model). Will only be used when actually training. ' 58 | 'Optional. Beta. Use with caution.') 59 | parser.add_argument('--load_dataset_ram', required=False, action='store_true', default=False, 60 | help='Load entire dataset into RAM, use carefully, only suggested for smaller, 2D datasets.') 61 | 62 | args = parser.parse_args() 63 | 64 | dataset = args.dataset 65 | task = args.task 66 | fold = args.fold 67 | network = args.network 68 | network_trainer = args.network_trainer 69 | validation_only = args.validation_only 70 | plans_identifier = args.p 71 | 72 | continue_training = args.continue_training 73 | continue_from = args.continue_from 74 | load_dataset_ram = args.load_dataset_ram 75 | 76 | use_compressed_data = args.use_compressed_data 77 | decompress_data = not use_compressed_data 78 | 79 | deterministic = args.deterministic 80 | 81 | fp32 = args.fp32 82 | run_mixed_precision = not fp32 83 | 84 | val_folder = args.val_folder 85 | 86 | if fold == 'all': 87 | pass 88 | else: 89 | fold = int(fold) 90 | 91 | plans_file, output_folder_name, dataset_directory, stage, trainer_class, task_class =\ 92 | get_default_configuration(network, dataset, task, network_trainer, plans_identifier) 93 | 94 | if trainer_class is None: 95 | raise RuntimeError('Could not find trainer class in nnood.training.network_training') 96 | if task_class is None: 97 | raise RuntimeError('Could not find task class in nnood.training.network_training') 98 | 99 | if network == 'cascade_fullres': 100 | assert False, 'Trying to run cascade full res, but I haven\'t made that yet!' 101 | # assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \ 102 | # 'If running 3d_cascade_fullres then your ' \ 103 | # 'trainer class must be derived from ' \ 104 | # 'nnOODTrainerCascadeFullRes' 105 | else: 106 | assert issubclass(trainer_class, 107 | nnOODTrainer), 'network_trainer was found but is not derived from nnOODTrainer' 108 | 109 | trainer = trainer_class(plans_file, fold, task_class, output_folder=output_folder_name, 110 | dataset_directory=dataset_directory, stage=stage, unpack_data=decompress_data, 111 | deterministic=deterministic, fp16=run_mixed_precision, load_dataset_ram=load_dataset_ram) 112 | if args.disable_saving: 113 | trainer.save_final_checkpoint = False # whether or not to save the final checkpoint 114 | trainer.save_best_checkpoint = False # whether or not to save the best checkpoint according to 115 | trainer.save_intermediate_checkpoints = True # whether or not to save checkpoint_latest. We need that in case 116 | # the training crashes 117 | trainer.save_latest_only = True # if false it will not store/overwrite _latest but separate files each 118 | 119 | trainer.initialize(not validation_only) 120 | 121 | if not validation_only: 122 | if continue_training: 123 | # -c was set, continue a previous training and ignore pretrained weights 124 | if continue_from: 125 | assert continue_from in ['model_final_checkpoint', 'model_latest', 'model_best'],\ 126 | f'Unexpected checkpoint name: {continue_from}' 127 | 128 | checkpoint_file = trainer.output_folder / f'{continue_from}.model' 129 | assert checkpoint_file.is_file(), f'Missing checkpoint file: {checkpoint_file}' 130 | trainer.load_checkpoint(checkpoint_file) 131 | else: 132 | trainer.load_latest_checkpoint() 133 | 134 | elif (not continue_training) and (args.pretrained_weights is not None): 135 | # Start a new training, using pre-trained weights. 136 | load_pretrained_weights(trainer.network, args.pretrained_weights) 137 | else: 138 | # new training without pretrained weights, do nothing 139 | pass 140 | 141 | trainer.run_training() 142 | else: 143 | trainer.load_final_checkpoint(train=False) 144 | 145 | trainer.network.eval() 146 | 147 | # Predict validation 148 | # trainer.validate(save=args.npz, validation_folder_name=val_folder, overwrite=args.val_disable_overwrite) 149 | 150 | if network == 'lowres': 151 | print('FYI, even though this is a lowres experiment, we don\'t predict next stage as data in self-sup tasks ' 152 | 'is dynamic.') 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /nnood/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /nnood/utils/default_configuration.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import nnood 4 | from nnood.experiment_planning.utils import summarise_plans 5 | from nnood.paths import default_plans_identifier, preprocessed_data_base, results_base 6 | from nnood.utils.file_operations import load_pickle 7 | from nnood.utils.miscellaneous import recursive_find_python_class 8 | 9 | 10 | def get_default_configuration(network: str, dataset_name: str, task_name: str, network_trainer: str, 11 | plans_identifier=default_plans_identifier, silent=False): 12 | 13 | assert network in ['lowres', 'fullres', 'cascade_fullres'], \ 14 | 'network can only be one of the following: \'lowres\', \'fullres\', \'cascade_fullres\'' 15 | 16 | prep_data_base = Path(preprocessed_data_base) 17 | dataset_directory = prep_data_base / dataset_name 18 | plans_file = prep_data_base / dataset_name / plans_identifier 19 | 20 | plans = load_pickle(plans_file) 21 | possible_stages = list(plans['plans_per_stage'].keys()) 22 | 23 | if (network == 'cascade_fullres' or network == 'lowres') and len(possible_stages) == 1: 24 | raise RuntimeError('lowres/cascade_fullres only applies if there is more than one stage. This task does ' 25 | 'not require the cascade. Run fullres instead') 26 | 27 | if network == 'lowres': 28 | stage = 0 29 | else: 30 | stage = possible_stages[-1] 31 | 32 | trainer_class = recursive_find_python_class([Path(nnood.__path__[0], 'training', 'network_training').__str__()], 33 | network_trainer, current_module='nnood.training.network_training') 34 | task_class = recursive_find_python_class([Path(nnood.__path__[0], 'self_supervised_task').__str__()], 35 | task_name, current_module='nnood.self_supervised_task') 36 | 37 | output_folder_name = Path(results_base, dataset_name, task_name, network, network_trainer + '__' + plans_identifier) 38 | 39 | if not silent: 40 | print('###############################################') 41 | print('I am running the following nnUNet: %s' % network) 42 | print('My trainer class is: ', trainer_class) 43 | print('My task class is: ', task_class) 44 | print('For that I will be using the following configuration:') 45 | summarise_plans(plans) 46 | print('I am using stage %d from these plans' % stage) 47 | 48 | print('\nI am using data from this folder: ', dataset_directory / f'{plans["data_identifier"]}_stage{stage}') 49 | 50 | print('###############################################') 51 | return plans_file, output_folder_name, dataset_directory, stage, trainer_class, task_class 52 | -------------------------------------------------------------------------------- /nnood/utils/file_operations.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | from typing import Union, Any 4 | from pathlib import Path 5 | 6 | from nnood.paths import results_base, default_plans_identifier 7 | 8 | 9 | def load_pickle(file_path: Union[str, Path], mode: str = 'rb'): 10 | with open(file_path, mode) as f: 11 | content = pickle.load(f) 12 | return content 13 | 14 | 15 | def save_pickle(content: Any, file_path: Union[str, Path], mode: str = 'wb'): 16 | with open(file_path, mode) as f: 17 | pickle.dump(content, f) 18 | 19 | 20 | def load_json(file_path: Union[str, Path]): 21 | with open(file_path) as f: 22 | content = json.load(f) 23 | return content 24 | 25 | 26 | def save_json(content: dict, file_path: Union[str, Path], indent: int = 4, sort_keys: bool = True): 27 | with open(file_path, 'w') as f: 28 | json.dump(content, f, sort_keys=sort_keys, indent=indent) 29 | 30 | 31 | def load_results_json(dataset: str, task: str, plans_identifier: str = default_plans_identifier): 32 | results_folder = Path(results_base, dataset, task, 'testResults', plans_identifier) 33 | return load_json(results_folder / 'summary.json') 34 | -------------------------------------------------------------------------------- /nnood/utils/miscellaneous.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import importlib 3 | import pkgutil 4 | from pathlib import Path 5 | import re 6 | from typing import Dict, List, Tuple 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class no_op: 13 | def __enter__(self): 14 | pass 15 | 16 | def __exit__(self, *args): 17 | pass 18 | 19 | 20 | def recursive_find_python_class(folder, class_name, current_module): 21 | tr = None 22 | for importer, modname, ispkg in pkgutil.iter_modules(folder): 23 | if not ispkg: 24 | m = importlib.import_module(current_module + "." + modname) 25 | if hasattr(m, class_name): 26 | tr = getattr(m, class_name) 27 | break 28 | 29 | if tr is None: 30 | for importer, modname, ispkg in pkgutil.iter_modules(folder): 31 | if ispkg: 32 | next_current_module = current_module + "." + modname 33 | tr = recursive_find_python_class([folder[0], modname], class_name, 34 | current_module=next_current_module) 35 | if tr is not None: 36 | break 37 | 38 | return tr 39 | 40 | 41 | def load_pretrained_weights(network, fname, verbose=False): 42 | """ 43 | THIS DOES NOT TRANSFER SEGMENTATION HEADS! 44 | """ 45 | saved_model = torch.load(fname) 46 | pretrained_dict = saved_model['state_dict'] 47 | 48 | new_state_dict = {} 49 | 50 | # if state dict comes form nn.DataParallel but we use non-parallel model here then the state dict keys do not 51 | # match. Use heuristic to make it match 52 | for k, value in pretrained_dict.items(): 53 | key = k 54 | # remove module. prefix from DDP models 55 | if key.startswith('module.'): 56 | key = key[7:] 57 | new_state_dict[key] = value 58 | 59 | pretrained_dict = new_state_dict 60 | 61 | model_dict = network.state_dict() 62 | ok = True 63 | for key, _ in model_dict.items(): 64 | if 'conv_blocks' in key: 65 | if (key in pretrained_dict) and (model_dict[key].shape == pretrained_dict[key].shape): 66 | continue 67 | else: 68 | ok = False 69 | break 70 | 71 | # filter unnecessary keys 72 | if ok: 73 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if 74 | (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)} 75 | # 2. overwrite entries in the existing state dict 76 | model_dict.update(pretrained_dict) 77 | print("################### Loading pretrained weights from file ", fname, '###################') 78 | if verbose: 79 | print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:") 80 | for key, _ in pretrained_dict.items(): 81 | print(key) 82 | print("################### Done ###################") 83 | network.load_state_dict(model_dict) 84 | else: 85 | raise RuntimeError("Pretrained weights are not compatible with the current network architecture") 86 | 87 | 88 | def get_sample_ids_and_files(input_folder: Path, expected_modalities: Dict[int, str]) -> List[Tuple[str, List[Path]]]: 89 | 90 | assert input_folder.is_dir() 91 | 92 | id_to_files_dict = OrderedDict() 93 | 94 | for f in input_folder.iterdir(): 95 | if not f.is_file(): 96 | continue 97 | 98 | match = re.fullmatch('(.*)_(\d\d\d\d)(\..*)', f.name) 99 | 100 | if match is None: 101 | print(f'File "{f}" does not match the expected name format, so will be ignored') 102 | continue 103 | 104 | sample_id, mod_num, suffix = match.groups() 105 | mod_num = int(mod_num) 106 | 107 | if sample_id not in id_to_files_dict: 108 | id_to_files_dict[sample_id] = OrderedDict() 109 | 110 | id_to_files_dict[sample_id][mod_num] = f 111 | 112 | invalid_folder = False 113 | for sample_id, files in id_to_files_dict.items(): 114 | 115 | missing_mods = [m for m in range(len(expected_modalities)) if m not in files] 116 | 117 | # noinspection PySimplifyBooleanCheck 118 | if missing_mods != []: # Don't like simplification, removes clarity 119 | print(f'Sample "{sample_id}" is missing modalities: {missing_mods}') 120 | invalid_folder = True 121 | continue 122 | 123 | # Check modality formats 124 | for i, mod in expected_modalities.items(): 125 | if 'png' in mod and files[i].suffix != '.png': 126 | print(f'Sample "{sample_id}" expected a .png for modality {i}: {files[i]}') 127 | invalid_folder = True 128 | 129 | if invalid_folder: 130 | raise RuntimeError(f'Problems with files in {input_folder}') 131 | 132 | id_to_files_list = [] 133 | # convert dictionary to list, return 134 | for sample_id, file_dict in id_to_files_dict.items(): 135 | sample_files = list(file_dict.values()) 136 | sample_files.sort() 137 | 138 | id_to_files_list.append((sample_id, sample_files)) 139 | 140 | return id_to_files_list 141 | 142 | 143 | def make_pos_enc(sample_shape: np.ndarray) -> np.ndarray: 144 | """ 145 | :param sample_shape: Shape of sample data, excluding channels dimension 146 | :return: positional encoding of shape 147 | """ 148 | # One coordinate encoding channel per dimension 149 | sample_coords = np.zeros((len(sample_shape), *sample_shape)) 150 | 151 | for dim_num, dim_size in enumerate(sample_shape): 152 | # Coordinates range from -1 to 1 in each dimension 153 | dim_coords = np.linspace(-1, 1, dim_size) 154 | 155 | # Expand coords so they are correctly broadcast (meaning they only change in their respective dimension) 156 | for _ in range(dim_num): 157 | dim_coords = np.expand_dims(dim_coords, axis=0) 158 | 159 | for _ in range(len(sample_shape) - dim_num - 1): 160 | dim_coords = np.expand_dims(dim_coords, axis=-1) 161 | 162 | sample_coords[dim_num] = dim_coords 163 | 164 | return sample_coords 165 | 166 | 167 | def make_hypersphere_mask(radius: int, dims: int): 168 | L = np.arange(-radius, radius + 1) 169 | # It thinks meshgrid returns a string for some reason 170 | # noinspection PyTypeChecker 171 | mg: List[np.ndarray] = np.meshgrid(*([L] * dims)) 172 | return np.sum([D ** 2 for D in mg], axis=0) <= radius ** 2 173 | 174 | 175 | def make_default_mask(): 176 | return np.array([]) 177 | -------------------------------------------------------------------------------- /nnood/utils/to_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def maybe_to_torch(d): 5 | if isinstance(d, list): 6 | d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d] 7 | elif not isinstance(d, torch.Tensor): 8 | d = torch.from_numpy(d).float() 9 | return d 10 | 11 | 12 | def to_cuda(data, non_blocking=True, gpu_id=0): 13 | if isinstance(data, list): 14 | data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data] 15 | else: 16 | data = data.cuda(gpu_id, non_blocking=non_blocking) 17 | return data 18 | -------------------------------------------------------------------------------- /notebooks/mvtec_obj_stats: -------------------------------------------------------------------------------- 1 | { 2 | "Names": [ 3 | "Avg Width", 4 | "Avg Height", 5 | "Avg Area" 6 | ], 7 | "bottle": [ 8 | 0.7173256050429903, 9 | 0.7176468249554452, 10 | 0.6585728395061727 11 | ], 12 | "cable": [ 13 | 1, 14 | 1, 15 | 1 16 | ], 17 | "capsule": [ 18 | 0.7734256626578657, 19 | 0.25443810950465723, 20 | 0.22041565999999999 21 | ], 22 | "carpet": [ 23 | 1, 24 | 1, 25 | 1 26 | ], 27 | "grid": [ 28 | 1, 29 | 1, 30 | 1 31 | ], 32 | "hazelnut": [ 33 | 0.5228198711918352, 34 | 0.5157151684145468, 35 | 0.3518977165222168 36 | ], 37 | "leather": [ 38 | 1, 39 | 1, 40 | 1 41 | ], 42 | "metal_nut": [ 43 | 0.6131968469857051, 44 | 0.6134618172954254, 45 | 0.530491918367347 46 | ], 47 | "pill": [ 48 | 0.6820452496374071, 49 | 0.4732081863891496, 50 | 0.42012362499999995 51 | ], 52 | "screw": [ 53 | 0.19722929140107912, 54 | 0.18863691351511225, 55 | 0.10097871780395508 56 | ], 57 | "tile": [ 58 | 1, 59 | 1, 60 | 1 61 | ], 62 | "toothbrush": [ 63 | 0.3160567873443487, 64 | 0.7759843525947859, 65 | 0.30294075012207033 66 | ], 67 | "transistor": [ 68 | 1, 69 | 1, 70 | 1 71 | ], 72 | "wood": [ 73 | 1, 74 | 1, 75 | 1 76 | ], 77 | "zipper": [ 78 | 0.7847899627685547, 79 | 0.9914257195083441, 80 | 0.7847899627685547 81 | ] 82 | } -------------------------------------------------------------------------------- /notebooks/readme.md: -------------------------------------------------------------------------------- 1 | Notebooks for demoing and debugging self-supervised tasks. 2 | - self_sup_task_visualiser: demos complete self-supervised tasks. 3 | - patch_interpolation_helper_test: can experiment with different task components to make new tasks. 4 | - dataloader_test: visualise the output of a dataloader (with a calibrated task). 5 | - object_mask_helper: compare the statistics of the object masks of different datasets (was used for computing the generalised NSA parameters). 6 | - sampling_tests: visualise the different location sampling distributions. 7 | - trainer_test: test epochs in a new or previously trained trainers. 8 | - results_notebooks: review results from trained and tested models. -------------------------------------------------------------------------------- /notebooks/results_notebooks/mvtec.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "tags": [], 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "from collections import OrderedDict\n", 15 | "import pandas as pd\n", 16 | "\n", 17 | "from nnood.data.dataset_conversion.convert_mvtec import OBJECTS, TEXTURES\n", 18 | "from nnood.utils.file_operations import load_results_json\n", 19 | "import numpy as np\n", 20 | "\n", 21 | "def get_results(dataset, task, plans_identifier):\n", 22 | " result_metrics = load_results_json(dataset, task, plans_identifier)['results']\n", 23 | " return result_metrics" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": { 30 | "pycharm": { 31 | "name": "#%%\n" 32 | } 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "\n", 37 | "tasks = ['FPI', 'CutPaste']\n", 38 | "metrics = ['AP score', 'AUROC']" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "pycharm": { 46 | "name": "#%%\n" 47 | } 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "results = OrderedDict()\n", 52 | "\n", 53 | "for m in metrics:\n", 54 | " results[m] = []\n", 55 | "\n", 56 | "for class_name in OBJECTS + TEXTURES:\n", 57 | " \n", 58 | " if class_name is TEXTURES[0]:\n", 59 | " # Add objects average\n", 60 | " for m in metrics:\n", 61 | " all_object_results = results[m]\n", 62 | " results[m].append(np.mean(all_object_results, axis=0))\n", 63 | " \n", 64 | " for m in metrics:\n", 65 | " results[m].append([])\n", 66 | " \n", 67 | " for t in tasks:\n", 68 | " metric_results = get_results('mvtec_ad_' + class_name, t)\n", 69 | "\n", 70 | " for m in metrics:\n", 71 | " results[m][-1].append(metric_results[m])\n", 72 | "\n", 73 | "for m in metrics:\n", 74 | " # Add texture average\n", 75 | " all_texture_results = results[m][-len(TEXTURES):]\n", 76 | " \n", 77 | " results[m].append(np.mean(all_texture_results, axis=0))\n", 78 | " \n", 79 | " # Add total average\n", 80 | " all_nonaverage_results = np.concatenate((results[m][:len(OBJECTS)], results[m][len(OBJECTS) + 1: -1]))\n", 81 | " \n", 82 | " assert all_nonaverage_results.shape[0] == (len(OBJECTS) + len(TEXTURES)), f'Shape: {all_nonaverage_results.shape}'\n", 83 | " results[m].append(np.mean(all_nonaverage_results, axis=0))\n", 84 | "\n", 85 | "\n", 86 | " \n", 87 | " " 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "pycharm": { 95 | "name": "#%%\n" 96 | } 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "miindex = pd.MultiIndex.from_tuples(\n", 101 | " [('Object', obj) for obj in OBJECTS + ['Average']] + [('Texture', txtr) for txtr in TEXTURES + ['Average']] + [('Total', 'Average')])\n", 102 | "dataframes = OrderedDict()\n", 103 | "\n", 104 | "for m in metrics:\n", 105 | " dataframes[m] = pd.DataFrame(results[m], index=miindex, columns=tasks)\n" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": { 112 | "pycharm": { 113 | "name": "#%%\n" 114 | } 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "dataframes['AUROC']" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": { 125 | "pycharm": { 126 | "name": "#%%\n" 127 | } 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "auroc_avgs = dataframes['AUROC'].loc['Total', 'Average']\n" 132 | ] 133 | } 134 | ], 135 | "metadata": { 136 | "kernelspec": { 137 | "display_name": "Python 3", 138 | "language": "python", 139 | "name": "python3" 140 | }, 141 | "language_info": { 142 | "codemirror_mode": { 143 | "name": "ipython", 144 | "version": 3 145 | }, 146 | "file_extension": ".py", 147 | "mimetype": "text/x-python", 148 | "name": "python", 149 | "nbconvert_exporter": "python", 150 | "pygments_lexer": "ipython3", 151 | "version": "3.8.10" 152 | } 153 | }, 154 | "nbformat": 4, 155 | "nbformat_minor": 4 156 | } -------------------------------------------------------------------------------- /notebooks/trainer_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "pycharm": { 8 | "name": "#%%\n" 9 | } 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import numpy as np\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import os\n", 16 | "import sys\n", 17 | "from tqdm import tqdm\n", 18 | "\n", 19 | "module_path = os.path.abspath(os.path.join('../..'))\n", 20 | "\n", 21 | "if module_path not in sys.path:\n", 22 | " sys.path.append(module_path)\n", 23 | "\n", 24 | "def wl_to_lh(window, level):\n", 25 | " low = level - window / 2\n", 26 | " high = level + window / 2\n", 27 | " return low,high\n", 28 | "\n", 29 | "def display_image(img, phys_size=None, window=None, level=None, existing_ax=None):\n", 30 | "\n", 31 | " if window is None:\n", 32 | " window = np.max(img) - np.min(img)\n", 33 | "\n", 34 | " if level is None:\n", 35 | " level = window / 2 + np.min(img)\n", 36 | "\n", 37 | " low,high = wl_to_lh(window,level)\n", 38 | "\n", 39 | " if existing_ax is None:\n", 40 | " # Display the orthogonal slices\n", 41 | " fig, axes = plt.subplots(figsize=(14, 8))\n", 42 | " else:\n", 43 | " axes = existing_ax\n", 44 | "\n", 45 | " axes.imshow(img, clim=(low, high), extent= None if phys_size is None else (0, phys_size[0], phys_size[1], 0), cmap='gray')\n", 46 | "\n", 47 | " if existing_ax is None:\n", 48 | " plt.show()\n", 49 | " \n", 50 | "def print_stats(arr):\n", 51 | " print(np.mean(arr),', ',np.std(arr))\n", 52 | " print(np.min(arr), '-', np.max(arr))\n", 53 | " print(arr.shape)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "pycharm": { 61 | "name": "#%%\n" 62 | } 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "from nnood.utils.default_configuration import get_default_configuration\n", 67 | "\n", 68 | "def prepare_test_trainer(network_type, dset_name, task, network_trainer_type, fold):\n", 69 | "\n", 70 | " plans_file, output_folder_name, dataset_directory, stage, trainer_class, task_class =\\\n", 71 | " get_default_configuration(network_type, dset_name, task, network_trainer_type, silent=True)\n", 72 | " \n", 73 | " trainer = trainer_class(plans_file, fold, task_class, output_folder=output_folder_name,\n", 74 | " dataset_directory=dataset_directory, stage=stage, unpack_data=True,\n", 75 | " deterministic=False, fp16=True, load_dataset_ram=False)\n", 76 | " \n", 77 | " trainer.no_print = True\n", 78 | " \n", 79 | " # Need to set training to get datasets loaded\n", 80 | " trainer.initialize(training=True)\n", 81 | " trainer.load_final_checkpoint(train=True)\n", 82 | " trainer.network.eval()\n", 83 | " trainer.track_auroc = trainer.track_metrics = trainer.track_ap = True\n", 84 | " \n", 85 | " return trainer" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "pycharm": { 93 | "name": "#%%\n" 94 | } 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "curr_trainer = prepare_test_trainer('fullres', 'chestXray14_PA_male', 'FPI', 'nnOODTrainerDS', 0)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "pycharm": { 106 | "name": "#%%\n" 107 | } 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "import torch\n", 112 | "\n", 113 | "def run_test_batch(trnr):\n", 114 | " \n", 115 | " with torch.no_grad():\n", 116 | " trnr.run_iteration(trnr.val_gen, False, True)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "pycharm": { 124 | "name": "#%%\n" 125 | } 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "curr_trainer.track_ap = True\n", 130 | "curr_trainer.trac_auroc = True\n", 131 | "\n", 132 | "for _ in tqdm(range(100)):\n", 133 | " run_test_batch(curr_trainer)\n", 134 | " \n", 135 | "curr_trainer.finish_online_evaluation()" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "pycharm": { 143 | "name": "#%%\n" 144 | } 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "all_results_auroc = {}\n", 149 | "all_results_ap = {}\n", 150 | "\n", 151 | "for dset in ['chestXray14_PA_male', 'chestXray14_PA_female']:\n", 152 | " all_results_auroc[dset] = {}\n", 153 | " all_results_ap[dset] = {}\n", 154 | " print('Dataset', dset)\n", 155 | " \n", 156 | " for t in ['FPI', 'CutPaste', 'PII', 'NSA', 'NSAMixed']:\n", 157 | " print('Task', t)\n", 158 | " \n", 159 | " all_results_auroc[dset][t] = {'all': []}\n", 160 | " all_results_ap[dset][t] = {'all': []}\n", 161 | " for i in range(5):\n", 162 | " tmp_trainer = prepare_test_trainer('fullres', dset, t, 'nnOODTrainerDS', i)\n", 163 | " \n", 164 | " for _ in tqdm(range(40), desc=f'Fold {i}'):\n", 165 | " run_test_batch(tmp_trainer)\n", 166 | " \n", 167 | " fold_res = tmp_trainer.finish_online_evaluation()\n", 168 | " all_results_auroc[dset][t]['all'].append(fold_res['AUROC'])\n", 169 | " all_results_ap[dset][t]['all'].append(fold_res['AP'])\n", 170 | " \n", 171 | " all_results_auroc[dset][t]['avg'] = np.mean(all_results_auroc[dset][t]['all'])\n", 172 | " all_results_auroc[dset][t]['std'] = np.std(all_results_auroc[dset][t]['all'])\n", 173 | " \n", 174 | " all_results_ap[dset][t]['avg'] = np.mean(all_results_ap[dset][t]['all'])\n", 175 | " all_results_ap[dset][t]['std'] = np.std(all_results_ap[dset][t]['all'])\n", 176 | " \n", 177 | " print('Average AUROC', all_results_auroc[dset][t]['avg'])\n", 178 | " print('Average AP', all_results_ap[dset][t]['avg'])\n", 179 | " print()\n", 180 | " \n", 181 | " " 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": { 188 | "pycharm": { 189 | "name": "#%%\n" 190 | } 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "all_results_ap" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": { 201 | "pycharm": { 202 | "name": "#%%\n" 203 | } 204 | }, 205 | "outputs": [], 206 | "source": [ 207 | "from nnood.utils.file_operations import save_json, load_json\n", 208 | "\n", 209 | "save_json(all_results_auroc, 'trainer_auroc_resultsAP09.json')\n", 210 | "save_json(all_results_ap, 'trainer_ap_resultsAP09.json')" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": { 217 | "pycharm": { 218 | "name": "#%%\n" 219 | } 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "\n", 224 | "all_results_auroc_old = load_json('trainer_auroc_results.json')\n", 225 | "all_results_ap_old = load_json('trainer_ap_results.json')" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": { 232 | "pycharm": { 233 | "name": "#%%\n" 234 | } 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "for d_set in all_results_ap.keys():\n", 239 | " print(d_set, '\\n')\n", 240 | " for t in all_results_ap[d_set].keys():\n", 241 | " print(t)\n", 242 | " print('Old: ', all_results_ap_old[d_set][t]['avg'])\n", 243 | " print('New: ', all_results_ap[d_set][t]['avg'])\n", 244 | " print()" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": { 251 | "pycharm": { 252 | "name": "#%%\n" 253 | } 254 | }, 255 | "outputs": [], 256 | "source": [ 257 | "all_results_ap" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": { 264 | "pycharm": { 265 | "name": "#%%\n" 266 | } 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "all_results_auroc" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": { 277 | "pycharm": { 278 | "name": "#%%\n" 279 | } 280 | }, 281 | "outputs": [], 282 | "source": [] 283 | } 284 | ], 285 | "metadata": { 286 | "kernelspec": { 287 | "display_name": "Python 3", 288 | "language": "python", 289 | "name": "python3" 290 | }, 291 | "language_info": { 292 | "codemirror_mode": { 293 | "name": "ipython", 294 | "version": 3 295 | }, 296 | "file_extension": ".py", 297 | "mimetype": "text/x-python", 298 | "name": "python", 299 | "nbconvert_exporter": "python", 300 | "pygments_lexer": "ipython3", 301 | "version": "3.8.10" 302 | } 303 | }, 304 | "nbformat": 4, 305 | "nbformat_minor": 4 306 | } -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # nnOOD: A Framework for Benchmarking Self-supervised Anomaly Localisation Methods 2 | 3 | Recently there have been a number of anomaly detection methods that introduce synthetic anomalies into otherwise healthy 4 | data. 5 | The aim of these methods is to generalise to real, unseen anomalies. 6 | nnOOD is able to compare these methods under a standardised training regime, based on 7 | [nnU-Net](https://github.com/MIC-DKFZ/nnUNet/), where each model is trained to an equal level 8 | of saturation (each can perform the selected synthetic task similarly well). 9 | 10 | To facilitate future task development, we also provide a compartmentalised framework for 11 | constructing tasks based on blending patches (the current state-of-the-art), to facilitate further 12 | investigation into this area. 13 | 14 | Got any feedback on this project or ideas on how to make it more useful? 15 | I'd love to hear them! 16 | [Open an issue](https://github.com/matt-baugh/nnOOD/issues) and we can talk it through :). 17 | 18 | ![overview](documentation/nnOOD_overview_v2.png) 19 | 20 | *Overview of the nnOOD framework. The green components are entirely new to nnOOD, orange components differ 21 | significantly from their nnU-Net counterparts and grey components have only minor changes.* 22 | 23 | ## Installation 24 | 25 | 1. Install [PyTorch](https://pytorch.org/get-started/locally/), making sure that you use a version which works with 26 | your hardware. 27 | 2. Install nnOOD using 28 | ```bash 29 | git clone https://github.com/matt-baugh/nnOOD 30 | cd nnOOD 31 | pip install -e . 32 | ``` 33 | 34 | Once installed follow the [experiment walkthrough](documentation/experiment_walkthrough.md) to run nnOOD. 35 | 36 | ## Summary of main entrypoints: 37 | 38 | Full details on how to use these entrypoints to run an experiment are available in 39 | [documentation/experiment_walkthrough.md](documentation/experiment_walkthrough.md). 40 | 41 | - [nnood/experiment_planning/nnOOD_plan_and_preprocess.py](nnood/experiment_planning/nnOOD_plan_and_preprocess.py) 42 | - Given a dataset in the correct format, plan the experiment + model configuration, and optionally prepare the data. 43 | - Dataset should be within folder named in environment variable `nnood_raw_data_base` (see [nnood/paths.py](nnood/paths.py)) 44 | - [nnood/training/nnOOD_run_training.py](nnood/training/nnOOD_run_training.py) 45 | - Train a model on a given dataset using a given self-supervised task. 46 | - Must have already run nnOOD_plan_and_preprocess for dataset. 47 | - Self-supervised task must be the name of class within `nnood/self_supervised_task`, which extends and implements 48 | the interface of [nnood/self_supervised_task/self_sup_task.py](nnood/self_supervised_task/self_sup_task.py) 49 | - [nnood/evaluation/nnOOD_run_testing.py](nnood/evaluation/nnOOD_run_testing.py) 50 | - Predict and evaluate a model, trained on a specified self-supervised task, on the datasets test set. 51 | - Test set is defined as images in `$nnood_raw_data_base/DATASET_NAME/imagesTs` 52 | - Labels are in `$nnood_raw_data_base/DATASET_NAME/labelsTs` . If there is no label present for a sample, we assume 53 | it to be normal, so assign a label of all zeroes. 54 | 55 | ## Existing self-supervised tasks 56 | - [Foreign Patch Interpolation](https://www.melba-journal.org/papers/2022:013.html) 57 | - [CutPaste](https://openaccess.thecvf.com/content/CVPR2021/html/Li_CutPaste_Self-Supervised_Learning_for_Anomaly_Detection_and_Localization_CVPR_2021_paper.html) 58 | - [Poisson Image Interpolation](https://link.springer.com/content/pdf/10.1007%2F978-3-030-87240-3_56.pdf) 59 | - [Natural Synthetic Anomalies](https://arxiv.org/abs/2109.15222) (both source and mixed gradient variants). 60 | 61 | Full details on how to implement your own self-supervised tasks, either from scratch or using the modular components for 62 | making a patch-blending based task, are available in 63 | [documentation/synthetic_task_guide.md](documentation/synthetic_task_guide.md). 64 | 65 | ## Notebooks 66 | 67 | To test or demo components of nnOOD, use the notebooks within `notebooks`, described in [notebooks/readme.md](notebooks/readme.md). -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='nnood', 5 | version='', 6 | packages=['nnood', 'nnood.data', 'nnood.data.dataset_conversion', 'nnood.utils', 'nnood.training', 7 | 'nnood.training.dataloading', 'nnood.training.loss_functions', 'nnood.training.network_training', 8 | 'nnood.training.data_augmentation', 'nnood.inference', 'nnood.evaluation', 'nnood.preprocessing', 9 | 'nnood.experiment_planning', 'nnood.network_architecture', 'nnood.self_supervised_task', 10 | 'nnood.self_supervised_task.patch_transforms'], 11 | url='', 12 | license='', 13 | author='', 14 | author_email='', 15 | description='', 16 | install_requires=[ 17 | 'numpy', 18 | 'nibabel', 19 | 'SimpleITK', 20 | 'tqdm', 21 | 'opencv-python', 22 | 'pandas', 23 | 'torch>=1.10.0', 24 | 'matplotlib', 25 | 'sklearn', 26 | 'scikit-learn>=1.0.1', 27 | 'batchgenerators>=0.23', 28 | 'scikit-image>=0.19.0', 29 | 'argparse', 30 | 'scipy', 31 | 'unittest2', 32 | 'pie-torch' 33 | ] 34 | ) 35 | --------------------------------------------------------------------------------