├── dataio ├── __init__.py ├── loader │ ├── hms_dataset.py │ ├── __init__.py │ ├── test_dataset.py │ ├── utils.py │ ├── cmr_3D_dataset.py │ ├── ukbb_dataset.py │ └── us_dataset.py └── transformation │ ├── __init__.py │ ├── transforms.py │ └── myImageTransformations.py ├── utils ├── __init__.py ├── html.py ├── util.py ├── error_logger.py ├── post_process_crf.py ├── metrics.py └── visualiser.py ├── models ├── layers │ ├── __init__.py │ ├── loss.py │ └── nonlocal_layer.py ├── networks │ ├── unet_2D.py │ ├── unet_3D.py │ ├── __init__.py │ ├── sononet.py │ ├── unet_nonlocal_2D.py │ ├── unet_nonlocal_3D.py │ ├── unet_CT_dsv_3D.py │ ├── unet_grid_attention_3D.py │ ├── sononet_grid_attention.py │ ├── unet_CT_single_att_dsv_3D.py │ └── unet_CT_multi_att_dsv_3D.py ├── __init__.py ├── base_model.py ├── aggregated_classifier.py ├── utils.py ├── feedforward_seg_model.py └── feedforward_classifier.py ├── figures ├── figure1.png └── figure2.jpg ├── setup.py ├── LICENSE ├── README.md ├── configs ├── config_unet_ct_dsv.json ├── config_unet_ct_multi_att_dsv.json ├── config_sononet_8.json ├── config_sononet_grid_att_8_ft.json ├── config_sononet_grid_att_8.json └── config_sononet_grid_att_8_deepsup.json ├── .gitignore ├── visualise_fmaps.py ├── visualise_att_maps_epoch.py ├── validation.py ├── train_segmentation.py ├── test_classification.py ├── visualise_attention.py └── train_classifaction.py /dataio/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozan-oktay/Attention-Gated-Networks/HEAD/figures/figure1.png -------------------------------------------------------------------------------- /figures/figure2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ozan-oktay/Attention-Gated-Networks/HEAD/figures/figure2.jpg -------------------------------------------------------------------------------- /dataio/loader/hms_dataset.py: -------------------------------------------------------------------------------- 1 | # Author: Ozan Oktay 2 | # Date: January 2018 3 | 4 | class HMSDataset: 5 | 6 | def __init__(self): 7 | raise NotImplemented -------------------------------------------------------------------------------- /dataio/transformation/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataio.transformation.transforms import Transformations 3 | 4 | 5 | def get_dataset_transformation(name, opts=None): 6 | ''' 7 | :param opts: augmentation parameters 8 | :return: 9 | ''' 10 | # Build the transformation object and initialise the augmentation parameters 11 | trans_obj = Transformations(name) 12 | if opts: trans_obj.initialise(opts) 13 | 14 | # Print the input options 15 | trans_obj.print() 16 | 17 | # Returns a dictionary of transformations 18 | return trans_obj.get_transformation() 19 | -------------------------------------------------------------------------------- /dataio/loader/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from dataio.loader.ukbb_dataset import UKBBDataset 4 | from dataio.loader.test_dataset import TestDataset 5 | from dataio.loader.hms_dataset import HMSDataset 6 | from dataio.loader.cmr_3D_dataset import CMR3DDataset 7 | from dataio.loader.us_dataset import UltraSoundDataset 8 | 9 | 10 | def get_dataset(name): 11 | """get_dataset 12 | 13 | :param name: 14 | """ 15 | return { 16 | 'ukbb_sax': CMR3DDataset, 17 | 'acdc_sax': CMR3DDataset, 18 | 'rvsc_sax': CMR3DDataset, 19 | 'hms_sax': HMSDataset, 20 | 'test_sax': TestDataset, 21 | 'us': UltraSoundDataset 22 | }[name] 23 | 24 | 25 | def get_dataset_path(dataset_name, opts): 26 | """get_data_path 27 | 28 | :param dataset_name: 29 | :param opts: 30 | """ 31 | 32 | return getattr(opts, dataset_name) 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | with open('README.md') as f: 6 | readme = f.read() 7 | 8 | setup(name='AttentionGatedNetworks', 9 | version='1.0', 10 | description='Pytorch library for Soft Attention', 11 | long_description=readme, 12 | author='Ozan Oktay & Jo Schlemper', 13 | install_requires=[ 14 | "numpy", 15 | "torch", 16 | "matplotlib", 17 | "scipy", 18 | "torchvision", 19 | "tqdm", 20 | "visdom", 21 | "nibabel", 22 | "scikit-image", 23 | "h5py", 24 | "pandas", 25 | "dominate", 26 | 'torchsample==0.1.3', 27 | ], 28 | dependency_links=[ 29 | 'https://github.com/ozan-oktay/torchsample/tarball/master#egg=torchsample-0.1.3' 30 | ], 31 | packages=find_packages(exclude=('tests', 'docs')) 32 | ) 33 | 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ozan Oktay 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention Gated Networks
(Image Classification & Segmentation) 2 | 3 | Pytorch implementation of attention gates used in U-Net and VGG-16 models. The framework can be utilised in both medical image classification and segmentation tasks. 4 | 5 |

6 |
7 | The schematics of the proposed Attention-Gated Sononet 8 |

9 | 10 |

11 |
12 | The schematics of the proposed additive attention gate 13 |

14 | 15 | ### References: 16 | 17 | 1) "Attention-Gated Networks for Improving Ultrasound Scan Plane Detection", MIDL'18, Amsterdam
18 | [Conference Paper](https://openreview.net/pdf?id=BJtn7-3sM)
19 | [Conference Poster](https://www.doc.ic.ac.uk/~oo2113/posters/MIDL2018_poster_Jo.pdf) 20 | 21 | 2) "Attention U-Net: Learning Where to Look for the Pancreas", MIDL'18, Amsterdam
22 | [Conference Paper](https://openreview.net/pdf?id=Skft7cijM)
23 | [Conference Poster](https://www.doc.ic.ac.uk/~oo2113/posters/MIDL2018_poster.pdf) 24 | 25 | ### Installation 26 | pip install --process-dependency-links -e . 27 | 28 | -------------------------------------------------------------------------------- /configs/config_unet_ct_dsv.json: -------------------------------------------------------------------------------- 1 | { 2 | "training":{ 3 | "arch_type": "acdc_sax", 4 | "n_epochs": 1000, 5 | "save_epoch_freq": 10, 6 | "lr_policy": "step", 7 | "lr_decay_iters": 250, 8 | "batchSize": 2, 9 | "preloadData": true 10 | }, 11 | "visualisation":{ 12 | "display_port": 8098, 13 | "no_html": true, 14 | "display_winsize": 256, 15 | "display_id": 1, 16 | "display_single_pane_ncols": 0 17 | }, 18 | "data_path": { 19 | "acdc_sax": "/vol/biomedic2/oo2113/dataset/ken_abdominal_ct/" 20 | }, 21 | "augmentation": { 22 | "acdc_sax": { 23 | "shift": [0.1,0.1], 24 | "rotate": 15.0, 25 | "scale": [0.7,1.3], 26 | "intensity": [1.0,1.0], 27 | "random_flip_prob": 0.5, 28 | "scale_size": [160,160,96], 29 | "patch_size": [160,160,96] 30 | } 31 | }, 32 | "model":{ 33 | "type":"seg", 34 | "continue_train": false, 35 | "which_epoch": -1, 36 | "model_type": "unet_ct_dsv", 37 | "tensor_dim": "3D", 38 | "division_factor": 16, 39 | "input_nc": 1, 40 | "output_nc": 4, 41 | "lr_rate": 1e-4, 42 | "l2_reg_weight": 1e-6, 43 | "feature_scale": 4, 44 | "gpu_ids": [0], 45 | "isTrain": true, 46 | "checkpoints_dir": "./checkpoints", 47 | "experiment_name": "experiment_unet_ct_dsv_big", 48 | "criterion": "dice_loss" 49 | } 50 | } 51 | 52 | 53 | -------------------------------------------------------------------------------- /configs/config_unet_ct_multi_att_dsv.json: -------------------------------------------------------------------------------- 1 | { 2 | "training":{ 3 | "arch_type": "acdc_sax", 4 | "n_epochs": 1000, 5 | "save_epoch_freq": 10, 6 | "lr_policy": "step", 7 | "lr_decay_iters": 250, 8 | "batchSize": 2, 9 | "preloadData": true 10 | }, 11 | "visualisation":{ 12 | "display_port": 8099, 13 | "no_html": true, 14 | "display_winsize": 256, 15 | "display_id": 1, 16 | "display_single_pane_ncols": 0 17 | }, 18 | "data_path": { 19 | "acdc_sax": "/vol/biomedic2/oo2113/dataset/ken_abdominal_ct/" 20 | }, 21 | "augmentation": { 22 | "acdc_sax": { 23 | "shift": [0.1,0.1], 24 | "rotate": 15.0, 25 | "scale": [0.7,1.3], 26 | "intensity": [1.0,1.0], 27 | "random_flip_prob": 0.5, 28 | "scale_size": [160,160,96], 29 | "patch_size": [160,160,96] 30 | } 31 | }, 32 | "model":{ 33 | "type":"seg", 34 | "continue_train": false, 35 | "which_epoch": -1, 36 | "model_type": "unet_ct_multi_att_dsv", 37 | "tensor_dim": "3D", 38 | "division_factor": 16, 39 | "input_nc": 1, 40 | "output_nc": 4, 41 | "lr_rate": 1e-4, 42 | "l2_reg_weight": 1e-6, 43 | "feature_scale": 4, 44 | "gpu_ids": [0], 45 | "isTrain": true, 46 | "checkpoints_dir": "./checkpoints", 47 | "experiment_name": "experiment_unet_ct_multi_att_dsv", 48 | "criterion": "dice_loss" 49 | } 50 | } 51 | 52 | 53 | -------------------------------------------------------------------------------- /configs/config_sononet_8.json: -------------------------------------------------------------------------------- 1 | { 2 | "training":{ 3 | "max_it":10, 4 | "arch_type": "us", 5 | "n_epochs": 300, 6 | "save_epoch_freq": 10, 7 | "lr_policy": "step_warmstart", 8 | "lr_decay_iters": 25, 9 | "lr_red_factor": 0.1, 10 | "batchSize": 64, 11 | "preloadData": false, 12 | "num_workers" : 8, 13 | "sampler": "weighted2", 14 | "bgd_weight_multiplier": 13 15 | }, 16 | "visualisation":{ 17 | "display_port": 8181, 18 | "no_html": true, 19 | "display_winsize": 256, 20 | "display_id": 1, 21 | "display_single_pane_ncols": 0 22 | }, 23 | "data_path": { 24 | "us": "/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5" 25 | }, 26 | "augmentation": { 27 | "us": { 28 | "patch_size": [208, 272], 29 | "shift": [0.02,0.02], 30 | "rotate": 25.0, 31 | "scale": [0.7,1.3], 32 | "intensity": [1.0,1.0], 33 | "random_flip_prob": 0.5 34 | } 35 | }, 36 | "model":{ 37 | "type":"classifier", 38 | "continue_train": false, 39 | "which_epoch": 0, 40 | "model_type": "sononet2", 41 | "tensor_dim": "2D", 42 | "input_nc": 1, 43 | "output_nc": 14, 44 | "lr_rate": 0.1, 45 | "l2_reg_weight": 1e-6, 46 | "feature_scale": 8, 47 | "gpu_ids": [0], 48 | "isTrain": true, 49 | "checkpoints_dir": "./checkpoints", 50 | "experiment_name": "experiment_sononet_8" 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /dataio/loader/test_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import os 4 | 5 | from os import listdir 6 | from os.path import join 7 | from .utils import load_nifti_img, check_exceptions, is_image_file 8 | 9 | 10 | class TestDataset(data.Dataset): 11 | def __init__(self, root_dir, transform): 12 | super(TestDataset, self).__init__() 13 | image_dir = join(root_dir, 'image') 14 | self.image_filenames = sorted([join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]) 15 | 16 | # Add the corresponding ground-truth images if they exist 17 | self.label_filenames = [] 18 | label_dir = join(root_dir, 'label') 19 | if os.path.isdir(label_dir): 20 | self.label_filenames = sorted([join(label_dir, x) for x in listdir(label_dir) if is_image_file(x)]) 21 | assert len(self.label_filenames) == len(self.image_filenames) 22 | 23 | # data pre-processing 24 | self.transform = transform 25 | 26 | # report the number of images in the dataset 27 | print('Number of test images: {0} NIFTIs'.format(self.__len__())) 28 | 29 | def __getitem__(self, index): 30 | 31 | # load the NIFTI images 32 | input, input_meta = load_nifti_img(self.image_filenames[index], dtype=np.int16) 33 | 34 | # load the label image if it exists 35 | if self.label_filenames: 36 | label, _ = load_nifti_img(self.label_filenames[index], dtype=np.int16) 37 | check_exceptions(input, label) 38 | else: 39 | label = [] 40 | check_exceptions(input) 41 | 42 | # Pre-process the input 3D Nifti image 43 | input = self.transform(input) 44 | 45 | return input, input_meta, label 46 | 47 | def __len__(self): 48 | return len(self.image_filenames) -------------------------------------------------------------------------------- /configs/config_sononet_grid_att_8_ft.json: -------------------------------------------------------------------------------- 1 | { 2 | "training":{ 3 | "max_it":100, 4 | "arch_type": "us", 5 | "n_epochs": 300, 6 | "save_epoch_freq": 1, 7 | "lr_policy": "step_warmstart", 8 | "lr_decay_iters": 25, 9 | "lr_red_factor": 0.1, 10 | "batchSize": 64, 11 | "preloadData": false, 12 | "num_workers" : 8, 13 | "sampler": "weighted2", 14 | "bgd_weight_multiplier": 13 15 | }, 16 | "visualisation":{ 17 | "display_port": 8181, 18 | "no_html": true, 19 | "display_winsize": 256, 20 | "display_id": 1, 21 | "display_single_pane_ncols": 0 22 | }, 23 | "data_path": { 24 | "us": "/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5" 25 | }, 26 | "augmentation": { 27 | "us": { 28 | "patch_size": [208, 272], 29 | "shift": [0.02,0.02], 30 | "rotate": 25.0, 31 | "scale": [0.7,1.3], 32 | "intensity": [1.0,1.0], 33 | "random_flip_prob": 0.5 34 | } 35 | }, 36 | "model":{ 37 | "type":"aggregated_classifier", 38 | "criterion":"cross_entropy", 39 | "model_type": "sononet_grid_attention", 40 | "nonlocal_mode": "concatenation_mean_flow", 41 | "aggregation_mode": "ft", 42 | "weight":[1], 43 | "aggregation":"mean", 44 | "continue_train": false, 45 | "which_epoch": 0, 46 | "tensor_dim": "2D", 47 | "input_nc": 1, 48 | "output_nc": 14, 49 | "lr_rate": 0.1, 50 | "l2_reg_weight": 1e-6, 51 | "feature_scale": 8, 52 | "gpu_ids": [0], 53 | "isTrain": true, 54 | "checkpoints_dir": "./checkpoints", 55 | "experiment_name": "experiment_sononet_grid_attention_fs8_avg_v12" 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /configs/config_sononet_grid_att_8.json: -------------------------------------------------------------------------------- 1 | { 2 | "training":{ 3 | "max_it":10, 4 | "arch_type": "us", 5 | "n_epochs": 300, 6 | "save_epoch_freq": 1, 7 | "lr_policy": "step_warmstart", 8 | "lr_decay_iters": 25, 9 | "lr_red_factor": 0.1, 10 | "batchSize": 64, 11 | "preloadData": false, 12 | "num_workers" : 8, 13 | "sampler": "weighted2", 14 | "bgd_weight_multiplier": 13 15 | }, 16 | "visualisation":{ 17 | "display_port": 8181, 18 | "no_html": true, 19 | "display_winsize": 256, 20 | "display_id": 1, 21 | "display_single_pane_ncols": 0 22 | }, 23 | "data_path": { 24 | "us": "/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5" 25 | }, 26 | "augmentation": { 27 | "us": { 28 | "patch_size": [208, 272], 29 | "shift": [0.02,0.02], 30 | "rotate": 25.0, 31 | "scale": [0.7,1.3], 32 | "intensity": [1.0,1.0], 33 | "random_flip_prob": 0.5 34 | } 35 | }, 36 | "model":{ 37 | "type":"aggregated_classifier", 38 | "criterion":"cross_entropy", 39 | "model_type": "sononet_grid_attention", 40 | "nonlocal_mode": "concatenation_mean_flow", 41 | "aggregation_mode": "mean", 42 | "weight":[1, 1, 1], 43 | "aggregation":"mean", 44 | "continue_train": false, 45 | "which_epoch": 0, 46 | "tensor_dim": "2D", 47 | "input_nc": 1, 48 | "output_nc": 14, 49 | "lr_rate": 0.1, 50 | "l2_reg_weight": 1e-6, 51 | "feature_scale": 8, 52 | "gpu_ids": [0], 53 | "isTrain": true, 54 | "checkpoints_dir": "./checkpoints", 55 | "experiment_name": "experiment_sononet_grid_attention_fs8_avg_v12" 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.pyc 6 | 7 | # Saved results 8 | checkpoints/ 9 | checkpoints_2/ 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | .static_storage/ 61 | .media/ 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | *~ 112 | .#* 113 | flycheck* -------------------------------------------------------------------------------- /configs/config_sononet_grid_att_8_deepsup.json: -------------------------------------------------------------------------------- 1 | { 2 | "training":{ 3 | "max_it":100, 4 | "arch_type": "us", 5 | "n_epochs": 300, 6 | "save_epoch_freq": 1, 7 | "lr_policy": "step_warmstart", 8 | "lr_decay_iters": 25, 9 | "lr_red_factor": 0.1, 10 | "batchSize": 64, 11 | "preloadData": false, 12 | "num_workers" : 8, 13 | "sampler": "weighted2", 14 | "bgd_weight_multiplier": 13 15 | }, 16 | "visualisation":{ 17 | "display_port": 8181, 18 | "no_html": true, 19 | "display_winsize": 256, 20 | "display_id": 1, 21 | "display_single_pane_ncols": 0 22 | }, 23 | "data_path": { 24 | "us": "/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5" 25 | }, 26 | "augmentation": { 27 | "us": { 28 | "patch_size": [208, 272], 29 | "shift": [0.02,0.02], 30 | "rotate": 25.0, 31 | "scale": [0.7,1.3], 32 | "intensity": [1.0,1.0], 33 | "random_flip_prob": 0.5 34 | } 35 | }, 36 | "model":{ 37 | "type":"aggregated_classifier", 38 | "criterion":"cross_entropy", 39 | "model_type": "sononet_grid_attention", 40 | "nonlocal_mode": "concatenation_mean_flow", 41 | "aggregation_mode": "deep_sup", 42 | "weight":[1, 0.1, 0.1, 0.1], 43 | "aggregation":"idx", 44 | "aggregation_param":0, 45 | "continue_train": false, 46 | "which_epoch": 0, 47 | "tensor_dim": "2D", 48 | "input_nc": 1, 49 | "output_nc": 14, 50 | "lr_rate": 0.1, 51 | "l2_reg_weight": 1e-6, 52 | "feature_scale": 8, 53 | "gpu_ids": [0], 54 | "isTrain": true, 55 | "checkpoints_dir": "./checkpoints", 56 | "experiment_name": "experiment_sononet_grid_attention_fs8_avg_v12" 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /dataio/loader/utils.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import numpy as np 3 | import os 4 | from utils.util import mkdir 5 | 6 | def is_image_file(filename): 7 | return any(filename.endswith(extension) for extension in [".nii.gz"]) 8 | 9 | 10 | def load_nifti_img(filepath, dtype): 11 | ''' 12 | NIFTI Image Loader 13 | :param filepath: path to the input NIFTI image 14 | :param dtype: dataio type of the nifti numpy array 15 | :return: return numpy array 16 | ''' 17 | nim = nib.load(filepath) 18 | out_nii_array = np.array(nim.get_data(),dtype=dtype) 19 | out_nii_array = np.squeeze(out_nii_array) # drop singleton dim in case temporal dim exists 20 | meta = {'affine': nim.get_affine(), 21 | 'dim': nim.header['dim'], 22 | 'pixdim': nim.header['pixdim'], 23 | 'name': os.path.basename(filepath) 24 | } 25 | 26 | return out_nii_array, meta 27 | 28 | 29 | def write_nifti_img(input_nii_array, meta, savedir): 30 | mkdir(savedir) 31 | affine = meta['affine'][0].cpu().numpy() 32 | pixdim = meta['pixdim'][0].cpu().numpy() 33 | dim = meta['dim'][0].cpu().numpy() 34 | 35 | img = nib.Nifti1Image(input_nii_array, affine=affine) 36 | img.header['dim'] = dim 37 | img.header['pixdim'] = pixdim 38 | 39 | savename = os.path.join(savedir, meta['name'][0]) 40 | print('saving: ', savename) 41 | nib.save(img, savename) 42 | 43 | 44 | def check_exceptions(image, label=None): 45 | if label is not None: 46 | if image.shape != label.shape: 47 | print('Error: mismatched size, image.shape = {0}, ' 48 | 'label.shape = {1}'.format(image.shape, label.shape)) 49 | #print('Skip {0}, {1}'.format(image_name, label_name)) 50 | raise(Exception('image and label sizes do not match')) 51 | 52 | if image.max() < 1e-6: 53 | print('Error: blank image, image.max = {0}'.format(image.max())) 54 | #print('Skip {0} {1}'.format(image_name, label_name)) 55 | raise (Exception('blank image exception')) -------------------------------------------------------------------------------- /utils/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() -------------------------------------------------------------------------------- /dataio/loader/cmr_3D_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import datetime 4 | 5 | from os import listdir 6 | from os.path import join 7 | from .utils import load_nifti_img, check_exceptions, is_image_file 8 | 9 | 10 | class CMR3DDataset(data.Dataset): 11 | def __init__(self, root_dir, split, transform=None, preload_data=False): 12 | super(CMR3DDataset, self).__init__() 13 | image_dir = join(root_dir, split, 'image') 14 | target_dir = join(root_dir, split, 'label') 15 | self.image_filenames = sorted([join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]) 16 | self.target_filenames = sorted([join(target_dir, x) for x in listdir(target_dir) if is_image_file(x)]) 17 | assert len(self.image_filenames) == len(self.target_filenames) 18 | 19 | # report the number of images in the dataset 20 | print('Number of {0} images: {1} NIFTIs'.format(split, self.__len__())) 21 | 22 | # data augmentation 23 | self.transform = transform 24 | 25 | # data load into the ram memory 26 | self.preload_data = preload_data 27 | if self.preload_data: 28 | print('Preloading the {0} dataset ...'.format(split)) 29 | self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] 30 | self.raw_labels = [load_nifti_img(ii, dtype=np.uint8)[0] for ii in self.target_filenames] 31 | print('Loading is done\n') 32 | 33 | 34 | def __getitem__(self, index): 35 | # update the seed to avoid workers sample the same augmentation parameters 36 | np.random.seed(datetime.datetime.now().second + datetime.datetime.now().microsecond) 37 | 38 | # load the nifti images 39 | if not self.preload_data: 40 | input, _ = load_nifti_img(self.image_filenames[index], dtype=np.int16) 41 | target, _ = load_nifti_img(self.target_filenames[index], dtype=np.uint8) 42 | else: 43 | input = np.copy(self.raw_images[index]) 44 | target = np.copy(self.raw_labels[index]) 45 | 46 | # handle exceptions 47 | check_exceptions(input, target) 48 | if self.transform: 49 | input, target = self.transform(input, target) 50 | 51 | return input, target 52 | 53 | def __len__(self): 54 | return len(self.image_filenames) -------------------------------------------------------------------------------- /dataio/loader/ukbb_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import datetime 4 | 5 | from os import listdir 6 | from os.path import join 7 | from .utils import load_nifti_img, check_exceptions, is_image_file 8 | 9 | 10 | class UKBBDataset(data.Dataset): 11 | def __init__(self, root_dir, split, transform=None, preload_data=False): 12 | super(UKBBDataset, self).__init__() 13 | image_dir = join(root_dir, split, 'image') 14 | target_dir = join(root_dir, split, 'label') 15 | self.image_filenames = sorted([join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]) 16 | self.target_filenames = sorted([join(target_dir, x) for x in listdir(target_dir) if is_image_file(x)]) 17 | assert len(self.image_filenames) == len(self.target_filenames) 18 | 19 | # report the number of images in the dataset 20 | print('Number of {0} images: {1} NIFTIs'.format(split, self.__len__())) 21 | 22 | # data augmentation 23 | self.transform = transform 24 | 25 | # data load into the ram memory 26 | self.preload_data = preload_data 27 | if self.preload_data: 28 | print('Preloading the {0} dataset ...'.format(split)) 29 | self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames] 30 | self.raw_labels = [load_nifti_img(ii, dtype=np.uint8)[0] for ii in self.target_filenames] 31 | print('Loading is done\n') 32 | 33 | def __getitem__(self, index): 34 | # update the seed to avoid workers sample the same augmentation parameters 35 | np.random.seed(datetime.datetime.now().second + datetime.datetime.now().microsecond) 36 | 37 | # load the nifti images 38 | if not self.preload_data: 39 | input, _ = load_nifti_img(self.image_filenames[index], dtype=np.int16) 40 | target, _ = load_nifti_img(self.target_filenames[index], dtype=np.uint8) 41 | else: 42 | input = np.copy(self.raw_images[index]) 43 | target = np.copy(self.raw_labels[index]) 44 | 45 | # pass a random slice for the time being 46 | id_slice = np.random.randint(0,input.shape[2]) 47 | input = input[:,:,[id_slice]] 48 | target= target[:,:,[id_slice]] 49 | 50 | # handle exceptions 51 | check_exceptions(input, target) 52 | if self.transform: 53 | input, target = self.transform(input, target) 54 | 55 | return input, target 56 | 57 | def __len__(self): 58 | return len(self.image_filenames) -------------------------------------------------------------------------------- /dataio/loader/us_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import h5py 4 | import numpy as np 5 | import datetime 6 | 7 | from os import listdir 8 | from os.path import join 9 | #from .utils import check_exceptions 10 | 11 | 12 | class UltraSoundDataset(data.Dataset): 13 | def __init__(self, root_path, split, transform=None, preload_data=False): 14 | super(UltraSoundDataset, self).__init__() 15 | 16 | f = h5py.File(root_path) 17 | 18 | self.images = f['x_'+split] 19 | 20 | if preload_data: 21 | self.images = np.array(self.images[:]) 22 | 23 | self.labels = np.array(f['p_'+split][:], dtype=np.int64)#[:1000] 24 | self.label_names = [x.decode('utf-8') for x in f['label_names'][:].tolist()] 25 | #print(self.label_names) 26 | #print(np.unique(self.labels[:])) 27 | # construct weight for entry 28 | self.n_class = len(self.label_names) 29 | class_weight = np.zeros(self.n_class) 30 | for lab in range(self.n_class): 31 | class_weight[lab] = np.sum(self.labels[:] == lab) 32 | 33 | class_weight = 1 / class_weight 34 | 35 | self.weight = np.zeros(len(self.labels)) 36 | for i in range(len(self.labels)): 37 | self.weight[i] = class_weight[self.labels[i]] 38 | 39 | #print(class_weight) 40 | assert len(self.images) == len(self.labels) 41 | 42 | # data augmentation 43 | self.transform = transform 44 | 45 | # report the number of images in the dataset 46 | print('Number of {0} images: {1} NIFTIs'.format(split, self.__len__())) 47 | 48 | def __getitem__(self, index): 49 | # update the seed to avoid workers sample the same augmentation parameters 50 | np.random.seed(datetime.datetime.now().second + datetime.datetime.now().microsecond) 51 | 52 | # load the nifti images 53 | input = self.images[index][0] 54 | target = self.labels[index] 55 | 56 | #input = input.transpose((1,2,0)) 57 | 58 | # handle exceptions 59 | #check_exceptions(input, target) 60 | if self.transform: 61 | input = self.transform(input) 62 | 63 | #print(input.shape, torch.from_numpy(np.array([target]))) 64 | #print("target",np.int64(target)) 65 | return input, int(target) 66 | 67 | def __len__(self): 68 | return len(self.images) 69 | 70 | 71 | # if __name__ == '__main__': 72 | # dataset = UltraSoundDataset("/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5",'test') 73 | 74 | # from torch.utils.data import DataLoader, sampler 75 | # ds = DataLoader(dataset=dataset, num_workers=1, batch_size=2) 76 | -------------------------------------------------------------------------------- /models/networks/unet_2D.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from .utils import unetConv2, unetUp 4 | import torch.nn.functional as F 5 | from models.networks_other import init_weights 6 | 7 | class unet_2D(nn.Module): 8 | 9 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 10 | super(unet_2D, self).__init__() 11 | self.is_deconv = is_deconv 12 | self.in_channels = in_channels 13 | self.is_batchnorm = is_batchnorm 14 | self.feature_scale = feature_scale 15 | 16 | filters = [64, 128, 256, 512, 1024] 17 | filters = [int(x / self.feature_scale) for x in filters] 18 | 19 | # downsampling 20 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 21 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 22 | 23 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 24 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 25 | 26 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 27 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 28 | 29 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 30 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 31 | 32 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 33 | 34 | # upsampling 35 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 36 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 37 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 38 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 39 | 40 | # final conv (without any concat) 41 | self.final = nn.Conv2d(filters[0], n_classes, 1) 42 | 43 | # initialise weights 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | init_weights(m, init_type='kaiming') 47 | elif isinstance(m, nn.BatchNorm2d): 48 | init_weights(m, init_type='kaiming') 49 | 50 | 51 | def forward(self, inputs): 52 | conv1 = self.conv1(inputs) 53 | maxpool1 = self.maxpool1(conv1) 54 | 55 | conv2 = self.conv2(maxpool1) 56 | maxpool2 = self.maxpool2(conv2) 57 | 58 | conv3 = self.conv3(maxpool2) 59 | maxpool3 = self.maxpool3(conv3) 60 | 61 | conv4 = self.conv4(maxpool3) 62 | maxpool4 = self.maxpool4(conv4) 63 | 64 | center = self.center(maxpool4) 65 | up4 = self.up_concat4(conv4, center) 66 | up3 = self.up_concat3(conv3, up4) 67 | up2 = self.up_concat2(conv2, up3) 68 | up1 = self.up_concat1(conv1, up2) 69 | 70 | final = self.final(up1) 71 | 72 | return final 73 | 74 | @staticmethod 75 | def apply_argmax_softmax(pred): 76 | log_p = F.softmax(pred, dim=1) 77 | 78 | return log_p 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /models/networks/unet_3D.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from .utils import UnetConv3, UnetUp3 4 | import torch.nn.functional as F 5 | from models.networks_other import init_weights 6 | 7 | class unet_3D(nn.Module): 8 | 9 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 10 | super(unet_3D, self).__init__() 11 | self.is_deconv = is_deconv 12 | self.in_channels = in_channels 13 | self.is_batchnorm = is_batchnorm 14 | self.feature_scale = feature_scale 15 | 16 | filters = [64, 128, 256, 512, 1024] 17 | filters = [int(x / self.feature_scale) for x in filters] 18 | 19 | # downsampling 20 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm) 21 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 22 | 23 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm) 24 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 25 | 26 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm) 27 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 28 | 29 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm) 30 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 31 | 32 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm) 33 | 34 | # upsampling 35 | self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv, is_batchnorm) 36 | self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv, is_batchnorm) 37 | self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv, is_batchnorm) 38 | self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv, is_batchnorm) 39 | 40 | # final conv (without any concat) 41 | self.final = nn.Conv3d(filters[0], n_classes, 1) 42 | 43 | # initialise weights 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv3d): 46 | init_weights(m, init_type='kaiming') 47 | elif isinstance(m, nn.BatchNorm3d): 48 | init_weights(m, init_type='kaiming') 49 | 50 | def forward(self, inputs): 51 | conv1 = self.conv1(inputs) 52 | maxpool1 = self.maxpool1(conv1) 53 | 54 | conv2 = self.conv2(maxpool1) 55 | maxpool2 = self.maxpool2(conv2) 56 | 57 | conv3 = self.conv3(maxpool2) 58 | maxpool3 = self.maxpool3(conv3) 59 | 60 | conv4 = self.conv4(maxpool3) 61 | maxpool4 = self.maxpool4(conv4) 62 | 63 | center = self.center(maxpool4) 64 | up4 = self.up_concat4(conv4, center) 65 | up3 = self.up_concat3(conv3, up4) 66 | up2 = self.up_concat2(conv2, up3) 67 | up1 = self.up_concat1(conv1, up2) 68 | 69 | final = self.final(up1) 70 | 71 | return final 72 | 73 | @staticmethod 74 | def apply_argmax_softmax(pred): 75 | log_p = F.softmax(pred, dim=1) 76 | 77 | return log_p 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_2D import * 2 | from .unet_3D import * 3 | from .unet_nonlocal_2D import * 4 | from .unet_nonlocal_3D import * 5 | from .unet_grid_attention_3D import * 6 | from .unet_CT_dsv_3D import * 7 | from .unet_CT_single_att_dsv_3D import * 8 | from .unet_CT_multi_att_dsv_3D import * 9 | from .sononet import * 10 | from .sononet_grid_attention import * 11 | 12 | def get_network(name, n_classes, in_channels=3, feature_scale=4, tensor_dim='2D', 13 | nonlocal_mode='embedded_gaussian', attention_dsample=(2,2,2), 14 | aggregation_mode='concat'): 15 | model = _get_model_instance(name, tensor_dim) 16 | 17 | if name in ['unet', 'unet_ct_dsv']: 18 | model = model(n_classes=n_classes, 19 | is_batchnorm=True, 20 | in_channels=in_channels, 21 | feature_scale=feature_scale, 22 | is_deconv=False) 23 | elif name in ['unet_nonlocal']: 24 | model = model(n_classes=n_classes, 25 | is_batchnorm=True, 26 | in_channels=in_channels, 27 | is_deconv=False, 28 | nonlocal_mode=nonlocal_mode, 29 | feature_scale=feature_scale) 30 | elif name in ['unet_grid_gating', 31 | 'unet_ct_single_att_dsv', 32 | 'unet_ct_multi_att_dsv']: 33 | model = model(n_classes=n_classes, 34 | is_batchnorm=True, 35 | in_channels=in_channels, 36 | nonlocal_mode=nonlocal_mode, 37 | feature_scale=feature_scale, 38 | attention_dsample=attention_dsample, 39 | is_deconv=False) 40 | elif name in ['sononet','sononet2']: 41 | model = model(n_classes=n_classes, 42 | is_batchnorm=True, 43 | in_channels=in_channels, 44 | feature_scale=feature_scale) 45 | elif name in ['sononet_grid_attention']: 46 | model = model(n_classes=n_classes, 47 | is_batchnorm=True, 48 | in_channels=in_channels, 49 | feature_scale=feature_scale, 50 | nonlocal_mode=nonlocal_mode, 51 | aggregation_mode=aggregation_mode) 52 | else: 53 | raise 'Model {} not available'.format(name) 54 | 55 | return model 56 | 57 | 58 | def _get_model_instance(name, tensor_dim): 59 | return { 60 | 'unet':{'2D': unet_2D, '3D': unet_3D}, 61 | 'unet_nonlocal':{'2D': unet_nonlocal_2D, '3D': unet_nonlocal_3D}, 62 | 'unet_grid_gating': {'3D': unet_grid_attention_3D}, 63 | 'unet_ct_dsv': {'3D': unet_CT_dsv_3D}, 64 | 'unet_ct_single_att_dsv': {'3D': unet_CT_single_att_dsv_3D}, 65 | 'unet_ct_multi_att_dsv': {'3D': unet_CT_multi_att_dsv_3D}, 66 | 'sononet': {'2D': sononet}, 67 | 'sononet2': {'2D': sononet2}, 68 | 'sononet_grid_attention': {'2D': sononet_grid_attention} 69 | }[name][tensor_dim] 70 | -------------------------------------------------------------------------------- /models/networks/sononet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch.nn as nn 4 | from .utils import unetConv2, unetUp, conv2DBatchNormRelu, conv2DBatchNorm 5 | import torch.nn.functional as F 6 | from models.networks_other import init_weights 7 | 8 | class sononet(nn.Module): 9 | 10 | def __init__(self, feature_scale=4, n_classes=21, in_channels=3, is_batchnorm=True, n_convs=None): 11 | super(sononet, self).__init__() 12 | self.in_channels = in_channels 13 | self.is_batchnorm = is_batchnorm 14 | self.feature_scale = feature_scale 15 | self.n_classes= n_classes 16 | 17 | filters = [64, 128, 256, 512] 18 | filters = [int(x / self.feature_scale) for x in filters] 19 | 20 | if n_convs is None: 21 | n_convs = [2,2,3,3,3] 22 | 23 | # downsampling 24 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm, n=n_convs[0]) 25 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 26 | 27 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm, n=n_convs[1]) 28 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 29 | 30 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm, n=n_convs[2]) 31 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 32 | 33 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm, n=n_convs[3]) 34 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 35 | 36 | self.conv5 = unetConv2(filters[3], filters[3], self.is_batchnorm, n=n_convs[4]) 37 | 38 | # adaptation layer 39 | self.conv5_p = conv2DBatchNormRelu(filters[3], filters[2], 1, 1, 0) 40 | self.conv6_p = conv2DBatchNorm(filters[2], self.n_classes, 1, 1, 0) 41 | 42 | # initialise weights 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | init_weights(m, init_type='kaiming') 46 | elif isinstance(m, nn.BatchNorm2d): 47 | init_weights(m, init_type='kaiming') 48 | 49 | 50 | def forward(self, inputs): 51 | # Feature Extraction 52 | conv1 = self.conv1(inputs) 53 | maxpool1 = self.maxpool1(conv1) 54 | 55 | conv2 = self.conv2(maxpool1) 56 | maxpool2 = self.maxpool2(conv2) 57 | 58 | conv3 = self.conv3(maxpool2) 59 | maxpool3 = self.maxpool3(conv3) 60 | 61 | conv4 = self.conv4(maxpool3) 62 | maxpool4 = self.maxpool4(conv4) 63 | 64 | conv5 = self.conv5(maxpool4) 65 | 66 | conv5_p = self.conv5_p(conv5) 67 | conv6_p = self.conv6_p(conv5_p) 68 | 69 | batch_size = inputs.shape[0] 70 | pooled = F.adaptive_avg_pool2d(conv6_p, (1, 1)).view(batch_size, -1) 71 | 72 | return pooled 73 | 74 | 75 | @staticmethod 76 | def apply_argmax_softmax(pred): 77 | log_p = F.softmax(pred, dim=1) 78 | 79 | return log_p 80 | 81 | 82 | def sononet2(feature_scale=4, n_classes=21, in_channels=3, is_batchnorm=True): 83 | return sononet(feature_scale, n_classes, in_channels, is_batchnorm, n_convs=[3,3,3,2,2]) 84 | 85 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Abstract level model definition 2 | # Returns the model class for specified network type 3 | import os 4 | 5 | 6 | class ModelOpts: 7 | def __init__(self): 8 | self.gpu_ids = [0] 9 | self.isTrain = True 10 | self.continue_train = False 11 | self.which_epoch = int(0) 12 | self.save_dir = './checkpoints/default' 13 | self.model_type = 'unet' 14 | self.input_nc = 1 15 | self.output_nc = 4 16 | self.lr_rate = 1e-12 17 | self.l2_reg_weight = 0.0 18 | self.feature_scale = 4 19 | self.tensor_dim = '2D' 20 | self.path_pre_trained_model = None 21 | self.criterion = 'cross_entropy' 22 | self.type = 'seg' 23 | 24 | # Attention 25 | self.nonlocal_mode = 'concatenation' 26 | self.attention_dsample = (2,2,2) 27 | 28 | # Attention Classifier 29 | self.aggregation_mode = 'concatenation' 30 | 31 | 32 | def initialise(self, json_opts): 33 | opts = json_opts 34 | 35 | self.raw = json_opts 36 | self.gpu_ids = opts.gpu_ids 37 | self.isTrain = opts.isTrain 38 | self.save_dir = os.path.join(opts.checkpoints_dir, opts.experiment_name) 39 | self.model_type = opts.model_type 40 | self.input_nc = opts.input_nc 41 | self.output_nc = opts.output_nc 42 | self.continue_train = opts.continue_train 43 | self.which_epoch = opts.which_epoch 44 | 45 | if hasattr(opts, 'type'): self.type = opts.type 46 | if hasattr(opts, 'l2_reg_weight'): self.l2_reg_weight = opts.l2_reg_weight 47 | if hasattr(opts, 'lr_rate'): self.lr_rate = opts.lr_rate 48 | if hasattr(opts, 'feature_scale'): self.feature_scale = opts.feature_scale 49 | if hasattr(opts, 'tensor_dim'): self.tensor_dim = opts.tensor_dim 50 | 51 | if hasattr(opts, 'path_pre_trained_model'): self.path_pre_trained_model = opts.path_pre_trained_model 52 | if hasattr(opts, 'criterion'): self.criterion = opts.criterion 53 | 54 | if hasattr(opts, 'nonlocal_mode'): self.nonlocal_mode = opts.nonlocal_mode 55 | if hasattr(opts, 'attention_dsample'): self.attention_dsample = opts.attention_dsample 56 | # Classifier 57 | if hasattr(opts, 'aggregation_mode'): self.aggregation_mode = opts.aggregation_mode 58 | 59 | def get_model(json_opts): 60 | 61 | # Neural Network Model Initialisation 62 | model = None 63 | model_opts = ModelOpts() 64 | model_opts.initialise(json_opts) 65 | 66 | # Print the model type 67 | print('\nInitialising model {}'.format(model_opts.model_type)) 68 | 69 | model_type = model_opts.type 70 | if model_type == 'seg': 71 | # Return the model type 72 | from .feedforward_seg_model import FeedForwardSegmentation 73 | model = FeedForwardSegmentation() 74 | 75 | elif model_type == 'classifier': 76 | # Return the model type 77 | from .feedforward_classifier import FeedForwardClassifier 78 | model = FeedForwardClassifier() 79 | 80 | elif model_type == 'aggregated_classifier': 81 | # Return the model type 82 | from .aggregated_classifier import AggregatedClassifier 83 | model = AggregatedClassifier() 84 | 85 | 86 | # Initialise the created model 87 | model.initialize(model_opts) 88 | print("Model [%s] is created" % (model.name())) 89 | 90 | return model 91 | -------------------------------------------------------------------------------- /models/networks/unet_nonlocal_2D.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from .utils import unetConv2, unetUp 4 | from models.layers.nonlocal_layer import NONLocalBlock2D 5 | import torch.nn.functional as F 6 | 7 | 8 | class unet_nonlocal_2D(nn.Module): 9 | 10 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, 11 | is_batchnorm=True, nonlocal_mode='embedded_gaussian', nonlocal_sf=4): 12 | super(unet_nonlocal_2D, self).__init__() 13 | self.is_deconv = is_deconv 14 | self.in_channels = in_channels 15 | self.is_batchnorm = is_batchnorm 16 | self.feature_scale = feature_scale 17 | 18 | filters = [64, 128, 256, 512, 1024] 19 | filters = [int(x / self.feature_scale) for x in filters] 20 | 21 | # downsampling 22 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) 23 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 24 | self.nonlocal1 = NONLocalBlock2D(in_channels=filters[0], inter_channels=filters[0] // 4, 25 | sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) 26 | 27 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) 28 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 29 | self.nonlocal2 = NONLocalBlock2D(in_channels=filters[1], inter_channels=filters[1] // 4, 30 | sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) 31 | 32 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) 33 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 34 | 35 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) 36 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 37 | 38 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) 39 | 40 | # upsampling 41 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) 42 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) 43 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) 44 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) 45 | 46 | # final conv (without any concat) 47 | self.final = nn.Conv2d(filters[0], n_classes, 1) 48 | 49 | # initialise weights 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | elif isinstance(m, nn.BatchNorm2d): 55 | m.weight.data.fill_(1) 56 | m.bias.data.zero_() 57 | 58 | def forward(self, inputs): 59 | conv1 = self.conv1(inputs) 60 | maxpool1 = self.maxpool1(conv1) 61 | nonlocal1 = self.nonlocal1(maxpool1) 62 | 63 | conv2 = self.conv2(nonlocal1) 64 | maxpool2 = self.maxpool2(conv2) 65 | nonlocal2 = self.nonlocal2(maxpool2) 66 | 67 | conv3 = self.conv3(nonlocal2) 68 | maxpool3 = self.maxpool3(conv3) 69 | 70 | conv4 = self.conv4(maxpool3) 71 | maxpool4 = self.maxpool4(conv4) 72 | 73 | center = self.center(maxpool4) 74 | up4 = self.up_concat4(conv4, center) 75 | up3 = self.up_concat3(conv3, up4) 76 | up2 = self.up_concat2(conv2, up3) 77 | up1 = self.up_concat1(conv1, up2) 78 | 79 | final = self.final(up1) 80 | 81 | return final 82 | 83 | @staticmethod 84 | def apply_argmax_softmax(pred): 85 | log_p = F.softmax(pred, dim=1) 86 | 87 | return log_p 88 | -------------------------------------------------------------------------------- /models/networks/unet_nonlocal_3D.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from .utils import UnetConv3, UnetUp3 4 | import torch.nn.functional as F 5 | from models.layers.nonlocal_layer import NONLocalBlock3D 6 | from models.networks_other import init_weights 7 | 8 | class unet_nonlocal_3D(nn.Module): 9 | 10 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True, 11 | nonlocal_mode='embedded_gaussian', nonlocal_sf=4): 12 | super(unet_nonlocal_3D, self).__init__() 13 | self.is_deconv = is_deconv 14 | self.in_channels = in_channels 15 | self.is_batchnorm = is_batchnorm 16 | self.feature_scale = feature_scale 17 | 18 | filters = [64, 128, 256, 512, 1024] 19 | filters = [int(x / self.feature_scale) for x in filters] 20 | 21 | # downsampling 22 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm) 23 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 24 | 25 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm) 26 | self.nonlocal2 = NONLocalBlock3D(in_channels=filters[1], inter_channels=filters[1] // 4, 27 | sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) 28 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 29 | 30 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm) 31 | self.nonlocal3 = NONLocalBlock3D(in_channels=filters[2], inter_channels=filters[2] // 4, 32 | sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) 33 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 34 | 35 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm) 36 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 37 | 38 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm) 39 | 40 | # upsampling 41 | self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv) 42 | self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv) 43 | self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv) 44 | self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv) 45 | 46 | # final conv (without any concat) 47 | self.final = nn.Conv3d(filters[0], n_classes, 1) 48 | 49 | # initialise weights 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv3d): 52 | init_weights(m, init_type='kaiming') 53 | elif isinstance(m, nn.BatchNorm3d): 54 | init_weights(m, init_type='kaiming') 55 | 56 | def forward(self, inputs): 57 | conv1 = self.conv1(inputs) 58 | maxpool1 = self.maxpool1(conv1) 59 | 60 | conv2 = self.conv2(maxpool1) 61 | nl2 = self.nonlocal2(conv2) 62 | maxpool2 = self.maxpool2(nl2) 63 | 64 | conv3 = self.conv3(maxpool2) 65 | nl3 = self.nonlocal3(conv3) 66 | maxpool3 = self.maxpool3(nl3) 67 | 68 | conv4 = self.conv4(maxpool3) 69 | maxpool4 = self.maxpool4(conv4) 70 | 71 | center = self.center(maxpool4) 72 | up4 = self.up_concat4(conv4, center) 73 | up3 = self.up_concat3(nl3, up4) 74 | up2 = self.up_concat2(nl2, up3) 75 | up1 = self.up_concat1(conv1, up2) 76 | 77 | final = self.final(up1) 78 | 79 | return final 80 | 81 | @staticmethod 82 | def apply_argmax_softmax(pred): 83 | log_p = F.softmax(pred, dim=1) 84 | 85 | return log_p 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /visualise_fmaps.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from dataio.loader import get_dataset, get_dataset_path 4 | from dataio.transformation import get_dataset_transformation 5 | from utils.util import json_file_to_pyobj 6 | from models import get_model 7 | 8 | import matplotlib.cm as cm 9 | import matplotlib.pyplot as plt 10 | import math, numpy, os 11 | from scipy.misc import imresize 12 | from skimage.transform import resize 13 | from dataio.loader.utils import write_nifti_img 14 | from torch.nn import functional as F 15 | 16 | def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None): 17 | plt.ion() 18 | filters = units.shape[2] 19 | n_columns = round(math.sqrt(filters)) 20 | n_rows = math.ceil(filters / n_columns) + 1 21 | fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3)) 22 | fig.clf() 23 | 24 | for i in range(filters): 25 | ax1 = plt.subplot(n_rows, n_columns, i+1) 26 | plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap) 27 | plt.axis('on') 28 | ax1.set_xticklabels([]) 29 | ax1.set_yticklabels([]) 30 | plt.colorbar() 31 | if colormap_lim: 32 | plt.clim(colormap_lim[0],colormap_lim[1]) 33 | 34 | plt.subplots_adjust(wspace=0, hspace=0) 35 | plt.tight_layout() 36 | 37 | # Load options 38 | json_opts = json_file_to_pyobj('/vol/biomedic2/oo2113/projects/syntAI/ukbb_pytorch/configs_final/debug_ct.json') 39 | 40 | # Setup the NN Model 41 | model = get_model(json_opts.model) 42 | 43 | # Setup Dataset and Augmentation 44 | dataset_class = get_dataset('test_sax') 45 | dataset_path = get_dataset_path('test_sax', json_opts.data_path) 46 | dataset_transform = get_dataset_transformation('test_sax', json_opts.augmentation) 47 | 48 | # Setup Data Loader 49 | dataset = dataset_class(dataset_path, transform=dataset_transform['test']) 50 | data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, shuffle=False) 51 | 52 | # test 53 | for iteration, (input_arr, input_meta, _) in enumerate(data_loader, 1): 54 | model.set_input(input_arr) 55 | layer_name = 'attentionblock1' 56 | inp_fmap, out_fmap = model.get_feature_maps(layer_name=layer_name, upscale=False) 57 | 58 | # Display the input image and Down_sample the input image 59 | orig_input_img = model.input.permute(2, 3, 4, 1, 0).cpu().numpy() 60 | upsampled_attention = F.upsample(out_fmap[1], size=input_arr.size()[2:], mode='trilinear').data.squeeze().permute(1,2,3,0).cpu().numpy() 61 | upsampled_fmap_before = F.upsample(inp_fmap[0], size=input_arr.size()[2:], mode='trilinear').data.squeeze().permute(1,2,3,0).cpu().numpy() 62 | upsampled_fmap_after = F.upsample(out_fmap[2], size=input_arr.size()[2:], mode='trilinear').data.squeeze().permute(1,2,3,0).cpu().numpy() 63 | 64 | # Define the directories 65 | save_directory = os.path.join('/vol/bitbucket/oo2113/tmp/feature_maps', layer_name) 66 | basename = input_meta['name'][0].split('.')[0] 67 | 68 | # Write the attentions to a nifti image 69 | input_meta['name'][0] = basename + '_img.nii.gz' 70 | write_nifti_img(orig_input_img, input_meta, savedir=save_directory) 71 | 72 | input_meta['name'][0] = basename + '_att.nii.gz' 73 | write_nifti_img(upsampled_attention, input_meta, savedir=save_directory) 74 | 75 | input_meta['name'][0] = basename + '_fmap_before.nii.gz' 76 | write_nifti_img(upsampled_fmap_before, input_meta, savedir=save_directory) 77 | 78 | input_meta['name'][0] = basename + '_fmap_after.nii.gz' 79 | write_nifti_img(upsampled_fmap_after, input_meta, savedir=save_directory) 80 | 81 | model.destructor() 82 | #if iteration == 1: break -------------------------------------------------------------------------------- /visualise_att_maps_epoch.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from dataio.loader import get_dataset, get_dataset_path 4 | from dataio.transformation import get_dataset_transformation 5 | from utils.util import json_file_to_pyobj 6 | from models import get_model 7 | 8 | import matplotlib.cm as cm 9 | import matplotlib.pyplot as plt 10 | import math, numpy, os 11 | from dataio.loader.utils import write_nifti_img 12 | from torch.nn import functional as F 13 | 14 | 15 | def mkdirfun(directory): 16 | if not os.path.exists(directory): 17 | os.makedirs(directory) 18 | 19 | 20 | def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None): 21 | plt.ion() 22 | filters = units.shape[2] 23 | n_columns = round(math.sqrt(filters)) 24 | n_rows = math.ceil(filters / n_columns) + 1 25 | fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3)) 26 | fig.clf() 27 | 28 | for i in range(filters): 29 | ax1 = plt.subplot(n_rows, n_columns, i+1) 30 | plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap) 31 | plt.axis('on') 32 | ax1.set_xticklabels([]) 33 | ax1.set_yticklabels([]) 34 | plt.colorbar() 35 | if colormap_lim: 36 | plt.clim(colormap_lim[0],colormap_lim[1]) 37 | 38 | plt.subplots_adjust(wspace=0, hspace=0) 39 | plt.tight_layout() 40 | 41 | # Epochs 42 | layer_name = 'attentionblock2' 43 | layer_save_directory = os.path.join('/vol/bitbucket/oo2113/tmp/attention_maps', layer_name); mkdirfun(layer_save_directory) 44 | epochs = range(225, 230, 3) 45 | att_maps = list() 46 | int_imgs = list() 47 | subject_id = int(2) 48 | for epoch in epochs: 49 | 50 | # Load options and replace the epoch attribute 51 | json_opts = json_file_to_pyobj('/vol/biomedic2/oo2113/projects/syntAI/ukbb_pytorch/configs_final/debug_ct.json') 52 | json_opts = json_opts._replace(model=json_opts.model._replace(which_epoch=epoch)) 53 | 54 | # Setup the NN Model 55 | model = get_model(json_opts.model) 56 | 57 | # Setup Dataset and Augmentation 58 | dataset_class = get_dataset('test_sax') 59 | dataset_path = get_dataset_path('test_sax', json_opts.data_path) 60 | dataset_transform = get_dataset_transformation('test_sax', json_opts.augmentation) 61 | 62 | # Setup Data Loader 63 | dataset = dataset_class(dataset_path, transform=dataset_transform['test']) 64 | data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, shuffle=False) 65 | 66 | # test 67 | for iteration, (input_arr, input_meta, _) in enumerate(data_loader, 1): 68 | # look for the subject_id 69 | if iteration == subject_id: 70 | # load the input image into the model 71 | model.set_input(input_arr) 72 | inp_fmap, out_fmap = model.get_feature_maps(layer_name=layer_name, upscale=False) 73 | 74 | # Display the input image and Down_sample the input image 75 | orig_input_img = model.input.permute(2, 3, 4, 1, 0).cpu().numpy() 76 | upsampled_attention = F.upsample(out_fmap[1], size=input_arr.size()[2:], mode='trilinear').data.squeeze().permute(1,2,3,0).cpu().numpy() 77 | 78 | # Append it to the list 79 | int_imgs.append(orig_input_img[:,:,:,0,0]) 80 | att_maps.append(upsampled_attention[:,:,:,1]) 81 | 82 | # return the model 83 | model.destructor() 84 | 85 | # Write the attentions to a nifti image 86 | input_meta['name'][0] = str(subject_id) + '_img_2.nii.gz' 87 | int_imgs = numpy.array(int_imgs).transpose([1,2,3,0]) 88 | write_nifti_img(int_imgs, input_meta, savedir=layer_save_directory) 89 | 90 | input_meta['name'][0] = str(subject_id) + '_att_2.nii.gz' 91 | att_maps = numpy.array(att_maps).transpose([1,2,3,0]) 92 | write_nifti_img(att_maps, input_meta, savedir=layer_save_directory) 93 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy 3 | import torch 4 | from utils.util import mkdir 5 | from .networks_other import get_n_parameters 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | self.input = None 10 | self.net = None 11 | self.isTrain = False 12 | self.use_cuda = True 13 | self.schedulers = [] 14 | self.optimizers = [] 15 | self.save_dir = None 16 | self.gpu_ids = [] 17 | self.which_epoch = int(0) 18 | self.path_pre_trained_model = None 19 | 20 | def name(self): 21 | return 'BaseModel' 22 | 23 | def initialize(self, opt, **kwargs): 24 | self.gpu_ids = opt.gpu_ids 25 | self.isTrain = opt.isTrain 26 | self.ImgTensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 27 | self.LblTensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 28 | self.save_dir = opt.save_dir; mkdir(self.save_dir) 29 | 30 | def set_input(self, input): 31 | self.input = input 32 | 33 | def set_scheduler(self, train_opt): 34 | pass 35 | 36 | def forward(self, split): 37 | pass 38 | 39 | # used in test time, no backprop 40 | def test(self): 41 | pass 42 | 43 | def get_image_paths(self): 44 | pass 45 | 46 | def optimize_parameters(self): 47 | pass 48 | 49 | def get_current_visuals(self): 50 | return self.input 51 | 52 | def get_current_errors(self): 53 | return {} 54 | 55 | def get_input_size(self): 56 | return self.input.size() if input else None 57 | 58 | def save(self, label): 59 | pass 60 | 61 | # helper saving function that can be used by subclasses 62 | def save_network(self, network, network_label, epoch_label, gpu_ids): 63 | print('Saving the model {0} at the end of epoch {1}'.format(network_label, epoch_label)) 64 | save_filename = '{0:03d}_net_{1}.pth'.format(epoch_label, network_label) 65 | save_path = os.path.join(self.save_dir, save_filename) 66 | torch.save(network.cpu().state_dict(), save_path) 67 | if len(gpu_ids) and torch.cuda.is_available(): 68 | network.cuda(gpu_ids[0]) 69 | 70 | # helper loading function that can be used by subclasses 71 | def load_network(self, network, network_label, epoch_label): 72 | print('Loading the model {0} - epoch {1}'.format(network_label, epoch_label)) 73 | save_filename = '{0:03d}_net_{1}.pth'.format(epoch_label, network_label) 74 | save_path = os.path.join(self.save_dir, save_filename) 75 | network.load_state_dict(torch.load(save_path)) 76 | 77 | def load_network_from_path(self, network, network_filepath, strict): 78 | network_label = os.path.basename(network_filepath) 79 | epoch_label = network_label.split('_')[0] 80 | print('Loading the model {0} - epoch {1}'.format(network_label, epoch_label)) 81 | network.load_state_dict(torch.load(network_filepath), strict=strict) 82 | 83 | # update learning rate (called once every epoch) 84 | def update_learning_rate(self, metric=None, epoch=None): 85 | for scheduler in self.schedulers: 86 | if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 87 | scheduler.step(metrics=metric) 88 | else: 89 | scheduler.step() 90 | lr = self.optimizers[0].param_groups[0]['lr'] 91 | print('current learning rate = %.7f' % lr) 92 | 93 | # returns the number of trainable parameters 94 | def get_number_parameters(self): 95 | return get_n_parameters(self.net) 96 | 97 | # clean up the GPU memory 98 | def destructor(self): 99 | del self.net 100 | del self.input 101 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from PIL import Image 4 | import inspect, re 5 | import numpy as np 6 | import os 7 | import collections 8 | import json 9 | import csv 10 | from skimage.exposure import rescale_intensity 11 | 12 | # Converts a Tensor into a Numpy array 13 | # |imtype|: the desired type of the converted numpy array 14 | def tensor2im(image_tensor, imgtype='img', datatype=np.uint8): 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.ndim == 4:# image_numpy (C x W x H x S) 17 | mid_slice = image_numpy.shape[-1]//2 18 | image_numpy = image_numpy[:,:,:,mid_slice] 19 | if image_numpy.shape[0] == 1: 20 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 21 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) 22 | if imgtype == 'img': 23 | image_numpy = (image_numpy + 8) / 16.0 * 255.0 24 | if np.unique(image_numpy).size == int(1): 25 | return image_numpy.astype(datatype) 26 | return rescale_intensity(image_numpy.astype(datatype)) 27 | 28 | 29 | def diagnose_network(net, name='network'): 30 | mean = 0.0 31 | count = 0 32 | for param in net.parameters(): 33 | if param.grad is not None: 34 | mean += torch.mean(torch.abs(param.grad.data)) 35 | count += 1 36 | if count > 0: 37 | mean = mean / count 38 | print(name) 39 | print(mean) 40 | 41 | 42 | def save_image(image_numpy, image_path): 43 | image_pil = Image.fromarray(image_numpy) 44 | image_pil.save(image_path) 45 | 46 | 47 | def info(object, spacing=10, collapse=1): 48 | """Print methods and doc strings. 49 | Takes module, class, list, dictionary, or string.""" 50 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] 51 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) 52 | print( "\n".join(["%s %s" % 53 | (method.ljust(spacing), 54 | processFunc(str(getattr(object, method).__doc__))) 55 | for method in methodList]) ) 56 | 57 | def varname(p): 58 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: 59 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) 60 | if m: 61 | return m.group(1) 62 | 63 | def print_numpy(x, val=True, shp=False): 64 | x = x.astype(np.float64) 65 | if shp: 66 | print('shape,', x.shape) 67 | if val: 68 | x = x.flatten() 69 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 70 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 71 | 72 | 73 | def mkdirs(paths): 74 | if isinstance(paths, list) and not isinstance(paths, str): 75 | for path in paths: 76 | mkdir(path) 77 | else: 78 | mkdir(paths) 79 | 80 | 81 | def mkdir(path): 82 | if not os.path.exists(path): 83 | os.makedirs(path) 84 | 85 | 86 | def json_file_to_pyobj(filename): 87 | def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values()) 88 | def json2obj(data): return json.loads(data, object_hook=_json_object_hook) 89 | return json2obj(open(filename).read()) 90 | 91 | 92 | def determine_crop_size(inp_shape, div_factor): 93 | div_factor= np.array(div_factor, dtype=np.float32) 94 | new_shape = np.ceil(np.divide(inp_shape, div_factor)) * div_factor 95 | pre_pad = np.round((new_shape - inp_shape) / 2.0).astype(np.int16) 96 | post_pad = ((new_shape - inp_shape) - pre_pad).astype(np.int16) 97 | return pre_pad, post_pad 98 | 99 | 100 | def csv_write(out_filename, in_header_list, in_val_list): 101 | with open(out_filename, 'w') as f: 102 | writer = csv.writer(f) 103 | writer.writerow(in_header_list) 104 | writer.writerows(zip(*in_val_list)) 105 | -------------------------------------------------------------------------------- /models/networks/unet_CT_dsv_3D.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import UnetConv3, UnetUp3_CT, UnetDsv3 3 | import torch.nn.functional as F 4 | from models.networks_other import init_weights 5 | import torch 6 | 7 | class unet_CT_dsv_3D(nn.Module): 8 | 9 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 10 | super(unet_CT_dsv_3D, self).__init__() 11 | self.is_deconv = is_deconv 12 | self.in_channels = in_channels 13 | self.is_batchnorm = is_batchnorm 14 | self.feature_scale = feature_scale 15 | 16 | filters = [64, 128, 256, 512, 1024] 17 | filters = [int(x / self.feature_scale) for x in filters] 18 | 19 | # downsampling 20 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 21 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 22 | 23 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 24 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 25 | 26 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 27 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 28 | 29 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 30 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 31 | 32 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 33 | 34 | # upsampling 35 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 36 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 37 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 38 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 39 | 40 | # deep supervision 41 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) 42 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) 43 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) 44 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) 45 | 46 | # final conv (without any concat) 47 | self.final = nn.Conv3d(n_classes*4, n_classes, 1) 48 | 49 | 50 | # initialise weights 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv3d): 53 | init_weights(m, init_type='kaiming') 54 | elif isinstance(m, nn.BatchNorm3d): 55 | init_weights(m, init_type='kaiming') 56 | 57 | def forward(self, inputs): 58 | conv1 = self.conv1(inputs) 59 | maxpool1 = self.maxpool1(conv1) 60 | 61 | conv2 = self.conv2(maxpool1) 62 | maxpool2 = self.maxpool2(conv2) 63 | 64 | conv3 = self.conv3(maxpool2) 65 | maxpool3 = self.maxpool3(conv3) 66 | 67 | conv4 = self.conv4(maxpool3) 68 | maxpool4 = self.maxpool4(conv4) 69 | 70 | center = self.center(maxpool4) 71 | up4 = self.up_concat4(conv4, center) 72 | up3 = self.up_concat3(conv3, up4) 73 | up2 = self.up_concat2(conv2, up3) 74 | up1 = self.up_concat1(conv1, up2) 75 | 76 | # Deep Supervision 77 | dsv4 = self.dsv4(up4) 78 | dsv3 = self.dsv3(up3) 79 | dsv2 = self.dsv2(up2) 80 | dsv1 = self.dsv1(up1) 81 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1)) 82 | 83 | return final 84 | 85 | @staticmethod 86 | def apply_argmax_softmax(pred): 87 | log_p = F.softmax(pred, dim=1) 88 | 89 | return log_p 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /models/aggregated_classifier.py: -------------------------------------------------------------------------------- 1 | import os, collections 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from .feedforward_classifier import FeedForwardClassifier 6 | 7 | 8 | class AggregatedClassifier(FeedForwardClassifier): 9 | def name(self): 10 | return 'AggregatedClassifier' 11 | 12 | def initialize(self, opts, **kwargs): 13 | FeedForwardClassifier.initialize(self, opts, **kwargs) 14 | 15 | weight = self.opts.raw.weight[:] # copy 16 | weight_t = torch.from_numpy(np.array(weight, dtype=np.float32)) 17 | self.weight = weight 18 | self.aggregation = opts.raw.aggregation 19 | self.aggregation_param = opts.raw.aggregation_param 20 | self.aggregation_weight = Variable(weight_t, volatile=True).view(-1,1,1).cuda() 21 | 22 | def compute_loss(self): 23 | """Compute loss function. Iterate over multiple output""" 24 | preds = self.predictions 25 | weights = self.weight 26 | if not isinstance(preds, collections.Sequence): 27 | preds = [preds] 28 | weights = [1] 29 | 30 | loss = 0 31 | for lmda, prediction in zip(weights, preds): 32 | if lmda == 0: 33 | continue 34 | loss += lmda * self.criterion(prediction, self.target) 35 | 36 | self.loss = loss 37 | 38 | def aggregate_output(self): 39 | """Given a list of predictions from net, make a decision based on aggreagation rule""" 40 | if isinstance(self.predictions, collections.Sequence): 41 | logits = [] 42 | for pred in self.predictions: 43 | logit = self.net.apply_argmax_softmax(pred).unsqueeze(0) 44 | logits.append(logit) 45 | 46 | logits = torch.cat(logits, 0) 47 | if self.aggregation == 'max': 48 | self.pred = logits.data.max(0)[0].max(1) 49 | elif self.aggregation == 'mean': 50 | self.pred = logits.data.mean(0).max(1) 51 | elif self.aggregation == 'weighted_mean': 52 | self.pred = (self.aggregation_weight.expand_as(logits) * logits).data.mean(0).max(1) 53 | elif self.aggregation == 'idx': 54 | self.pred = logits[self.aggregation_param].data.max(1) 55 | else: 56 | # Apply a softmax and return a segmentation map 57 | self.logits = self.net.apply_argmax_softmax(self.predictions) 58 | self.pred = self.logits.data.max(1) 59 | 60 | 61 | def forward(self, split): 62 | if split == 'train': 63 | self.predictions = self.net(Variable(self.input)) 64 | elif split in ['validation', 'test']: 65 | self.predictions = self.net(Variable(self.input, volatile=True)) 66 | self.aggregate_output() 67 | 68 | def backward(self): 69 | self.compute_loss() 70 | self.loss.backward() 71 | 72 | def validate(self): 73 | self.net.eval() 74 | self.forward(split='test') 75 | self.compute_loss() 76 | self.accumulate_results() 77 | 78 | def update_state(self, epoch): 79 | """ A function that is called at the end of every epoch. Can adjust state of the network here. 80 | For example, if one wants to change the loss weights for prediction during training (e.g. deep supervision), """ 81 | if hasattr(self.opts.raw,'late_gate'): 82 | if epoch < self.opts.raw.late_gate: 83 | self.weight[0] = 0 84 | self.weight[1] = 0 85 | print('='*10,'weight={}'.format(self.weight), '='*10) 86 | if epoch == self.opts.raw.late_gate: 87 | self.weight = self.opts.raw.weight[:] 88 | weight_t = torch.from_numpy(np.array(self.weight, dtype=np.float32)) 89 | self.aggregation_weight = Variable(weight_t,volatile=True).view(-1,1,1).cuda() 90 | print('='*10,'weight={}'.format(self.weight), '='*10) 91 | -------------------------------------------------------------------------------- /models/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.loss import _Loss 5 | from torch.autograd import Function, Variable 6 | 7 | def cross_entropy_2D(input, target, weight=None, size_average=True): 8 | n, c, h, w = input.size() 9 | log_p = F.log_softmax(input, dim=1) 10 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 11 | target = target.view(target.numel()) 12 | loss = F.nll_loss(log_p, target, weight=weight, size_average=False) 13 | if size_average: 14 | loss /= float(target.numel()) 15 | return loss 16 | 17 | 18 | def cross_entropy_3D(input, target, weight=None, size_average=True): 19 | n, c, h, w, s = input.size() 20 | log_p = F.log_softmax(input, dim=1) 21 | log_p = log_p.transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous().view(-1, c) 22 | target = target.view(target.numel()) 23 | loss = F.nll_loss(log_p, target, weight=weight, size_average=False) 24 | if size_average: 25 | loss /= float(target.numel()) 26 | return loss 27 | 28 | 29 | class SoftDiceLoss(nn.Module): 30 | def __init__(self, n_classes): 31 | super(SoftDiceLoss, self).__init__() 32 | self.one_hot_encoder = One_Hot(n_classes).forward 33 | self.n_classes = n_classes 34 | 35 | def forward(self, input, target): 36 | smooth = 0.01 37 | batch_size = input.size(0) 38 | 39 | input = F.softmax(input, dim=1).view(batch_size, self.n_classes, -1) 40 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1) 41 | 42 | inter = torch.sum(input * target, 2) + smooth 43 | union = torch.sum(input, 2) + torch.sum(target, 2) + smooth 44 | 45 | score = torch.sum(2.0 * inter / union) 46 | score = 1.0 - score / (float(batch_size) * float(self.n_classes)) 47 | 48 | return score 49 | 50 | 51 | class CustomSoftDiceLoss(nn.Module): 52 | def __init__(self, n_classes, class_ids): 53 | super(CustomSoftDiceLoss, self).__init__() 54 | self.one_hot_encoder = One_Hot(n_classes).forward 55 | self.n_classes = n_classes 56 | self.class_ids = class_ids 57 | 58 | def forward(self, input, target): 59 | smooth = 0.01 60 | batch_size = input.size(0) 61 | 62 | input = F.softmax(input[:,self.class_ids], dim=1).view(batch_size, len(self.class_ids), -1) 63 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1) 64 | target = target[:, self.class_ids, :] 65 | 66 | inter = torch.sum(input * target, 2) + smooth 67 | union = torch.sum(input, 2) + torch.sum(target, 2) + smooth 68 | 69 | score = torch.sum(2.0 * inter / union) 70 | score = 1.0 - score / (float(batch_size) * float(self.n_classes)) 71 | 72 | return score 73 | 74 | 75 | class One_Hot(nn.Module): 76 | def __init__(self, depth): 77 | super(One_Hot, self).__init__() 78 | self.depth = depth 79 | self.ones = torch.sparse.torch.eye(depth).cuda() 80 | 81 | def forward(self, X_in): 82 | n_dim = X_in.dim() 83 | output_size = X_in.size() + torch.Size([self.depth]) 84 | num_element = X_in.numel() 85 | X_in = X_in.data.long().view(num_element) 86 | out = Variable(self.ones.index_select(0, X_in)).view(output_size) 87 | return out.permute(0, -1, *range(1, n_dim)).squeeze(dim=2).float() 88 | 89 | def __repr__(self): 90 | return self.__class__.__name__ + "({})".format(self.depth) 91 | 92 | 93 | if __name__ == '__main__': 94 | from torch.autograd import Variable 95 | depth=3 96 | batch_size=2 97 | encoder = One_Hot(depth=depth).forward 98 | y = Variable(torch.LongTensor(batch_size, 1, 1, 2 ,2).random_() % depth).cuda() # 4 classes,1x3x3 img 99 | y_onehot = encoder(y) 100 | x = Variable(torch.randn(y_onehot.size()).float()).cuda() 101 | dicemetric = SoftDiceLoss(n_classes=depth) 102 | dicemetric(x,y) -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from dataio.loader import get_dataset, get_dataset_path 4 | from dataio.transformation import get_dataset_transformation 5 | from utils.util import json_file_to_pyobj 6 | 7 | from models import get_model 8 | import numpy as np 9 | import os 10 | from utils.metrics import dice_score, distance_metric, precision_and_recall 11 | from utils.error_logger import StatLogger 12 | 13 | 14 | def mkdirfun(directory): 15 | if not os.path.exists(directory): 16 | os.makedirs(directory) 17 | 18 | 19 | def validation(json_name): 20 | # Load options 21 | json_opts = json_file_to_pyobj(json_name) 22 | train_opts = json_opts.training 23 | 24 | # Setup the NN Model 25 | model = get_model(json_opts.model) 26 | save_directory = os.path.join(model.save_dir, train_opts.arch_type); mkdirfun(save_directory) 27 | 28 | # Setup Dataset and Augmentation 29 | dataset_class = get_dataset(train_opts.arch_type) 30 | dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path) 31 | dataset_transform = get_dataset_transformation(train_opts.arch_type, opts=json_opts.augmentation) 32 | 33 | # Setup Data Loader 34 | dataset = dataset_class(dataset_path, split='validation', transform=dataset_transform['valid']) 35 | data_loader = DataLoader(dataset=dataset, num_workers=8, batch_size=1, shuffle=False) 36 | 37 | # Visualisation Parameters 38 | #visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir) 39 | 40 | # Setup stats logger 41 | stat_logger = StatLogger() 42 | 43 | # test 44 | for iteration, data in enumerate(data_loader, 1): 45 | model.set_input(data[0], data[1]) 46 | model.test() 47 | 48 | input_arr = np.squeeze(data[0].cpu().numpy()).astype(np.float32) 49 | label_arr = np.squeeze(data[1].cpu().numpy()).astype(np.int16) 50 | output_arr = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype(np.int16) 51 | 52 | # If there is a label image - compute statistics 53 | dice_vals = dice_score(label_arr, output_arr, n_class=int(4)) 54 | md, hd = distance_metric(label_arr, output_arr, dx=2.00, k=2) 55 | precision, recall = precision_and_recall(label_arr, output_arr, n_class=int(4)) 56 | stat_logger.update(split='test', input_dict={'img_name': '', 57 | 'dice_LV': dice_vals[1], 58 | 'dice_MY': dice_vals[2], 59 | 'dice_RV': dice_vals[3], 60 | 'prec_MYO':precision[2], 61 | 'reca_MYO':recall[2], 62 | 'md_MYO': md, 63 | 'hd_MYO': hd 64 | }) 65 | 66 | # Write a nifti image 67 | import SimpleITK as sitk 68 | input_img = sitk.GetImageFromArray(np.transpose(input_arr, (2, 1, 0))); input_img.SetDirection([-1,0,0,0,-1,0,0,0,1]) 69 | label_img = sitk.GetImageFromArray(np.transpose(label_arr, (2, 1, 0))); label_img.SetDirection([-1,0,0,0,-1,0,0,0,1]) 70 | predi_img = sitk.GetImageFromArray(np.transpose(output_arr,(2, 1, 0))); predi_img.SetDirection([-1,0,0,0,-1,0,0,0,1]) 71 | 72 | sitk.WriteImage(input_img, os.path.join(save_directory,'{}_img.nii.gz'.format(iteration))) 73 | sitk.WriteImage(label_img, os.path.join(save_directory,'{}_lbl.nii.gz'.format(iteration))) 74 | sitk.WriteImage(predi_img, os.path.join(save_directory,'{}_pred.nii.gz'.format(iteration))) 75 | 76 | stat_logger.statlogger2csv(split='test', out_csv_name=os.path.join(save_directory,'stats.csv')) 77 | for key, (mean_val, std_val) in stat_logger.get_errors(split='test').items(): 78 | print('-',key,': \t{0:.3f}+-{1:.3f}'.format(mean_val, std_val),'-') 79 | 80 | 81 | if __name__ == '__main__': 82 | import argparse 83 | 84 | parser = argparse.ArgumentParser(description='CNN Seg Validation Function') 85 | 86 | parser.add_argument('-c', '--config', help='testing config file', required=True) 87 | args = parser.parse_args() 88 | 89 | validation(args.config) 90 | -------------------------------------------------------------------------------- /models/networks/unet_grid_attention_3D.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import UnetConv3, UnetUp3, UnetGridGatingSignal3 3 | import torch.nn.functional as F 4 | from models.layers.grid_attention_layer import GridAttentionBlock3D 5 | from models.networks_other import init_weights 6 | 7 | class unet_grid_attention_3D(nn.Module): 8 | 9 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, 10 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True): 11 | super(unet_grid_attention_3D, self).__init__() 12 | self.is_deconv = is_deconv 13 | self.in_channels = in_channels 14 | self.is_batchnorm = is_batchnorm 15 | self.feature_scale = feature_scale 16 | 17 | filters = [64, 128, 256, 512, 1024] 18 | filters = [int(x / self.feature_scale) for x in filters] 19 | 20 | # downsampling 21 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm) 22 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 23 | 24 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm) 25 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 26 | 27 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm) 28 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 29 | 30 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm) 31 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1)) 32 | 33 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm) 34 | self.gating = UnetGridGatingSignal3(filters[4], filters[3], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) 35 | 36 | # attention blocks 37 | self.attentionblock2 = GridAttentionBlock3D(in_channels=filters[1], gating_channels=filters[3], 38 | inter_channels=filters[1], sub_sample_factor=attention_dsample, mode=nonlocal_mode) 39 | self.attentionblock3 = GridAttentionBlock3D(in_channels=filters[2], gating_channels=filters[3], 40 | inter_channels=filters[2], sub_sample_factor=attention_dsample, mode=nonlocal_mode) 41 | self.attentionblock4 = GridAttentionBlock3D(in_channels=filters[3], gating_channels=filters[3], 42 | inter_channels=filters[3], sub_sample_factor=attention_dsample, mode=nonlocal_mode) 43 | 44 | # upsampling 45 | self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv, self.is_batchnorm) 46 | self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv, self.is_batchnorm) 47 | self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv, self.is_batchnorm) 48 | self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv, self.is_batchnorm) 49 | 50 | # final conv (without any concat) 51 | self.final = nn.Conv3d(filters[0], n_classes, 1) 52 | 53 | # initialise weights 54 | for m in self.modules(): 55 | if isinstance(m, nn.Conv3d): 56 | init_weights(m, init_type='kaiming') 57 | elif isinstance(m, nn.BatchNorm3d): 58 | init_weights(m, init_type='kaiming') 59 | 60 | def forward(self, inputs): 61 | # Feature Extraction 62 | conv1 = self.conv1(inputs) 63 | maxpool1 = self.maxpool1(conv1) 64 | 65 | conv2 = self.conv2(maxpool1) 66 | maxpool2 = self.maxpool2(conv2) 67 | 68 | conv3 = self.conv3(maxpool2) 69 | maxpool3 = self.maxpool3(conv3) 70 | 71 | conv4 = self.conv4(maxpool3) 72 | maxpool4 = self.maxpool4(conv4) 73 | 74 | # Gating Signal Generation 75 | center = self.center(maxpool4) 76 | gating = self.gating(center) 77 | 78 | # Attention Mechanism 79 | g_conv4, att4 = self.attentionblock4(conv4, gating) 80 | g_conv3, att3 = self.attentionblock3(conv3, gating) 81 | g_conv2, att2 = self.attentionblock2(conv2, gating) 82 | 83 | # Upscaling Part (Decoder) 84 | up4 = self.up_concat4(g_conv4, center) 85 | up3 = self.up_concat3(g_conv3, up4) 86 | up2 = self.up_concat2(g_conv2, up3) 87 | up1 = self.up_concat1(conv1, up2) 88 | 89 | final = self.final(up1) 90 | 91 | return final 92 | 93 | @staticmethod 94 | def apply_argmax_softmax(pred): 95 | log_p = F.softmax(pred, dim=1) 96 | 97 | return log_p 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /utils/error_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .util import csv_write 3 | 4 | 5 | class BaseMeter(object): 6 | """Just a place holderb""" 7 | 8 | def __init__(self, name): 9 | self.reset() 10 | self.name = name 11 | 12 | def reset(self): 13 | pass 14 | 15 | def update(self, val): 16 | self.val = val 17 | 18 | def get_value(self): 19 | return self.val 20 | 21 | 22 | class AverageMeter(object): 23 | """Computes and stores the average and current value""" 24 | 25 | def __init__(self, name): 26 | self.reset() 27 | self.name = name 28 | 29 | def reset(self): 30 | self.val = 0.0 31 | self.avg = 0.0 32 | self.sum = 0.0 33 | self.count = 0.0 34 | 35 | def update(self, val, n=1.0): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | def get_value(self): 42 | return self.avg 43 | 44 | class StatMeter(object): 45 | """Computes and stores the error vals and image names""" 46 | 47 | def __init__(self, name, csv_name=None): 48 | self.reset() 49 | self.name = name 50 | 51 | def reset(self): 52 | self.vals = [] 53 | self.img_names = [] 54 | 55 | def update(self, val, img_name): 56 | self.vals.append(val) 57 | self.img_names.append(img_name) 58 | 59 | def return_average(self): 60 | values_array = np.array(self.vals, dtype=np.float) 61 | return np.nanmean(values_array) 62 | 63 | def return_std(self): 64 | values_array = np.array(self.vals, dtype=np.float) 65 | return np.nanstd(values_array) 66 | 67 | 68 | class ErrorLogger(object): 69 | 70 | def __init__(self): 71 | self.variables = {'train': dict(), 72 | 'validation': dict(), 73 | 'test': dict() 74 | } 75 | 76 | def update(self, input_dict, split): 77 | 78 | for key, value in input_dict.items(): 79 | if key not in self.variables[split]: 80 | if np.isscalar(value): 81 | self.variables[split][key] = AverageMeter(name=key) 82 | else: 83 | self.variables[split][key] = BaseMeter(name=key) 84 | 85 | self.variables[split][key].update(value) 86 | 87 | 88 | def get_errors(self, split): 89 | output = dict() 90 | for key, meter_obj in self.variables[split].items(): 91 | output[key] = meter_obj.get_value() 92 | return output 93 | 94 | def reset(self): 95 | for key, meter_obj in self.variables['train'].items(): 96 | meter_obj.reset() 97 | for key, meter_obj in self.variables['validation'].items(): 98 | meter_obj.reset() 99 | for key, meter_obj in self.variables['test'].items(): 100 | meter_obj.reset() 101 | 102 | 103 | class StatLogger(object): 104 | 105 | def __init__(self): 106 | self.variables = {'train': dict(), 107 | 'validation': dict(), 108 | 'test': dict() 109 | } 110 | 111 | def update(self, input_dict, split): 112 | img_name = input_dict.pop('img_name', None) 113 | for key, value in input_dict.items(): 114 | if key not in self.variables[split]: 115 | self.variables[split][key] = StatMeter(name=key) 116 | self.variables[split][key].update(val=value, img_name=img_name) 117 | 118 | def get_errors(self, split): 119 | output = dict() 120 | for key, meter_obj in self.variables[split].items(): 121 | output[key] = (meter_obj.return_average(), meter_obj.return_std()) 122 | return output 123 | 124 | def statlogger2csv(self, split, out_csv_name): 125 | csv_values = []; csv_header = [] 126 | for loopId, (meter_key, meter_obj) in enumerate(self.variables[split].items(), 1): 127 | if loopId == 1: csv_values.append(meter_obj.img_names); csv_header.append('img_names') 128 | csv_values.append(meter_obj.vals) 129 | csv_header.append(meter_key) 130 | csv_write(out_csv_name, csv_header, csv_values) 131 | 132 | def reset(self): 133 | for key, meter_obj in self.variables['train'].items(): 134 | meter_obj.reset() 135 | for key, meter_obj in self.variables['validation'].items(): 136 | meter_obj.reset() 137 | for key, meter_obj in self.variables['test'].items(): 138 | meter_obj.reset() 139 | -------------------------------------------------------------------------------- /train_segmentation.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from torch.utils.data import DataLoader 3 | from tqdm import tqdm 4 | 5 | 6 | from dataio.loader import get_dataset, get_dataset_path 7 | from dataio.transformation import get_dataset_transformation 8 | from utils.util import json_file_to_pyobj 9 | from utils.visualiser import Visualiser 10 | from utils.error_logger import ErrorLogger 11 | 12 | from models import get_model 13 | 14 | def train(arguments): 15 | 16 | # Parse input arguments 17 | json_filename = arguments.config 18 | network_debug = arguments.debug 19 | 20 | # Load options 21 | json_opts = json_file_to_pyobj(json_filename) 22 | train_opts = json_opts.training 23 | 24 | # Architecture type 25 | arch_type = train_opts.arch_type 26 | 27 | # Setup Dataset and Augmentation 28 | ds_class = get_dataset(arch_type) 29 | ds_path = get_dataset_path(arch_type, json_opts.data_path) 30 | ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation) 31 | 32 | # Setup the NN Model 33 | model = get_model(json_opts.model) 34 | if network_debug: 35 | print('# of pars: ', model.get_number_parameters()) 36 | print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(*model.get_fp_bp_time())) 37 | exit() 38 | 39 | # Setup Data Loader 40 | train_dataset = ds_class(ds_path, split='train', transform=ds_transform['train'], preload_data=train_opts.preloadData) 41 | valid_dataset = ds_class(ds_path, split='validation', transform=ds_transform['valid'], preload_data=train_opts.preloadData) 42 | test_dataset = ds_class(ds_path, split='test', transform=ds_transform['valid'], preload_data=train_opts.preloadData) 43 | train_loader = DataLoader(dataset=train_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=True) 44 | valid_loader = DataLoader(dataset=valid_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=False) 45 | test_loader = DataLoader(dataset=test_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=False) 46 | 47 | # Visualisation Parameters 48 | visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir) 49 | error_logger = ErrorLogger() 50 | 51 | # Training Function 52 | model.set_scheduler(train_opts) 53 | for epoch in range(model.which_epoch, train_opts.n_epochs): 54 | print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader))) 55 | 56 | # Training Iterations 57 | for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1), total=len(train_loader)): 58 | # Make a training update 59 | model.set_input(images, labels) 60 | model.optimize_parameters() 61 | #model.optimize_parameters_accumulate_grd(epoch_iter) 62 | 63 | # Error visualisation 64 | errors = model.get_current_errors() 65 | error_logger.update(errors, split='train') 66 | 67 | # Validation and Testing Iterations 68 | for loader, split in zip([valid_loader, test_loader], ['validation', 'test']): 69 | for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1), total=len(loader)): 70 | 71 | # Make a forward pass with the model 72 | model.set_input(images, labels) 73 | model.validate() 74 | 75 | # Error visualisation 76 | errors = model.get_current_errors() 77 | stats = model.get_segmentation_stats() 78 | error_logger.update({**errors, **stats}, split=split) 79 | 80 | # Visualise predictions 81 | visuals = model.get_current_visuals() 82 | visualizer.display_current_results(visuals, epoch=epoch, save_result=False) 83 | 84 | # Update the plots 85 | for split in ['train', 'validation', 'test']: 86 | visualizer.plot_current_errors(epoch, error_logger.get_errors(split), split_name=split) 87 | visualizer.print_current_errors(epoch, error_logger.get_errors(split), split_name=split) 88 | error_logger.reset() 89 | 90 | # Save the model parameters 91 | if epoch % train_opts.save_epoch_freq == 0: 92 | model.save(epoch) 93 | 94 | # Update the model learning rate 95 | model.update_learning_rate() 96 | 97 | 98 | if __name__ == '__main__': 99 | import argparse 100 | 101 | parser = argparse.ArgumentParser(description='CNN Seg Training Function') 102 | 103 | parser.add_argument('-c', '--config', help='training config file', required=True) 104 | parser.add_argument('-d', '--debug', help='returns number of parameters and bp/fp runtime', action='store_true') 105 | args = parser.parse_args() 106 | 107 | train(args) 108 | -------------------------------------------------------------------------------- /utils/post_process_crf.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import numpy as np, nibabel as nib 3 | 4 | import pydensecrf.densecrf as dcrf 5 | from pydensecrf.utils import create_pairwise_bilateral, create_pairwise_gaussian 6 | 7 | 8 | def apply_crf(input_image, input_prob, theta_a, theta_b, theta_r, mu1, mu2): 9 | n_slices = input_image.shape[2] 10 | output = np.zeros(input_image.shape) 11 | for slice_id in range(n_slices): 12 | image = input_image[:,:,slice_id] 13 | prob = input_prob[:,:,slice_id,:] 14 | 15 | n_pixel = image.shape[0] * image.shape[1] 16 | n_class = prob.shape[-1] 17 | 18 | P = np.transpose(prob, axes=(2, 0, 1)) 19 | 20 | # Setup the CRF model 21 | d = dcrf.DenseCRF(n_pixel, n_class) 22 | 23 | # Set unary potentials (negative log probability) 24 | U = - np.log(P + 1e-10) 25 | U = np.ascontiguousarray(U.reshape((n_class, n_pixel))) 26 | d.setUnaryEnergy(U) 27 | 28 | # Set edge potential 29 | # This creates the color-dependent features and then add them to the CRF 30 | feats = create_pairwise_bilateral(sdims=(theta_a, theta_a), schan=(theta_b,), img=image, chdim=-1) 31 | d.addPairwiseEnergy(feats, compat=mu1, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 32 | 33 | # This creates the color-independent features and then add them to the CRF 34 | feats = create_pairwise_gaussian(sdims=(theta_r, theta_r), shape=image.shape) 35 | d.addPairwiseEnergy(feats, compat=mu2, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 36 | 37 | # Perform the inference 38 | Q = d.inference(5) 39 | res = np.argmax(Q, axis=0).astype('float32') 40 | res = np.reshape(res, image.shape).astype(dtype='int8') 41 | output[:,:,slice_id] = res 42 | 43 | return output 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--n_train', metavar='int', nargs=1, default=['80'], help='number of training subjects') 49 | args = parser.parse_args() 50 | 51 | # Data path 52 | data_path = '/vol/medic02/users/wbai/data/cardiac_atlas/Biobank' 53 | data_list = sorted(os.listdir(data_path)) 54 | dest_path = '/vol/bitbucket/wbai/cardiac_cnn/Biobank/seg' 55 | 56 | # Model name 57 | size = 224 58 | n_train = int(args.n_train[0]) 59 | model_name = 'FCN_VGG16_sz{0}_n{1}_2d'.format(size, n_train) 60 | epoch = 500 61 | 62 | # model_name = 'FCN_VGG16_sz{0}_prob_atlas_stepwise'.format(size) 63 | # epoch = 200 64 | 65 | # model_name = 'FCN_VGG16_sz{0}_auto_context_stepwise'.format(size) 66 | # epoch = 200 67 | 68 | for data in data_list: 69 | print(data) 70 | data_dir = os.path.join(data_path, data) 71 | dest_dir = os.path.join(dest_path, data) 72 | 73 | # tune_dir = os.path.join(dest_dir, 'tune') 74 | # if not os.path.exists(tune_dir): 75 | # os.mkdir(tune_dir) 76 | 77 | for fr in ['ED', 'ES']: 78 | # Read image 79 | nim = nib.load(os.path.join(data_dir, 'image_{0}.nii.gz'.format(fr))) 80 | image = np.squeeze(nim.get_data()) 81 | 82 | # Scale the intensity to be [0, 1] so that we can set a consistent intensity parameter for CRF 83 | #image = intensity_rescaling(image, 1, 99) 84 | 85 | # Read probability map 86 | nim = nib.load(os.path.join(dest_dir, 'prob_{0}_{1}_epoch{2:03d}.nii.gz'.format(fr, model_name, epoch))) 87 | prob = nim.get_data() 88 | 89 | # Apply CRF 90 | mu1 = 1 91 | theta_a = 0.5 92 | theta_b = 1 93 | mu2 = 2 94 | theta_r = 1 95 | seg = apply_crf(image, prob, theta_a, theta_b, theta_r, mu1, mu2) 96 | 97 | # Save the CRF segmentation 98 | seg_name = os.path.join(dest_dir, 'seg_{0}_{1}_epoch{2:03d}_crf.nii.gz'.format(fr, model_name, epoch)) 99 | nib.save(nib.Nifti1Image(seg, nim.affine), seg_name) 100 | 101 | # For parameter tuning 102 | # nib.save(nib.Nifti1Image(seg, nim.affine), os.path.join(tune_dir, 'seg_{0}_crf_mu2{1}_sr{2}.nii.gz'.format(fr, mu2, theta_r))) 103 | # nib.save(nib.Nifti1Image(seg, nim.affine), os.path.join(tune_dir, 'seg_{0}_crf_mu1{1}_sa{2:.1f}_sb{1}.nii.gz'.format(fr, mu1, theta_a, theta_b))) 104 | 105 | # # Fit to the template 106 | # template_dir = '/vol/medic02/users/wbai/data/imperial_atlas/template' 107 | # par_dir = '/vol/vipdata/data/biobank/cardiac/Application_18545/par' 108 | # out_name = os.path.join(dest_dir, 'seg_{0}_{1}_epoch{2:03d}_crf_fit.nii.gz'.format(fr, model_name, epoch)) 109 | # fit_to_template(seg_name, fr, template_dir, par_dir, out_name) -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Misc Utility functions 3 | ''' 4 | 5 | import os 6 | import numpy as np 7 | import torch.optim as optim 8 | from torch.nn import CrossEntropyLoss 9 | from utils.metrics import segmentation_scores, dice_score_list 10 | from sklearn import metrics 11 | from .layers.loss import * 12 | 13 | def get_optimizer(option, params): 14 | opt_alg = 'sgd' if not hasattr(option, 'optim') else option.optim 15 | if opt_alg == 'sgd': 16 | optimizer = optim.SGD(params, 17 | lr=option.lr_rate, 18 | momentum=0.9, 19 | nesterov=True, 20 | weight_decay=option.l2_reg_weight) 21 | 22 | if opt_alg == 'adam': 23 | optimizer = optim.Adam(params, 24 | lr=option.lr_rate, 25 | betas=(0.9, 0.999), 26 | weight_decay=option.l2_reg_weight) 27 | 28 | return optimizer 29 | 30 | 31 | def get_criterion(opts): 32 | if opts.criterion == 'cross_entropy': 33 | if opts.type == 'seg': 34 | criterion = cross_entropy_2D if opts.tensor_dim == '2D' else cross_entropy_3D 35 | elif 'classifier' in opts.type: 36 | criterion = CrossEntropyLoss() 37 | elif opts.criterion == 'dice_loss': 38 | criterion = SoftDiceLoss(opts.output_nc) 39 | elif opts.criterion == 'dice_loss_pancreas_only': 40 | criterion = CustomSoftDiceLoss(opts.output_nc, class_ids=[0, 2]) 41 | 42 | return criterion 43 | 44 | def recursive_glob(rootdir='.', suffix=''): 45 | """Performs recursive glob with given suffix and rootdir 46 | :param rootdir is the root directory 47 | :param suffix is the suffix to be searched 48 | """ 49 | return [os.path.join(looproot, filename) 50 | for looproot, _, filenames in os.walk(rootdir) 51 | for filename in filenames if filename.endswith(suffix)] 52 | 53 | def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9,): 54 | """Polynomial decay of learning rate 55 | :param init_lr is base learning rate 56 | :param iter is a current iteration 57 | :param lr_decay_iter how frequently decay occurs, default is 1 58 | :param max_iter is number of maximum iterations 59 | :param power is a polymomial power 60 | 61 | """ 62 | if iter % lr_decay_iter or iter > max_iter: 63 | return optimizer 64 | 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = init_lr*(1 - iter/max_iter)**power 67 | 68 | 69 | def adjust_learning_rate(optimizer, init_lr, epoch): 70 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 71 | lr = init_lr * (0.1 ** (epoch // 30)) 72 | for param_group in optimizer.param_groups: 73 | param_group['lr'] = lr 74 | 75 | 76 | def segmentation_stats(pred_seg, target): 77 | n_classes = pred_seg.size(1) 78 | pred_lbls = pred_seg.data.max(1)[1].cpu().numpy() 79 | gt = np.squeeze(target.data.cpu().numpy(), axis=1) 80 | gts, preds = [], [] 81 | for gt_, pred_ in zip(gt, pred_lbls): 82 | gts.append(gt_) 83 | preds.append(pred_) 84 | 85 | iou = segmentation_scores(gts, preds, n_class=n_classes) 86 | dice = dice_score_list(gts, preds, n_class=n_classes) 87 | 88 | return iou, dice 89 | 90 | 91 | def classification_scores(gts, preds, labels): 92 | accuracy = metrics.accuracy_score(gts, preds) 93 | class_accuracies = [] 94 | for lab in labels: # TODO Fix 95 | class_accuracies.append(metrics.accuracy_score(gts[gts == lab], preds[gts == lab])) 96 | class_accuracies = np.array(class_accuracies) 97 | 98 | f1_micro = metrics.f1_score(gts, preds, average='micro') 99 | precision_micro = metrics.precision_score(gts, preds, average='micro') 100 | recall_micro = metrics.recall_score(gts, preds, average='micro') 101 | f1_macro = metrics.f1_score(gts, preds, average='macro') 102 | precision_macro = metrics.precision_score(gts, preds, average='macro') 103 | recall_macro = metrics.recall_score(gts, preds, average='macro') 104 | 105 | # class wise score 106 | f1s = metrics.f1_score(gts, preds, average=None) 107 | precisions = metrics.precision_score(gts, preds, average=None) 108 | recalls = metrics.recall_score(gts, preds, average=None) 109 | 110 | confusion = metrics.confusion_matrix(gts,preds, labels=labels) 111 | 112 | #TODO confusion matrix, recall, precision 113 | return accuracy, f1_micro, precision_micro, recall_micro, f1_macro, precision_macro, recall_macro, confusion, class_accuracies, f1s, precisions, recalls 114 | 115 | 116 | def classification_stats(pred_seg, target, labels): 117 | return classification_scores(target, pred_seg, labels) 118 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Originally written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | import cv2 6 | 7 | def _fast_hist(label_true, label_pred, n_class): 8 | mask = (label_true >= 0) & (label_true < n_class) 9 | hist = np.bincount( 10 | n_class * label_true[mask].astype(int) + 11 | label_pred[mask], minlength=n_class**2).reshape(n_class, n_class) 12 | return hist 13 | 14 | 15 | def segmentation_scores(label_trues, label_preds, n_class): 16 | """Returns accuracy score evaluation result. 17 | - overall accuracy 18 | - mean accuracy 19 | - mean IU 20 | - fwavacc 21 | """ 22 | hist = np.zeros((n_class, n_class)) 23 | for lt, lp in zip(label_trues, label_preds): 24 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 25 | acc = np.diag(hist).sum() / hist.sum() 26 | acc_cls = np.diag(hist) / hist.sum(axis=1) 27 | acc_cls = np.nanmean(acc_cls) 28 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 29 | mean_iu = np.nanmean(iu) 30 | freq = hist.sum(axis=1) / hist.sum() 31 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 32 | 33 | return {'overall_acc': acc, 34 | 'mean_acc': acc_cls, 35 | 'freq_w_acc': fwavacc, 36 | 'mean_iou': mean_iu} 37 | 38 | 39 | def dice_score_list(label_gt, label_pred, n_class): 40 | """ 41 | 42 | :param label_gt: [WxH] (2D images) 43 | :param label_pred: [WxH] (2D images) 44 | :param n_class: number of label classes 45 | :return: 46 | """ 47 | epsilon = 1.0e-6 48 | assert len(label_gt) == len(label_pred) 49 | batchSize = len(label_gt) 50 | dice_scores = np.zeros((batchSize, n_class), dtype=np.float32) 51 | for batch_id, (l_gt, l_pred) in enumerate(zip(label_gt, label_pred)): 52 | for class_id in range(n_class): 53 | img_A = np.array(l_gt == class_id, dtype=np.float32).flatten() 54 | img_B = np.array(l_pred == class_id, dtype=np.float32).flatten() 55 | score = 2.0 * np.sum(img_A * img_B) / (np.sum(img_A) + np.sum(img_B) + epsilon) 56 | dice_scores[batch_id, class_id] = score 57 | 58 | return np.mean(dice_scores, axis=0) 59 | 60 | 61 | def dice_score(label_gt, label_pred, n_class): 62 | """ 63 | 64 | :param label_gt: 65 | :param label_pred: 66 | :param n_class: 67 | :return: 68 | """ 69 | 70 | epsilon = 1.0e-6 71 | assert np.all(label_gt.shape == label_pred.shape) 72 | dice_scores = np.zeros(n_class, dtype=np.float32) 73 | for class_id in range(n_class): 74 | img_A = np.array(label_gt == class_id, dtype=np.float32).flatten() 75 | img_B = np.array(label_pred == class_id, dtype=np.float32).flatten() 76 | score = 2.0 * np.sum(img_A * img_B) / (np.sum(img_A) + np.sum(img_B) + epsilon) 77 | dice_scores[class_id] = score 78 | 79 | return dice_scores 80 | 81 | 82 | def precision_and_recall(label_gt, label_pred, n_class): 83 | from sklearn.metrics import precision_score, recall_score 84 | assert len(label_gt) == len(label_pred) 85 | precision = np.zeros(n_class, dtype=np.float32) 86 | recall = np.zeros(n_class, dtype=np.float32) 87 | img_A = np.array(label_gt, dtype=np.float32).flatten() 88 | img_B = np.array(label_pred, dtype=np.float32).flatten() 89 | precision[:] = precision_score(img_A, img_B, average=None, labels=range(n_class)) 90 | recall[:] = recall_score(img_A, img_B, average=None, labels=range(n_class)) 91 | 92 | return precision, recall 93 | 94 | 95 | def distance_metric(seg_A, seg_B, dx, k): 96 | """ 97 | Measure the distance errors between the contours of two segmentations. 98 | The manual contours are drawn on 2D slices. 99 | We calculate contour to contour distance for each slice. 100 | """ 101 | 102 | # Extract the label k from the segmentation maps to generate binary maps 103 | seg_A = (seg_A == k) 104 | seg_B = (seg_B == k) 105 | 106 | table_md = [] 107 | table_hd = [] 108 | X, Y, Z = seg_A.shape 109 | for z in range(Z): 110 | # Binary mask at this slice 111 | slice_A = seg_A[:, :, z].astype(np.uint8) 112 | slice_B = seg_B[:, :, z].astype(np.uint8) 113 | 114 | # The distance is defined only when both contours exist on this slice 115 | if np.sum(slice_A) > 0 and np.sum(slice_B) > 0: 116 | # Find contours and retrieve all the points 117 | _, contours, _ = cv2.findContours(cv2.inRange(slice_A, 1, 1), 118 | cv2.RETR_EXTERNAL, 119 | cv2.CHAIN_APPROX_NONE) 120 | pts_A = contours[0] 121 | for i in range(1, len(contours)): 122 | pts_A = np.vstack((pts_A, contours[i])) 123 | 124 | _, contours, _ = cv2.findContours(cv2.inRange(slice_B, 1, 1), 125 | cv2.RETR_EXTERNAL, 126 | cv2.CHAIN_APPROX_NONE) 127 | pts_B = contours[0] 128 | for i in range(1, len(contours)): 129 | pts_B = np.vstack((pts_B, contours[i])) 130 | 131 | # Distance matrix between point sets 132 | M = np.zeros((len(pts_A), len(pts_B))) 133 | for i in range(len(pts_A)): 134 | for j in range(len(pts_B)): 135 | M[i, j] = np.linalg.norm(pts_A[i, 0] - pts_B[j, 0]) 136 | 137 | # Mean distance and hausdorff distance 138 | md = 0.5 * (np.mean(np.min(M, axis=0)) + np.mean(np.min(M, axis=1))) * dx 139 | hd = np.max([np.max(np.min(M, axis=0)), np.max(np.min(M, axis=1))]) * dx 140 | table_md += [md] 141 | table_hd += [hd] 142 | 143 | # Return the mean distance and Hausdorff distance across 2D slices 144 | mean_md = np.mean(table_md) if table_md else None 145 | mean_hd = np.mean(table_hd) if table_hd else None 146 | return mean_md, mean_hd -------------------------------------------------------------------------------- /models/networks/sononet_grid_attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch.nn as nn 4 | from .utils import unetConv2, unetUp, conv2DBatchNormRelu, conv2DBatchNorm 5 | import torch 6 | import torch.nn.functional as F 7 | from models.layers.grid_attention_layer import GridAttentionBlock2D_TORR as AttentionBlock2D 8 | from models.networks_other import init_weights 9 | 10 | class sononet_grid_attention(nn.Module): 11 | 12 | def __init__(self, feature_scale=4, n_classes=21, in_channels=3, is_batchnorm=True, n_convs=None, 13 | nonlocal_mode='concatenation', aggregation_mode='concat'): 14 | super(sononet_grid_attention, self).__init__() 15 | self.in_channels = in_channels 16 | self.is_batchnorm = is_batchnorm 17 | self.feature_scale = feature_scale 18 | self.n_classes= n_classes 19 | self.aggregation_mode = aggregation_mode 20 | self.deep_supervised = True 21 | 22 | if n_convs is None: 23 | n_convs = [3, 3, 3, 2, 2] 24 | 25 | filters = [64, 128, 256, 512] 26 | filters = [int(x / self.feature_scale) for x in filters] 27 | 28 | #################### 29 | # Feature Extraction 30 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm, n=n_convs[0]) 31 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 32 | 33 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm, n=n_convs[1]) 34 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 35 | 36 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm, n=n_convs[2]) 37 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 38 | 39 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm, n=n_convs[3]) 40 | self.maxpool4 = nn.MaxPool2d(kernel_size=2) 41 | 42 | self.conv5 = unetConv2(filters[3], filters[3], self.is_batchnorm, n=n_convs[4]) 43 | 44 | ################ 45 | # Attention Maps 46 | self.compatibility_score1 = AttentionBlock2D(in_channels=filters[2], gating_channels=filters[3], 47 | inter_channels=filters[3], sub_sample_factor=(1,1), 48 | mode=nonlocal_mode, use_W=False, use_phi=True, 49 | use_theta=True, use_psi=True, nonlinearity1='relu') 50 | 51 | self.compatibility_score2 = AttentionBlock2D(in_channels=filters[3], gating_channels=filters[3], 52 | inter_channels=filters[3], sub_sample_factor=(1,1), 53 | mode=nonlocal_mode, use_W=False, use_phi=True, 54 | use_theta=True, use_psi=True, nonlinearity1='relu') 55 | 56 | ######################### 57 | # Aggreagation Strategies 58 | self.attention_filter_sizes = [filters[2], filters[3]] 59 | 60 | if aggregation_mode == 'concat': 61 | self.classifier = nn.Linear(filters[2]+filters[3]+filters[3], n_classes) 62 | self.aggregate = self.aggreagation_concat 63 | 64 | else: 65 | self.classifier1 = nn.Linear(filters[2], n_classes) 66 | self.classifier2 = nn.Linear(filters[3], n_classes) 67 | self.classifier3 = nn.Linear(filters[3], n_classes) 68 | self.classifiers = [self.classifier1, self.classifier2, self.classifier3] 69 | 70 | if aggregation_mode == 'mean': 71 | self.aggregate = self.aggregation_sep 72 | 73 | elif aggregation_mode == 'deep_sup': 74 | self.classifier = nn.Linear(filters[2] + filters[3] + filters[3], n_classes) 75 | self.aggregate = self.aggregation_ds 76 | 77 | elif aggregation_mode == 'ft': 78 | self.classifier = nn.Linear(n_classes*3, n_classes) 79 | self.aggregate = self.aggregation_ft 80 | else: 81 | raise NotImplementedError 82 | 83 | #################### 84 | # initialise weights 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | init_weights(m, init_type='kaiming') 88 | elif isinstance(m, nn.BatchNorm2d): 89 | init_weights(m, init_type='kaiming') 90 | 91 | 92 | def aggregation_sep(self, *attended_maps): 93 | return [ clf(att) for clf, att in zip(self.classifiers, attended_maps) ] 94 | 95 | def aggregation_ft(self, *attended_maps): 96 | preds = self.aggregation_sep(*attended_maps) 97 | return self.classifier(torch.cat(preds, dim=1)) 98 | 99 | def aggregation_ds(self, *attended_maps): 100 | preds_sep = self.aggregation_sep(*attended_maps) 101 | pred = self.aggregation_concat(*attended_maps) 102 | return [pred] + preds_sep 103 | 104 | def aggregation_concat(self, *attended_maps): 105 | return self.classifier(torch.cat(attended_maps, dim=1)) 106 | 107 | 108 | def forward(self, inputs): 109 | # Feature Extraction 110 | conv1 = self.conv1(inputs) 111 | maxpool1 = self.maxpool1(conv1) 112 | 113 | conv2 = self.conv2(maxpool1) 114 | maxpool2 = self.maxpool2(conv2) 115 | 116 | conv3 = self.conv3(maxpool2) 117 | maxpool3 = self.maxpool3(conv3) 118 | 119 | conv4 = self.conv4(maxpool3) 120 | maxpool4 = self.maxpool4(conv4) 121 | 122 | conv5 = self.conv5(maxpool4) 123 | 124 | batch_size = inputs.shape[0] 125 | pooled = F.adaptive_avg_pool2d(conv5, (1, 1)).view(batch_size, -1) 126 | 127 | # Attention Mechanism 128 | g_conv1, att1 = self.compatibility_score1(conv3, conv5) 129 | g_conv2, att2 = self.compatibility_score2(conv4, conv5) 130 | 131 | # flatten to get single feature vector 132 | fsizes = self.attention_filter_sizes 133 | g1 = torch.sum(g_conv1.view(batch_size, fsizes[0], -1), dim=-1) 134 | g2 = torch.sum(g_conv2.view(batch_size, fsizes[1], -1), dim=-1) 135 | 136 | return self.aggregate(g1, g2, pooled) 137 | 138 | 139 | @staticmethod 140 | def apply_argmax_softmax(pred): 141 | log_p = F.softmax(pred, dim=1) 142 | 143 | return log_p 144 | -------------------------------------------------------------------------------- /models/networks/unet_CT_single_att_dsv_3D.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .utils import UnetConv3, UnetUp3_CT, UnetGridGatingSignal3, UnetDsv3 4 | import torch.nn.functional as F 5 | from models.networks_other import init_weights 6 | from models.layers.grid_attention_layer import GridAttentionBlock3D 7 | 8 | 9 | class unet_CT_single_att_dsv_3D(nn.Module): 10 | 11 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, 12 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True): 13 | super(unet_CT_single_att_dsv_3D, self).__init__() 14 | self.is_deconv = is_deconv 15 | self.in_channels = in_channels 16 | self.is_batchnorm = is_batchnorm 17 | self.feature_scale = feature_scale 18 | 19 | filters = [64, 128, 256, 512, 1024] 20 | filters = [int(x / self.feature_scale) for x in filters] 21 | 22 | # downsampling 23 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 24 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 25 | 26 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 27 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 28 | 29 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 30 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 31 | 32 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 33 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 34 | 35 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 36 | self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) 37 | 38 | # attention blocks 39 | self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], 40 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 41 | self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], 42 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 43 | self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3], 44 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 45 | 46 | # upsampling 47 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 48 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 49 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 50 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 51 | 52 | # deep supervision 53 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) 54 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) 55 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) 56 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(n_classes*4, n_classes, 1) 60 | 61 | # initialise weights 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv3d): 64 | init_weights(m, init_type='kaiming') 65 | elif isinstance(m, nn.BatchNorm3d): 66 | init_weights(m, init_type='kaiming') 67 | 68 | def forward(self, inputs): 69 | # Feature Extraction 70 | conv1 = self.conv1(inputs) 71 | maxpool1 = self.maxpool1(conv1) 72 | 73 | conv2 = self.conv2(maxpool1) 74 | maxpool2 = self.maxpool2(conv2) 75 | 76 | conv3 = self.conv3(maxpool2) 77 | maxpool3 = self.maxpool3(conv3) 78 | 79 | conv4 = self.conv4(maxpool3) 80 | maxpool4 = self.maxpool4(conv4) 81 | 82 | # Gating Signal Generation 83 | center = self.center(maxpool4) 84 | gating = self.gating(center) 85 | 86 | # Attention Mechanism 87 | # Upscaling Part (Decoder) 88 | g_conv4, att4 = self.attentionblock4(conv4, gating) 89 | up4 = self.up_concat4(g_conv4, center) 90 | g_conv3, att3 = self.attentionblock3(conv3, up4) 91 | up3 = self.up_concat3(g_conv3, up4) 92 | g_conv2, att2 = self.attentionblock2(conv2, up3) 93 | up2 = self.up_concat2(g_conv2, up3) 94 | up1 = self.up_concat1(conv1, up2) 95 | 96 | # Deep Supervision 97 | dsv4 = self.dsv4(up4) 98 | dsv3 = self.dsv3(up3) 99 | dsv2 = self.dsv2(up2) 100 | dsv1 = self.dsv1(up1) 101 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1)) 102 | 103 | return final 104 | 105 | 106 | @staticmethod 107 | def apply_argmax_softmax(pred): 108 | log_p = F.softmax(pred, dim=1) 109 | 110 | return log_p 111 | 112 | 113 | class MultiAttentionBlock(nn.Module): 114 | def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): 115 | super(MultiAttentionBlock, self).__init__() 116 | self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, 117 | inter_channels=inter_size, mode=nonlocal_mode, 118 | sub_sample_factor= sub_sample_factor) 119 | self.combine_gates = nn.Sequential(nn.Conv3d(in_size, in_size, kernel_size=1, stride=1, padding=0), 120 | nn.BatchNorm3d(in_size), 121 | nn.ReLU(inplace=True) 122 | ) 123 | 124 | # initialise the blocks 125 | for m in self.children(): 126 | if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue 127 | init_weights(m, init_type='kaiming') 128 | 129 | def forward(self, input, gating_signal): 130 | gate_1, attention_1 = self.gate_block_1(input, gating_signal) 131 | 132 | return self.combine_gates(gate_1), attention_1 133 | 134 | 135 | -------------------------------------------------------------------------------- /models/feedforward_seg_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.optim as optim 4 | 5 | from collections import OrderedDict 6 | import utils.util as util 7 | from .base_model import BaseModel 8 | from .networks import get_network 9 | from .layers.loss import * 10 | from .networks_other import get_scheduler, print_network, benchmark_fp_bp_time 11 | from .utils import segmentation_stats, get_optimizer, get_criterion 12 | from .networks.utils import HookBasedFeatureExtractor 13 | 14 | 15 | class FeedForwardSegmentation(BaseModel): 16 | 17 | def name(self): 18 | return 'FeedForwardSegmentation' 19 | 20 | def initialize(self, opts, **kwargs): 21 | BaseModel.initialize(self, opts, **kwargs) 22 | self.isTrain = opts.isTrain 23 | 24 | # define network input and output pars 25 | self.input = None 26 | self.target = None 27 | self.tensor_dim = opts.tensor_dim 28 | 29 | # load/define networks 30 | self.net = get_network(opts.model_type, n_classes=opts.output_nc, 31 | in_channels=opts.input_nc, nonlocal_mode=opts.nonlocal_mode, 32 | tensor_dim=opts.tensor_dim, feature_scale=opts.feature_scale, 33 | attention_dsample=opts.attention_dsample) 34 | if self.use_cuda: self.net = self.net.cuda() 35 | 36 | # load the model if a path is specified or it is in inference mode 37 | if not self.isTrain or opts.continue_train: 38 | self.path_pre_trained_model = opts.path_pre_trained_model 39 | if self.path_pre_trained_model: 40 | self.load_network_from_path(self.net, self.path_pre_trained_model, strict=False) 41 | self.which_epoch = int(0) 42 | else: 43 | self.which_epoch = opts.which_epoch 44 | self.load_network(self.net, 'S', self.which_epoch) 45 | 46 | # training objective 47 | if self.isTrain: 48 | self.criterion = get_criterion(opts) 49 | # initialize optimizers 50 | self.schedulers = [] 51 | self.optimizers = [] 52 | self.optimizer_S = get_optimizer(opts, self.net.parameters()) 53 | self.optimizers.append(self.optimizer_S) 54 | 55 | # print the network details 56 | # print the network details 57 | if kwargs.get('verbose', True): 58 | print('Network is initialized') 59 | print_network(self.net) 60 | 61 | def set_scheduler(self, train_opt): 62 | for optimizer in self.optimizers: 63 | self.schedulers.append(get_scheduler(optimizer, train_opt)) 64 | print('Scheduler is added for optimiser {0}'.format(optimizer)) 65 | 66 | def set_input(self, *inputs): 67 | # self.input.resize_(inputs[0].size()).copy_(inputs[0]) 68 | for idx, _input in enumerate(inputs): 69 | # If it's a 5D array and 2D model then (B x C x H x W x Z) -> (BZ x C x H x W) 70 | bs = _input.size() 71 | if (self.tensor_dim == '2D') and (len(bs) > 4): 72 | _input = _input.permute(0,4,1,2,3).contiguous().view(bs[0]*bs[4], bs[1], bs[2], bs[3]) 73 | 74 | # Define that it's a cuda array 75 | if idx == 0: 76 | self.input = _input.cuda() if self.use_cuda else _input 77 | elif idx == 1: 78 | self.target = Variable(_input.cuda()) if self.use_cuda else Variable(_input) 79 | assert self.input.size() == self.target.size() 80 | 81 | def forward(self, split): 82 | if split == 'train': 83 | self.prediction = self.net(Variable(self.input)) 84 | elif split == 'test': 85 | self.prediction = self.net(Variable(self.input, volatile=True)) 86 | # Apply a softmax and return a segmentation map 87 | self.logits = self.net.apply_argmax_softmax(self.prediction) 88 | self.pred_seg = self.logits.data.max(1)[1].unsqueeze(1) 89 | 90 | def backward(self): 91 | self.loss_S = self.criterion(self.prediction, self.target) 92 | self.loss_S.backward() 93 | 94 | def optimize_parameters(self): 95 | self.net.train() 96 | self.forward(split='train') 97 | 98 | self.optimizer_S.zero_grad() 99 | self.backward() 100 | self.optimizer_S.step() 101 | 102 | # This function updates the network parameters every "accumulate_iters" 103 | def optimize_parameters_accumulate_grd(self, iteration): 104 | accumulate_iters = int(2) 105 | if iteration == 0: self.optimizer_S.zero_grad() 106 | self.net.train() 107 | self.forward(split='train') 108 | self.backward() 109 | 110 | if iteration % accumulate_iters == 0: 111 | self.optimizer_S.step() 112 | self.optimizer_S.zero_grad() 113 | 114 | def test(self): 115 | self.net.eval() 116 | self.forward(split='test') 117 | 118 | def validate(self): 119 | self.net.eval() 120 | self.forward(split='test') 121 | self.loss_S = self.criterion(self.prediction, self.target) 122 | 123 | def get_segmentation_stats(self): 124 | self.seg_scores, self.dice_score = segmentation_stats(self.prediction, self.target) 125 | seg_stats = [('Overall_Acc', self.seg_scores['overall_acc']), ('Mean_IOU', self.seg_scores['mean_iou'])] 126 | for class_id in range(self.dice_score.size): 127 | seg_stats.append(('Class_{}'.format(class_id), self.dice_score[class_id])) 128 | return OrderedDict(seg_stats) 129 | 130 | def get_current_errors(self): 131 | return OrderedDict([('Seg_Loss', self.loss_S.data[0]) 132 | ]) 133 | 134 | def get_current_visuals(self): 135 | inp_img = util.tensor2im(self.input, 'img') 136 | seg_img = util.tensor2im(self.pred_seg, 'lbl') 137 | return OrderedDict([('out_S', seg_img), ('inp_S', inp_img)]) 138 | 139 | def get_feature_maps(self, layer_name, upscale): 140 | feature_extractor = HookBasedFeatureExtractor(self.net, layer_name, upscale) 141 | return feature_extractor.forward(Variable(self.input)) 142 | 143 | # returns the fp/bp times of the model 144 | def get_fp_bp_time (self, size=None): 145 | if size is None: 146 | size = (1, 1, 160, 160, 96) 147 | 148 | inp_array = Variable(torch.zeros(*size)).cuda() 149 | out_array = Variable(torch.zeros(*size)).cuda() 150 | fp, bp = benchmark_fp_bp_time(self.net, inp_array, out_array) 151 | 152 | bsize = size[0] 153 | return fp/float(bsize), bp/float(bsize) 154 | 155 | def save(self, epoch_label): 156 | self.save_network(self.net, 'S', epoch_label, self.gpu_ids) 157 | -------------------------------------------------------------------------------- /models/networks/unet_CT_multi_att_dsv_3D.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .utils import UnetConv3, UnetUp3_CT, UnetGridGatingSignal3, UnetDsv3 4 | import torch.nn.functional as F 5 | from models.networks_other import init_weights 6 | from models.layers.grid_attention_layer import GridAttentionBlock3D 7 | 8 | 9 | class unet_CT_multi_att_dsv_3D(nn.Module): 10 | 11 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, 12 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True): 13 | super(unet_CT_multi_att_dsv_3D, self).__init__() 14 | self.is_deconv = is_deconv 15 | self.in_channels = in_channels 16 | self.is_batchnorm = is_batchnorm 17 | self.feature_scale = feature_scale 18 | 19 | filters = [64, 128, 256, 512, 1024] 20 | filters = [int(x / self.feature_scale) for x in filters] 21 | 22 | # downsampling 23 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 24 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 25 | 26 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 27 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 28 | 29 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 30 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 31 | 32 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 33 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 34 | 35 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 36 | self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) 37 | 38 | # attention blocks 39 | self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], 40 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 41 | self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], 42 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 43 | self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3], 44 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 45 | 46 | # upsampling 47 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 48 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 49 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 50 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 51 | 52 | # deep supervision 53 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) 54 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) 55 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) 56 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(n_classes*4, n_classes, 1) 60 | 61 | # initialise weights 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv3d): 64 | init_weights(m, init_type='kaiming') 65 | elif isinstance(m, nn.BatchNorm3d): 66 | init_weights(m, init_type='kaiming') 67 | 68 | def forward(self, inputs): 69 | # Feature Extraction 70 | conv1 = self.conv1(inputs) 71 | maxpool1 = self.maxpool1(conv1) 72 | 73 | conv2 = self.conv2(maxpool1) 74 | maxpool2 = self.maxpool2(conv2) 75 | 76 | conv3 = self.conv3(maxpool2) 77 | maxpool3 = self.maxpool3(conv3) 78 | 79 | conv4 = self.conv4(maxpool3) 80 | maxpool4 = self.maxpool4(conv4) 81 | 82 | # Gating Signal Generation 83 | center = self.center(maxpool4) 84 | gating = self.gating(center) 85 | 86 | # Attention Mechanism 87 | # Upscaling Part (Decoder) 88 | g_conv4, att4 = self.attentionblock4(conv4, gating) 89 | up4 = self.up_concat4(g_conv4, center) 90 | g_conv3, att3 = self.attentionblock3(conv3, up4) 91 | up3 = self.up_concat3(g_conv3, up4) 92 | g_conv2, att2 = self.attentionblock2(conv2, up3) 93 | up2 = self.up_concat2(g_conv2, up3) 94 | up1 = self.up_concat1(conv1, up2) 95 | 96 | # Deep Supervision 97 | dsv4 = self.dsv4(up4) 98 | dsv3 = self.dsv3(up3) 99 | dsv2 = self.dsv2(up2) 100 | dsv1 = self.dsv1(up1) 101 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1)) 102 | 103 | return final 104 | 105 | 106 | @staticmethod 107 | def apply_argmax_softmax(pred): 108 | log_p = F.softmax(pred, dim=1) 109 | 110 | return log_p 111 | 112 | 113 | class MultiAttentionBlock(nn.Module): 114 | def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): 115 | super(MultiAttentionBlock, self).__init__() 116 | self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, 117 | inter_channels=inter_size, mode=nonlocal_mode, 118 | sub_sample_factor= sub_sample_factor) 119 | self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, 120 | inter_channels=inter_size, mode=nonlocal_mode, 121 | sub_sample_factor=sub_sample_factor) 122 | self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), 123 | nn.BatchNorm3d(in_size), 124 | nn.ReLU(inplace=True) 125 | ) 126 | 127 | # initialise the blocks 128 | for m in self.children(): 129 | if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue 130 | init_weights(m, init_type='kaiming') 131 | 132 | def forward(self, input, gating_signal): 133 | gate_1, attention_1 = self.gate_block_1(input, gating_signal) 134 | gate_2, attention_2 = self.gate_block_2(input, gating_signal) 135 | 136 | return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) 137 | 138 | 139 | -------------------------------------------------------------------------------- /test_classification.py: -------------------------------------------------------------------------------- 1 | import os, sys, numpy as np 2 | from torch.utils.data import DataLoader, sampler 3 | from tqdm import tqdm 4 | 5 | 6 | from dataio.loader import get_dataset, get_dataset_path 7 | from dataio.transformation import get_dataset_transformation 8 | from utils.util import json_file_to_pyobj 9 | from utils.visualiser import Visualiser 10 | from utils.error_logger import ErrorLogger 11 | from models.networks_other import adjust_learning_rate 12 | 13 | from models import get_model 14 | 15 | class HiddenPrints: 16 | def __enter__(self): 17 | self._original_stdout = sys.stdout 18 | sys.stdout = None 19 | 20 | def __exit__(self, exc_type, exc_val, exc_tb): 21 | sys.stdout = self._original_stdout 22 | 23 | class StratifiedSampler(object): 24 | """Stratified Sampling 25 | Provides equal representation of target classes in each batch 26 | """ 27 | def __init__(self, class_vector, batch_size): 28 | """ 29 | Arguments 30 | --------- 31 | class_vector : torch tensor 32 | a vector of class labels 33 | batch_size : integer 34 | batch_size 35 | """ 36 | self.class_vector = class_vector 37 | self.batch_size = batch_size 38 | self.num_iter = len(class_vector) // 52 39 | self.n_class = 14 40 | self.sample_n = 2 41 | # create pool of each vectors 42 | indices = {} 43 | for i in range(self.n_class): 44 | indices[i] = np.where(self.class_vector == i)[0] 45 | 46 | self.indices = indices 47 | self.background_index = np.argmax([ len(indices[i]) for i in range(self.n_class)]) 48 | 49 | 50 | def gen_sample_array(self): 51 | # sample 2 from each class 52 | sample_array = [] 53 | for i in range(self.num_iter): 54 | arrs = [] 55 | for i in range(self.n_class): 56 | n = self.sample_n 57 | if i == self.background_index: 58 | n = self.sample_n * (self.n_class-1) 59 | arr = np.random.choice(self.indices[i], n) 60 | arrs.append(arr) 61 | 62 | sample_array.append(np.hstack(arrs)) 63 | return np.hstack(sample_array) 64 | 65 | def __iter__(self): 66 | return iter(self.gen_sample_array()) 67 | 68 | def __len__(self): 69 | return len(self.class_vector) 70 | 71 | 72 | def test(arguments): 73 | 74 | # Parse input arguments 75 | json_filename = arguments.config 76 | network_debug = arguments.debug 77 | 78 | # Load options 79 | json_opts = json_file_to_pyobj(json_filename) 80 | train_opts = json_opts.training 81 | 82 | # Architecture type 83 | arch_type = train_opts.arch_type 84 | 85 | # Setup Dataset and Augmentation 86 | ds_class = get_dataset(arch_type) 87 | ds_path = get_dataset_path(arch_type, json_opts.data_path) 88 | ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation) 89 | 90 | # Setup the NN Model 91 | with HiddenPrints(): 92 | model = get_model(json_opts.model) 93 | 94 | if network_debug: 95 | print('# of pars: ', model.get_number_parameters()) 96 | print('fp time: {0:.8f} sec\tbp time: {1:.8f} sec per sample'.format(*model.get_fp_bp_time2((1,1,224,288)))) 97 | exit() 98 | 99 | # Setup Data Loader 100 | num_workers = train_opts.num_workers if hasattr(train_opts, 'num_workers') else 16 101 | 102 | valid_dataset = ds_class(ds_path, split='val', transform=ds_transform['valid'], preload_data=train_opts.preloadData) 103 | test_dataset = ds_class(ds_path, split='test', transform=ds_transform['valid'], preload_data=train_opts.preloadData) 104 | # loader 105 | batch_size = train_opts.batchSize 106 | valid_loader = DataLoader(dataset=valid_dataset, num_workers=num_workers, batch_size=train_opts.batchSize, shuffle=False) 107 | test_loader = DataLoader(dataset=test_dataset, num_workers=0, batch_size=train_opts.batchSize, shuffle=False) 108 | 109 | # Visualisation Parameters 110 | filename = 'test_loss_log.txt' 111 | visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir, 112 | filename=filename) 113 | error_logger = ErrorLogger() 114 | 115 | # Training Function 116 | track_labels = np.arange(len(valid_dataset.label_names)) 117 | model.set_labels(track_labels) 118 | model.set_scheduler(train_opts) 119 | 120 | if hasattr(model.net, 'deep_supervised'): 121 | model.net.deep_supervised = False 122 | 123 | # Validation and Testing Iterations 124 | pr_lbls = [] 125 | gt_lbls = [] 126 | for loader, split in zip([test_loader], ['test']): 127 | #for loader, split in zip([valid_loader, test_loader], ['validation', 'test']): 128 | model.reset_results() 129 | 130 | for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1), total=len(loader)): 131 | 132 | # Make a forward pass with the model 133 | model.set_input(images, labels) 134 | model.validate() 135 | 136 | # Error visualisation 137 | errors = model.get_accumulated_errors() 138 | stats = model.get_classification_stats() 139 | error_logger.update({**errors, **stats}, split=split) 140 | 141 | # Update the plots 142 | # for split in ['train', 'validation', 'test']: 143 | for split in ['test']: 144 | # exclude bckground 145 | #track_labels = np.delete(track_labels, 3) 146 | #show_labels = train_dataset.label_names[:3] + train_dataset.label_names[4:] 147 | show_labels = valid_dataset.label_names 148 | visualizer.plot_current_errors(300, error_logger.get_errors(split), split_name=split, labels=show_labels) 149 | visualizer.print_current_errors(300, error_logger.get_errors(split), split_name=split) 150 | 151 | import pickle as pkl 152 | dst_file = os.path.join(model.save_dir, 'test_result.pkl') 153 | with open(dst_file, 'wb') as f: 154 | d = error_logger.get_errors(split) 155 | d['labels'] = valid_dataset.label_names 156 | d['pr_lbls'] = np.hstack(model.pr_lbls) 157 | d['gt_lbls'] = np.hstack(model.gt_lbls) 158 | pkl.dump(d, f) 159 | 160 | error_logger.reset() 161 | 162 | if arguments.time: 163 | print('# of pars: ', model.get_number_parameters()) 164 | print('fp time: {0:.8f} sec\tbp time: {1:.8f} sec per sample'.format(*model.get_fp_bp_time2((1,1,224,288)))) 165 | 166 | 167 | if __name__ == '__main__': 168 | import argparse 169 | 170 | parser = argparse.ArgumentParser(description='CNN Seg Training Function') 171 | 172 | parser.add_argument('-c', '--config', help='training config file', required=True) 173 | parser.add_argument('-d', '--debug', help='returns number of parameters and bp/fp runtime', action='store_true') 174 | parser.add_argument('-t', '--time', help='returns number of parameters and bp/fp runtime', action='store_true') 175 | args = parser.parse_args() 176 | 177 | test(args) 178 | -------------------------------------------------------------------------------- /models/feedforward_classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import utils.util as util 4 | from collections import OrderedDict 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | from .base_model import BaseModel 9 | from .networks import get_network 10 | from .layers.loss import * 11 | from .networks_other import get_scheduler, print_network, benchmark_fp_bp_time 12 | from .utils import classification_stats, get_optimizer, get_criterion 13 | from .networks.utils import HookBasedFeatureExtractor 14 | 15 | 16 | class FeedForwardClassifier(BaseModel): 17 | 18 | def name(self): 19 | return 'FeedForwardClassifier' 20 | 21 | def initialize(self, opts, **kwargs): 22 | BaseModel.initialize(self, opts, **kwargs) 23 | self.opts = opts 24 | self.isTrain = opts.isTrain 25 | 26 | # define network input and output pars 27 | self.input = None 28 | self.target = None 29 | self.labels = None 30 | self.tensor_dim = opts.tensor_dim 31 | 32 | # load/define networks 33 | self.net = get_network(opts.model_type, n_classes=opts.output_nc, 34 | in_channels=opts.input_nc, nonlocal_mode=opts.nonlocal_mode, 35 | tensor_dim=opts.tensor_dim, feature_scale=opts.feature_scale, 36 | attention_dsample=opts.attention_dsample, 37 | aggregation_mode=opts.aggregation_mode) 38 | if self.use_cuda: self.net = self.net.cuda() 39 | 40 | # load the model if a path is specified or it is in inference mode 41 | if not self.isTrain or opts.continue_train: 42 | self.path_pre_trained_model = opts.path_pre_trained_model 43 | if self.path_pre_trained_model: 44 | self.load_network_from_path(self.net, self.path_pre_trained_model, strict=False) 45 | self.which_epoch = int(0) 46 | else: 47 | self.which_epoch = opts.which_epoch 48 | self.load_network(self.net, 'S', self.which_epoch) 49 | 50 | # training objective 51 | if self.isTrain: 52 | self.criterion = get_criterion(opts) 53 | # initialize optimizers 54 | self.schedulers = [] 55 | self.optimizers = [] 56 | 57 | self.optimizer = get_optimizer(opts, self.net.parameters()) 58 | self.optimizers.append(self.optimizer) 59 | 60 | # print the network details 61 | if kwargs.get('verbose', True): 62 | print('Network is initialized') 63 | print_network(self.net) 64 | 65 | # for accumulator 66 | self.reset_results() 67 | 68 | def set_scheduler(self, train_opt): 69 | for optimizer in self.optimizers: 70 | self.schedulers.append(get_scheduler(optimizer, train_opt)) 71 | print('Scheduler is added for optimiser {0}'.format(optimizer)) 72 | 73 | def set_input(self, *inputs): 74 | # self.input.resize_(inputs[0].size()).copy_(inputs[0]) 75 | for idx, _input in enumerate(inputs): 76 | # If it's a 5D array and 2D model then (B x C x H x W x Z) -> (BZ x C x H x W) 77 | bs = _input.size() 78 | if (self.tensor_dim == '2D') and (len(bs) > 4): 79 | _input = _input.permute(0,4,1,2,3).contiguous().view(bs[0]*bs[4], bs[1], bs[2], bs[3]) 80 | 81 | # Define that it's a cuda array 82 | if idx == 0: 83 | self.input = _input.cuda() if self.use_cuda else _input 84 | elif idx == 1: 85 | self.target = Variable(_input.cuda()) if self.use_cuda else Variable(_input) 86 | assert self.input.shape[0] == self.target.shape[0] 87 | 88 | def forward(self, split): 89 | if split == 'train': 90 | self.prediction = self.net(Variable(self.input)) 91 | elif split in ['validation', 'test']: 92 | self.prediction = self.net(Variable(self.input, volatile=True)) 93 | # Apply a softmax and return a segmentation map 94 | self.logits = self.net.apply_argmax_softmax(self.prediction) 95 | self.pred = self.logits.data.max(1) 96 | 97 | 98 | def backward(self): 99 | #print(self.net.apply_argmax_softmax(self.prediction), self.target) 100 | self.loss = self.criterion(self.prediction, self.target) 101 | self.loss.backward() 102 | 103 | def optimize_parameters(self): 104 | self.net.train() 105 | self.forward(split='train') 106 | 107 | self.optimizer.zero_grad() 108 | self.backward() 109 | self.optimizer.step() 110 | 111 | def test(self): 112 | self.net.eval() 113 | self.forward(split='test') 114 | self.accumulate_results() 115 | 116 | def validate(self): 117 | self.net.eval() 118 | self.forward(split='test') 119 | self.loss = self.criterion(self.prediction, self.target) 120 | self.accumulate_results() 121 | 122 | def reset_results(self): 123 | self.losses = [] 124 | self.pr_lbls = [] 125 | self.pr_probs = [] 126 | self.gt_lbls = [] 127 | 128 | def accumulate_results(self): 129 | self.losses.append(self.loss.data[0]) 130 | self.pr_probs.append(self.pred[0].cpu().numpy()) 131 | self.pr_lbls.append(self.pred[1].cpu().numpy()) 132 | self.gt_lbls.append(self.target.data.cpu().numpy()) 133 | 134 | def get_classification_stats(self): 135 | self.pr_lbls = np.concatenate(self.pr_lbls) 136 | self.gt_lbls = np.concatenate(self.gt_lbls) 137 | res = classification_stats(self.pr_lbls, self.gt_lbls, self.labels) 138 | (self.accuracy, self.f1_micro, self.precision_micro, 139 | self.recall_micro, self.f1_macro, self.precision_macro, 140 | self.recall_macro, self.confusion, self.class_accuracies, 141 | self.f1s, self.precisions,self.recalls) = res 142 | 143 | breakdown = dict(type='table', 144 | colnames=['|accuracy|',' precison|',' recall|',' f1_score|'], 145 | rownames=self.labels, 146 | data=[self.class_accuracies, self.precisions,self.recalls, self.f1s]) 147 | 148 | return OrderedDict([('accuracy', self.accuracy), 149 | ('confusion', self.confusion), 150 | ('f1', self.f1_macro), 151 | ('precision', self.precision_macro), 152 | ('recall', self.recall_macro), 153 | ('confusion', self.confusion), 154 | ('breakdown', breakdown)]) 155 | 156 | def get_current_errors(self): 157 | return OrderedDict([('CE', self.loss.data[0])]) 158 | 159 | def get_accumulated_errors(self): 160 | return OrderedDict([('CE', np.mean(self.losses))]) 161 | 162 | def get_current_visuals(self): 163 | inp_img = util.tensor2im(self.input, 'img') 164 | return OrderedDict([('inp_S', inp_img)]) 165 | 166 | def get_feature_maps(self, layer_name, upscale): 167 | feature_extractor = HookBasedFeatureExtractor(self.net, layer_name, upscale) 168 | return feature_extractor.forward(Variable(self.input)) 169 | 170 | 171 | def save(self, epoch_label): 172 | self.save_network(self.net, 'S', epoch_label, self.gpu_ids) 173 | 174 | def set_labels(self, labels): 175 | self.labels = labels 176 | 177 | def load_network_from_path(self, network, network_filepath, strict): 178 | network_label = os.path.basename(network_filepath) 179 | epoch_label = network_label.split('_')[0] 180 | print('Loading the model {0} - epoch {1}'.format(network_label, epoch_label)) 181 | network.load_state_dict(torch.load(network_filepath), strict=strict) 182 | 183 | def update_state(self, epoch): 184 | pass 185 | 186 | def get_fp_bp_time2(self, size=None): 187 | # returns the fp/bp times of the model 188 | if size is None: 189 | size = (8, 1, 192, 192) 190 | 191 | inp_array = Variable(torch.rand(*size)).cuda() 192 | out_array = Variable(torch.rand(*size)).cuda() 193 | fp, bp = benchmark_fp_bp_time(self.net, inp_array, out_array) 194 | 195 | bsize = size[0] 196 | return fp/float(bsize), bp/float(bsize) 197 | -------------------------------------------------------------------------------- /dataio/transformation/transforms.py: -------------------------------------------------------------------------------- 1 | import torchsample.transforms as ts 2 | from pprint import pprint 3 | 4 | 5 | class Transformations: 6 | 7 | def __init__(self, name): 8 | self.name = name 9 | 10 | # Input patch and scale size 11 | self.scale_size = (192, 192, 1) 12 | self.patch_size = (128, 128, 1) 13 | # self.patch_size = (208, 272, 1) 14 | 15 | # Affine and Intensity Transformations 16 | self.shift_val = (0.1, 0.1) 17 | self.rotate_val = 15.0 18 | self.scale_val = (0.7, 1.3) 19 | self.inten_val = (1.0, 1.0) 20 | self.random_flip_prob = 0.0 21 | 22 | # Divisibility factor for testing 23 | self.division_factor = (16, 16, 1) 24 | 25 | def get_transformation(self): 26 | return { 27 | 'ukbb_sax': self.cmr_3d_sax_transform, 28 | 'hms_sax': self.hms_sax_transform, 29 | 'test_sax': self.test_3d_sax_transform, 30 | 'acdc_sax': self.cmr_3d_sax_transform, 31 | 'us': self.ultrasound_transform, 32 | }[self.name]() 33 | 34 | def print(self): 35 | print('\n\n############# Augmentation Parameters #############') 36 | pprint(vars(self)) 37 | print('###################################################\n\n') 38 | 39 | def initialise(self, opts): 40 | t_opts = getattr(opts, self.name) 41 | 42 | # Affine and Intensity Transformations 43 | if hasattr(t_opts, 'scale_size'): self.scale_size = t_opts.scale_size 44 | if hasattr(t_opts, 'patch_size'): self.patch_size = t_opts.patch_size 45 | if hasattr(t_opts, 'shift_val'): self.shift_val = t_opts.shift 46 | if hasattr(t_opts, 'rotate_val'): self.rotate_val = t_opts.rotate 47 | if hasattr(t_opts, 'scale_val'): self.scale_val = t_opts.scale 48 | if hasattr(t_opts, 'inten_val'): self.inten_val = t_opts.intensity 49 | if hasattr(t_opts, 'random_flip_prob'): self.random_flip_prob = t_opts.random_flip_prob 50 | if hasattr(t_opts, 'division_factor'): self.division_factor = t_opts.division_factor 51 | 52 | def ukbb_sax_transform(self): 53 | 54 | train_transform = ts.Compose([ts.PadNumpy(size=self.scale_size), 55 | ts.ToTensor(), 56 | ts.ChannelsFirst(), 57 | ts.TypeCast(['float', 'float']), 58 | ts.RandomFlip(h=True, v=True, p=self.random_flip_prob), 59 | ts.RandomAffine(rotation_range=self.rotate_val, translation_range=self.shift_val, 60 | zoom_range=self.scale_val, interp=('bilinear', 'nearest')), 61 | ts.NormalizeMedicPercentile(norm_flag=(True, False)), 62 | ts.RandomCrop(size=self.patch_size), 63 | ts.TypeCast(['float', 'long']) 64 | ]) 65 | 66 | valid_transform = ts.Compose([ts.PadNumpy(size=self.scale_size), 67 | ts.ToTensor(), 68 | ts.ChannelsFirst(), 69 | ts.TypeCast(['float', 'float']), 70 | ts.NormalizeMedicPercentile(norm_flag=(True, False)), 71 | ts.SpecialCrop(size=self.patch_size, crop_type=0), 72 | ts.TypeCast(['float', 'long']) 73 | ]) 74 | 75 | return {'train': train_transform, 'valid': valid_transform} 76 | 77 | def cmr_3d_sax_transform(self): 78 | 79 | train_transform = ts.Compose([ts.PadNumpy(size=self.scale_size), 80 | ts.ToTensor(), 81 | ts.ChannelsFirst(), 82 | ts.TypeCast(['float', 'float']), 83 | ts.RandomFlip(h=True, v=True, p=self.random_flip_prob), 84 | ts.RandomAffine(rotation_range=self.rotate_val, translation_range=self.shift_val, 85 | zoom_range=self.scale_val, interp=('bilinear', 'nearest')), 86 | #ts.NormalizeMedicPercentile(norm_flag=(True, False)), 87 | ts.NormalizeMedic(norm_flag=(True, False)), 88 | ts.ChannelsLast(), 89 | ts.AddChannel(axis=0), 90 | ts.RandomCrop(size=self.patch_size), 91 | ts.TypeCast(['float', 'long']) 92 | ]) 93 | 94 | valid_transform = ts.Compose([ts.PadNumpy(size=self.scale_size), 95 | ts.ToTensor(), 96 | ts.ChannelsFirst(), 97 | ts.TypeCast(['float', 'float']), 98 | #ts.NormalizeMedicPercentile(norm_flag=(True, False)), 99 | ts.NormalizeMedic(norm_flag=(True, False)), 100 | ts.ChannelsLast(), 101 | ts.AddChannel(axis=0), 102 | ts.SpecialCrop(size=self.patch_size, crop_type=0), 103 | ts.TypeCast(['float', 'long']) 104 | ]) 105 | 106 | return {'train': train_transform, 'valid': valid_transform} 107 | 108 | def hms_sax_transform(self): 109 | 110 | # Training transformation 111 | # 2D Stack input - 3D High Resolution output segmentation 112 | 113 | train_transform = [] 114 | valid_transform = [] 115 | 116 | # First pad to a fixed size 117 | # Torch tensor 118 | # Channels first 119 | # Joint affine transformation 120 | # In-plane respiratory motion artefacts (translation and rotation) 121 | # Random Crop 122 | # Normalise the intensity range 123 | train_transform = ts.Compose([]) 124 | 125 | return {'train': train_transform, 'valid': valid_transform} 126 | 127 | def test_3d_sax_transform(self): 128 | test_transform = ts.Compose([ts.PadFactorNumpy(factor=self.division_factor), 129 | ts.ToTensor(), 130 | ts.ChannelsFirst(), 131 | ts.TypeCast(['float']), 132 | #ts.NormalizeMedicPercentile(norm_flag=True), 133 | ts.NormalizeMedic(norm_flag=True), 134 | ts.ChannelsLast(), 135 | ts.AddChannel(axis=0), 136 | ]) 137 | 138 | return {'test': test_transform} 139 | 140 | 141 | def ultrasound_transform(self): 142 | 143 | train_transform = ts.Compose([ts.ToTensor(), 144 | ts.TypeCast(['float']), 145 | ts.AddChannel(axis=0), 146 | ts.SpecialCrop(self.patch_size,0), 147 | ts.RandomFlip(h=True, v=False, p=self.random_flip_prob), 148 | ts.RandomAffine(rotation_range=self.rotate_val, 149 | translation_range=self.shift_val, 150 | zoom_range=self.scale_val, 151 | interp=('bilinear')), 152 | ts.StdNormalize(), 153 | ]) 154 | 155 | valid_transform = ts.Compose([ts.ToTensor(), 156 | ts.TypeCast(['float']), 157 | ts.AddChannel(axis=0), 158 | ts.SpecialCrop(self.patch_size,0), 159 | ts.StdNormalize(), 160 | ]) 161 | 162 | return {'train': train_transform, 'valid': valid_transform} 163 | -------------------------------------------------------------------------------- /visualise_attention.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from dataio.loader import get_dataset, get_dataset_path 4 | from dataio.transformation import get_dataset_transformation 5 | from utils.util import json_file_to_pyobj 6 | from utils.visualiser import Visualiser 7 | from models import get_model 8 | import os, time 9 | 10 | # import matplotlib 11 | # matplotlib.use('Agg') 12 | 13 | import matplotlib.cm as cm 14 | import matplotlib.pyplot as plt 15 | import math, numpy 16 | import numpy as np 17 | from scipy.misc import imresize 18 | from skimage.transform import resize 19 | 20 | def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''): 21 | plt.ion() 22 | filters = units.shape[2] 23 | n_columns = round(math.sqrt(filters)) 24 | n_rows = math.ceil(filters / n_columns) + 1 25 | fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3)) 26 | fig.clf() 27 | 28 | for i in range(filters): 29 | ax1 = plt.subplot(n_rows, n_columns, i+1) 30 | plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap) 31 | plt.axis('on') 32 | ax1.set_xticklabels([]) 33 | ax1.set_yticklabels([]) 34 | plt.colorbar() 35 | if colormap_lim: 36 | plt.clim(colormap_lim[0],colormap_lim[1]) 37 | 38 | plt.subplots_adjust(wspace=0, hspace=0) 39 | plt.tight_layout() 40 | plt.suptitle(title) 41 | 42 | def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear', 43 | colormap=cm.jet, colormap_lim=None, title='', alpha=0.8): 44 | plt.ion() 45 | filters = units.shape[2] 46 | fig = plt.figure(figure_id, figsize=(5,5)) 47 | fig.clf() 48 | 49 | for i in range(filters): 50 | plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray') 51 | plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha) 52 | plt.axis('off') 53 | plt.colorbar() 54 | plt.title(title, fontsize='small') 55 | if colormap_lim: 56 | plt.clim(colormap_lim[0],colormap_lim[1]) 57 | 58 | plt.subplots_adjust(wspace=0, hspace=0) 59 | plt.tight_layout() 60 | 61 | # plt.savefig('{}/{}.png'.format(dir_name,time.time())) 62 | 63 | 64 | 65 | 66 | ## Load options 67 | PAUSE = .01 68 | #config_name = 'config_sononet_attention_fs8_v6.json' 69 | #config_name = 'config_sononet_attention_fs8_v8.json' 70 | #config_name = 'config_sononet_attention_fs8_v9.json' 71 | #config_name = 'config_sononet_attention_fs8_v10.json' 72 | #config_name = 'config_sononet_attention_fs8_v11.json' 73 | #config_name = 'config_sononet_attention_fs8_v13.json' 74 | #config_name = 'config_sononet_attention_fs8_v14.json' 75 | #config_name = 'config_sononet_attention_fs8_v15.json' 76 | #config_name = 'config_sononet_attention_fs8_v16.json' 77 | #config_name = 'config_sononet_grid_attention_fs8_v1.json' 78 | config_name = 'config_sononet_grid_attention_fs8_deepsup_v1.json' 79 | config_name = 'config_sononet_grid_attention_fs8_deepsup_v2.json' 80 | config_name = 'config_sononet_grid_attention_fs8_deepsup_v3.json' 81 | config_name = 'config_sononet_grid_attention_fs8_deepsup_v4.json' 82 | 83 | # config_name = 'config_sononet_grid_att_fs8_avg.json' 84 | config_name = 'config_sononet_grid_att_fs8_avg_v2.json' 85 | # config_name = 'config_sononet_grid_att_fs8_avg_v3.json' 86 | #config_name = 'config_sononet_grid_att_fs8_avg_v4.json' 87 | #config_name = 'config_sononet_grid_att_fs8_avg_v5.json' 88 | #config_name = 'config_sononet_grid_att_fs8_avg_v5.json' 89 | #config_name = 'config_sononet_grid_att_fs8_avg_v6.json' 90 | #config_name = 'config_sononet_grid_att_fs8_avg_v7.json' 91 | #config_name = 'config_sononet_grid_att_fs8_avg_v8.json' 92 | #config_name = 'config_sononet_grid_att_fs8_avg_v9.json' 93 | #config_name = 'config_sononet_grid_att_fs8_avg_v10.json' 94 | #config_name = 'config_sononet_grid_att_fs8_avg_v11.json' 95 | #config_name = 'config_sononet_grid_att_fs8_avg_v12.json' 96 | 97 | config_name = 'config_sononet_grid_att_fs8_avg_v12_scratch.json' 98 | config_name = 'config_sononet_grid_att_fs4_avg_v12.json' 99 | 100 | #config_name = 'config_sononet_grid_attention_fs8_v3.json' 101 | 102 | json_opts = json_file_to_pyobj('/vol/bitbucket/js3611/projects/transfer_learning/ultrasound/configs_2/{}'.format(config_name)) 103 | train_opts = json_opts.training 104 | 105 | dir_name = os.path.join('visualisation_debug', config_name) 106 | if not os.path.isdir(dir_name): 107 | os.makedirs(dir_name) 108 | os.makedirs(os.path.join(dir_name,'pos')) 109 | os.makedirs(os.path.join(dir_name,'neg')) 110 | 111 | # Setup the NN Model 112 | model = get_model(json_opts.model) 113 | if hasattr(model.net, 'classification_mode'): 114 | model.net.classification_mode = 'attention' 115 | if hasattr(model.net, 'deep_supervised'): 116 | model.net.deep_supervised = False 117 | 118 | # Setup Dataset and Augmentation 119 | dataset_class = get_dataset(train_opts.arch_type) 120 | dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path) 121 | dataset_transform = get_dataset_transformation(train_opts.arch_type, opts=json_opts.augmentation) 122 | 123 | # Setup Data Loader 124 | dataset = dataset_class(dataset_path, split='train', transform=dataset_transform['valid']) 125 | data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, shuffle=True) 126 | 127 | # test 128 | for iteration, data in enumerate(data_loader, 1): 129 | model.set_input(data[0], data[1]) 130 | 131 | cls = dataset.label_names[int(data[1])] 132 | 133 | model.validate() 134 | pred_class = model.pred[1] 135 | pred_cls = dataset.label_names[int(pred_class)] 136 | 137 | ######################################################### 138 | # Display the input image and Down_sample the input image 139 | input_img = model.input[0,0].cpu().numpy() 140 | #input_img = numpy.expand_dims(imresize(input_img, (fmap_size[0], fmap_size[1]), interp='bilinear'), axis=2) 141 | input_img = numpy.expand_dims(input_img, axis=2) 142 | 143 | # plotNNFilter(input_img, figure_id=0, colormap="gray") 144 | plotNNFilterOverlay(input_img, numpy.zeros_like(input_img), figure_id=0, interp='bilinear', 145 | colormap=cm.jet, title='[GT:{}|P:{}]'.format(cls, pred_cls),alpha=0) 146 | 147 | chance = np.random.random() < 0.01 if cls == "BACKGROUND" else 1 148 | if cls != pred_cls: 149 | plt.savefig('{}/neg/{:03d}.png'.format(dir_name,iteration)) 150 | elif cls == pred_cls and chance: 151 | plt.savefig('{}/pos/{:03d}.png'.format(dir_name,iteration)) 152 | ######################################################### 153 | # Compatibility Scores overlay with input 154 | attentions = [] 155 | for i in [1,2]: 156 | fmap = model.get_feature_maps('compatibility_score%d'%i, upscale=False) 157 | if not fmap: 158 | continue 159 | 160 | # Output of the attention block 161 | fmap_0 = fmap[0].squeeze().permute(1,2,0).cpu().numpy() 162 | fmap_size = fmap_0.shape 163 | 164 | # Attention coefficient (b x c x w x h x s) 165 | attention = fmap[1].squeeze().cpu().numpy() 166 | attention = attention[:, :] 167 | #attention = numpy.expand_dims(resize(attention, (fmap_size[0], fmap_size[1]), mode='constant', preserve_range=True), axis=2) 168 | attention = numpy.expand_dims(resize(attention, (input_img.shape[0], input_img.shape[1]), mode='constant', preserve_range=True), axis=2) 169 | 170 | # this one is useless 171 | #plotNNFilter(fmap_0, figure_id=i+3, interp='bilinear', colormap=cm.jet, title='compat. feature %d' %i) 172 | 173 | plotNNFilterOverlay(input_img, attention, figure_id=i, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. {}'.format(cls,pred_cls,i), alpha=0.5) 174 | attentions.append(attention) 175 | 176 | #plotNNFilterOverlay(input_img, attentions[0], figure_id=4, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. (all)'.format(cls, pred_cls), alpha=0.5) 177 | plotNNFilterOverlay(input_img, numpy.mean(attentions,0), figure_id=4, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. (all)'.format(cls, pred_cls), alpha=0.5) 178 | 179 | if cls != pred_cls: 180 | plt.savefig('{}/neg/{:03d}_hm.png'.format(dir_name,iteration)) 181 | elif cls == pred_cls and chance: 182 | plt.savefig('{}/pos/{:03d}_hm.png'.format(dir_name,iteration)) 183 | # Linear embedding g(x) 184 | # (b, c, h, w) 185 | #gx = fmap[2].squeeze().permute(1,2,0).cpu().numpy() 186 | #plotNNFilter(gx, figure_id=3, interp='nearest', colormap=cm.jet) 187 | 188 | plt.show() 189 | plt.pause(PAUSE) 190 | 191 | model.destructor() 192 | #if iteration == 1: break 193 | -------------------------------------------------------------------------------- /utils/visualiser.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import ntpath 5 | import time 6 | from utils import util, html 7 | 8 | # Use the following comment to launch a visdom server 9 | # python -m visdom.server 10 | 11 | class Visualiser(): 12 | def __init__(self, opt, save_dir, filename='loss_log.txt'): 13 | self.display_id = opt.display_id 14 | self.use_html = not opt.no_html 15 | self.win_size = opt.display_winsize 16 | self.save_dir = save_dir 17 | self.name = os.path.basename(self.save_dir) 18 | self.saved = False 19 | self.display_single_pane_ncols = opt.display_single_pane_ncols 20 | 21 | # Error plots 22 | self.error_plots = dict() 23 | self.error_wins = dict() 24 | 25 | if self.display_id > 0: 26 | import visdom 27 | self.vis = visdom.Visdom(port=opt.display_port) 28 | 29 | if self.use_html: 30 | self.web_dir = os.path.join(self.save_dir, 'web') 31 | self.img_dir = os.path.join(self.web_dir, 'images') 32 | print('create web directory %s...' % self.web_dir) 33 | util.mkdirs([self.web_dir, self.img_dir]) 34 | self.log_name = os.path.join(self.save_dir, filename) 35 | with open(self.log_name, "a") as log_file: 36 | now = time.strftime("%c") 37 | log_file.write('================ Training Loss (%s) ================\n' % now) 38 | 39 | def reset(self): 40 | self.saved = False 41 | 42 | # |visuals|: dictionary of images to display or save 43 | def display_current_results(self, visuals, epoch, save_result): 44 | if self.display_id > 0: # show images in the browser 45 | ncols = self.display_single_pane_ncols 46 | if ncols > 0: 47 | h, w = next(iter(visuals.values())).shape[:2] 48 | table_css = """""" % (w, h) 52 | title = self.name 53 | label_html = '' 54 | label_html_row = '' 55 | nrows = int(np.ceil(len(visuals.items()) / ncols)) 56 | images = [] 57 | idx = 0 58 | for label, image_numpy in visuals.items(): 59 | label_html_row += '%s' % label 60 | images.append(image_numpy.transpose([2, 0, 1])) 61 | idx += 1 62 | if idx % ncols == 0: 63 | label_html += '%s' % label_html_row 64 | label_html_row = '' 65 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 66 | while idx % ncols != 0: 67 | images.append(white_image) 68 | label_html_row += '' 69 | idx += 1 70 | if label_html_row != '': 71 | label_html += '%s' % label_html_row 72 | # pane col = image row 73 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 74 | padding=2, opts=dict(title=title + ' images')) 75 | label_html = '%s
' % label_html 76 | self.vis.text(table_css + label_html, win=self.display_id + 2, 77 | opts=dict(title=title + ' labels')) 78 | else: 79 | idx = 1 80 | for label, image_numpy in visuals.items(): 81 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 82 | win=self.display_id + idx) 83 | idx += 1 84 | 85 | if self.use_html and (save_result or not self.saved): # save images to a html file 86 | self.saved = True 87 | for label, image_numpy in visuals.items(): 88 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 89 | util.save_image(image_numpy, img_path) 90 | # update website 91 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 92 | for n in range(epoch, 0, -1): 93 | webpage.add_header('epoch [%d]' % n) 94 | ims = [] 95 | txts = [] 96 | links = [] 97 | 98 | for label, image_numpy in visuals.items(): 99 | img_path = 'epoch%.3d_%s.png' % (n, label) 100 | ims.append(img_path) 101 | txts.append(label) 102 | links.append(img_path) 103 | webpage.add_images(ims, txts, links, width=self.win_size) 104 | webpage.save() 105 | 106 | def plot_table_html(self, x, y, key, split_name, **kwargs): 107 | key_s = key+'_'+split_name 108 | if key_s not in self.error_plots: 109 | self.error_wins[key_s] = self.display_id * 3 + len(self.error_wins) 110 | else: 111 | self.vis.close(self.error_plots[key_s]) 112 | 113 | 114 | table = pd.DataFrame(np.array(y['data']).transpose(), 115 | index=kwargs['labels'], columns=y['colnames']) 116 | table_html = table.round(2).to_html(col_space=200, bold_rows=True, border=12) 117 | 118 | self.error_plots[key_s] = self.vis.text(table_html, 119 | opts=dict(title=self.name+split_name, 120 | width=350, height=350, 121 | win=self.error_wins[key_s])) 122 | 123 | 124 | def plot_heatmap(self, x, y, key, split_name, **kwargs): 125 | key_s = key+'_'+split_name 126 | if key_s not in self.error_plots: 127 | self.error_wins[key_s] = self.display_id * 3 + len(self.error_wins) 128 | else: 129 | self.vis.close(self.error_plots[key_s]) 130 | self.error_plots[key_s] = self.vis.heatmap( 131 | X=y, 132 | opts=dict( 133 | columnnames=kwargs['labels'], 134 | rownames=kwargs['labels'], 135 | title=self.name + ' confusion matrix', 136 | win=self.error_wins[key_s])) 137 | 138 | def plot_line(self, x, y, key, split_name): 139 | if key not in self.error_plots: 140 | self.error_wins[key] = self.display_id * 3 + len(self.error_wins) 141 | self.error_plots[key] = self.vis.line( 142 | X=np.array([x, x]), 143 | Y=np.array([y, y]), 144 | opts=dict( 145 | legend=[split_name], 146 | title=self.name + ' {} over time'.format(key), 147 | xlabel='Epochs', 148 | ylabel=key, 149 | win=self.error_wins[key] 150 | )) 151 | else: 152 | self.vis.updateTrace(X=np.array([x]), Y=np.array([y]), win=self.error_plots[key], name=split_name) 153 | # errors: dictionary of error labels and values 154 | def plot_current_errors(self, epoch, errors, split_name, counter_ratio=0.0, **kwargs): 155 | if self.display_id > 0: 156 | for key in errors.keys(): 157 | x = epoch + counter_ratio 158 | y = errors[key] 159 | if isinstance(y, dict): 160 | if y['type'] == 'table': 161 | self.plot_table_html(x,y,key,split_name, **kwargs) 162 | elif np.isscalar(y): 163 | self.plot_line(x,y,key,split_name) 164 | elif y.ndim == 2: 165 | self.plot_heatmap(x,y,key,split_name, **kwargs) 166 | 167 | 168 | # errors: same format as |errors| of plotCurrentErrors 169 | def print_current_errors(self, epoch, errors, split_name): 170 | message = '(epoch: %d, split: %s) ' % (epoch, split_name) 171 | for k, v in errors.items(): 172 | if np.isscalar(v): 173 | message += '%s: %.3f ' % (k, v) 174 | 175 | print(message) 176 | with open(self.log_name, "a") as log_file: 177 | log_file.write('%s\n' % message) 178 | 179 | # save image to the disk 180 | def save_images(self, webpage, visuals, image_path): 181 | image_dir = webpage.get_image_dir() 182 | short_path = ntpath.basename(image_path[0]) 183 | name = os.path.splitext(short_path)[0] 184 | 185 | webpage.add_header(name) 186 | ims = [] 187 | txts = [] 188 | links = [] 189 | 190 | for label, image_numpy in visuals.items(): 191 | image_name = '%s_%s.png' % (name, label) 192 | save_path = os.path.join(image_dir, image_name) 193 | util.save_image(image_numpy, save_path) 194 | 195 | ims.append(image_name) 196 | txts.append(label) 197 | links.append(image_name) 198 | webpage.add_images(ims, txts, links, width=self.win_size) 199 | -------------------------------------------------------------------------------- /train_classifaction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader, sampler 3 | from tqdm import tqdm 4 | 5 | 6 | from dataio.loader import get_dataset, get_dataset_path 7 | from dataio.transformation import get_dataset_transformation 8 | from utils.util import json_file_to_pyobj 9 | from utils.visualiser import Visualiser 10 | from utils.error_logger import ErrorLogger 11 | from models.networks_other import adjust_learning_rate 12 | 13 | from models import get_model 14 | 15 | 16 | class StratifiedSampler(object): 17 | """Stratified Sampling 18 | Provides equal representation of target classes in each batch 19 | """ 20 | def __init__(self, class_vector, batch_size): 21 | """ 22 | Arguments 23 | --------- 24 | class_vector : torch tensor 25 | a vector of class labels 26 | batch_size : integer 27 | batch_size 28 | """ 29 | self.class_vector = class_vector 30 | self.batch_size = batch_size 31 | self.num_iter = len(class_vector) // 52 32 | self.n_class = 14 33 | self.sample_n = 2 34 | # create pool of each vectors 35 | indices = {} 36 | for i in range(self.n_class): 37 | indices[i] = np.where(self.class_vector == i)[0] 38 | 39 | self.indices = indices 40 | self.background_index = np.argmax([ len(indices[i]) for i in range(self.n_class)]) 41 | 42 | 43 | def gen_sample_array(self): 44 | # sample 2 from each class 45 | sample_array = [] 46 | for i in range(self.num_iter): 47 | arrs = [] 48 | for i in range(self.n_class): 49 | n = self.sample_n 50 | if i == self.background_index: 51 | n = self.sample_n * (self.n_class-1) 52 | arr = np.random.choice(self.indices[i], n) 53 | arrs.append(arr) 54 | 55 | sample_array.append(np.hstack(arrs)) 56 | return np.hstack(sample_array) 57 | 58 | def __iter__(self): 59 | return iter(self.gen_sample_array()) 60 | 61 | def __len__(self): 62 | return len(self.class_vector) 63 | 64 | 65 | # Not using anymore 66 | def check_warm_start(epoch, model, train_opts): 67 | if hasattr(train_opts, "warm_start_epoch"): 68 | if epoch < train_opts.warm_start_epoch: 69 | print('... warm_start: lr={}'.format(train_opts.warm_start_lr)) 70 | adjust_learning_rate(model.optimizers[0], train_opts.warm_start_lr) 71 | elif epoch == train_opts.warm_start_epoch: 72 | print('... warm_start ended: lr={}'.format(model.opts.lr_rate)) 73 | adjust_learning_rate(model.optimizers[0], model.opts.lr_rate) 74 | 75 | 76 | def train(arguments): 77 | 78 | # Parse input arguments 79 | json_filename = arguments.config 80 | network_debug = arguments.debug 81 | 82 | # Load options 83 | json_opts = json_file_to_pyobj(json_filename) 84 | train_opts = json_opts.training 85 | 86 | # Architecture type 87 | arch_type = train_opts.arch_type 88 | 89 | # Setup Dataset and Augmentation 90 | ds_class = get_dataset(arch_type) 91 | ds_path = get_dataset_path(arch_type, json_opts.data_path) 92 | ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation) 93 | 94 | # Setup the NN Model 95 | model = get_model(json_opts.model) 96 | if network_debug: 97 | print('# of pars: ', model.get_number_parameters()) 98 | print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(*model.get_fp_bp_time())) 99 | exit() 100 | 101 | # Setup Data Loader 102 | num_workers = train_opts.num_workers if hasattr(train_opts, 'num_workers') else 16 103 | train_dataset = ds_class(ds_path, split='train', transform=ds_transform['train'], preload_data=train_opts.preloadData) 104 | valid_dataset = ds_class(ds_path, split='val', transform=ds_transform['valid'], preload_data=train_opts.preloadData) 105 | test_dataset = ds_class(ds_path, split='test', transform=ds_transform['valid'], preload_data=train_opts.preloadData) 106 | 107 | # create sampler 108 | if train_opts.sampler == 'stratified': 109 | print('stratified sampler') 110 | train_sampler = StratifiedSampler(train_dataset.labels, train_opts.batchSize) 111 | batch_size = 52 112 | elif train_opts.sampler == 'weighted2': 113 | print('weighted sampler with background weight={}x'.format(train_opts.bgd_weight_multiplier)) 114 | # modify and increase background weight 115 | weight = train_dataset.weight 116 | bgd_weight = np.min(weight) 117 | weight[abs(weight - bgd_weight) < 1e-8] = bgd_weight * train_opts.bgd_weight_multiplier 118 | train_sampler = sampler.WeightedRandomSampler(weight, len(train_dataset.weight)) 119 | batch_size = train_opts.batchSize 120 | else: 121 | print('weighted sampler') 122 | train_sampler = sampler.WeightedRandomSampler(train_dataset.weight, len(train_dataset.weight)) 123 | batch_size = train_opts.batchSize 124 | 125 | # loader 126 | train_loader = DataLoader(dataset=train_dataset, num_workers=num_workers, 127 | batch_size=batch_size, sampler=train_sampler) 128 | valid_loader = DataLoader(dataset=valid_dataset, num_workers=num_workers, batch_size=train_opts.batchSize, shuffle=True) 129 | test_loader = DataLoader(dataset=test_dataset, num_workers=num_workers, batch_size=train_opts.batchSize, shuffle=True) 130 | 131 | # Visualisation Parameters 132 | visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir) 133 | error_logger = ErrorLogger() 134 | 135 | # Training Function 136 | track_labels = np.arange(len(train_dataset.label_names)) 137 | model.set_labels(track_labels) 138 | model.set_scheduler(train_opts) 139 | 140 | if hasattr(model, 'update_state'): 141 | model.update_state(0) 142 | 143 | for epoch in range(model.which_epoch, train_opts.n_epochs): 144 | print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader))) 145 | 146 | # # # --- Start --- 147 | # import matplotlib.pyplot as plt 148 | # plt.ion() 149 | # plt.figure() 150 | # target_arr = np.zeros(14) 151 | # # # --- End --- 152 | 153 | # Training Iterations 154 | for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1), total=len(train_loader)): 155 | # Make a training update 156 | model.set_input(images, labels) 157 | model.optimize_parameters() 158 | 159 | if epoch == (train_opts.n_epochs-1): 160 | import time 161 | time.sleep(36000) 162 | 163 | if train_opts.max_it == epoch_iter: 164 | break 165 | 166 | # # # --- visualise distribution --- 167 | # for lab in labels.numpy(): 168 | # target_arr[lab] += 1 169 | # plt.clf(); plt.bar(train_dataset.label_names, target_arr); plt.pause(0.01) 170 | # # # --- End --- 171 | 172 | # Visualise predictions 173 | if epoch_iter <= 100: 174 | visuals = model.get_current_visuals() 175 | visualizer.display_current_results(visuals, epoch=epoch, save_result=False) 176 | 177 | # Error visualisation 178 | errors = model.get_current_errors() 179 | error_logger.update(errors, split='train') 180 | 181 | # Validation and Testing Iterations 182 | pr_lbls = [] 183 | gt_lbls = [] 184 | for loader, split in zip([valid_loader, test_loader], ['validation', 'test']): 185 | model.reset_results() 186 | 187 | for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1), total=len(loader)): 188 | 189 | # Make a forward pass with the model 190 | model.set_input(images, labels) 191 | model.validate() 192 | 193 | # Visualise predictions 194 | visuals = model.get_current_visuals() 195 | visualizer.display_current_results(visuals, epoch=epoch, save_result=False) 196 | 197 | if train_opts.max_it == epoch_iter: 198 | break 199 | 200 | # Error visualisation 201 | errors = model.get_accumulated_errors() 202 | stats = model.get_classification_stats() 203 | error_logger.update({**errors, **stats}, split=split) 204 | 205 | # HACK save validation error 206 | if split == 'validation': 207 | valid_err = errors['CE'] 208 | 209 | # Update the plots 210 | for split in ['train', 'validation', 'test']: 211 | # exclude bckground 212 | #track_labels = np.delete(track_labels, 3) 213 | #show_labels = train_dataset.label_names[:3] + train_dataset.label_names[4:] 214 | show_labels = train_dataset.label_names 215 | visualizer.plot_current_errors(epoch, error_logger.get_errors(split), split_name=split, labels=show_labels) 216 | visualizer.print_current_errors(epoch, error_logger.get_errors(split), split_name=split) 217 | error_logger.reset() 218 | 219 | # Save the model parameters 220 | if epoch % train_opts.save_epoch_freq == 0: 221 | model.save(epoch) 222 | 223 | if hasattr(model, 'update_state'): 224 | model.update_state(epoch) 225 | 226 | # Update the model learning rate 227 | model.update_learning_rate(metric=valid_err, epoch=epoch) 228 | 229 | 230 | if __name__ == '__main__': 231 | import argparse 232 | 233 | parser = argparse.ArgumentParser(description='CNN Classification Training Function') 234 | 235 | parser.add_argument('-c', '--config', help='training config file', required=True) 236 | parser.add_argument('-d', '--debug', help='returns number of parameters and bp/fp runtime', action='store_true') 237 | args = parser.parse_args() 238 | 239 | train(args) 240 | -------------------------------------------------------------------------------- /models/layers/nonlocal_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from models.networks_other import init_weights 5 | 6 | 7 | class _NonLocalBlockND(nn.Module): 8 | def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', 9 | sub_sample_factor=4, bn_layer=True): 10 | super(_NonLocalBlockND, self).__init__() 11 | 12 | assert dimension in [1, 2, 3] 13 | assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down'] 14 | 15 | # print('Dimension: %d, mode: %s' % (dimension, mode)) 16 | 17 | self.mode = mode 18 | self.dimension = dimension 19 | self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor] 20 | 21 | self.in_channels = in_channels 22 | self.inter_channels = inter_channels 23 | 24 | if self.inter_channels is None: 25 | self.inter_channels = in_channels // 2 26 | if self.inter_channels == 0: 27 | self.inter_channels = 1 28 | 29 | if dimension == 3: 30 | conv_nd = nn.Conv3d 31 | max_pool = nn.MaxPool3d 32 | bn = nn.BatchNorm3d 33 | elif dimension == 2: 34 | conv_nd = nn.Conv2d 35 | max_pool = nn.MaxPool2d 36 | bn = nn.BatchNorm2d 37 | else: 38 | conv_nd = nn.Conv1d 39 | max_pool = nn.MaxPool1d 40 | bn = nn.BatchNorm1d 41 | 42 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 43 | kernel_size=1, stride=1, padding=0) 44 | 45 | if bn_layer: 46 | self.W = nn.Sequential( 47 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 48 | kernel_size=1, stride=1, padding=0), 49 | bn(self.in_channels) 50 | ) 51 | nn.init.constant(self.W[1].weight, 0) 52 | nn.init.constant(self.W[1].bias, 0) 53 | else: 54 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 55 | kernel_size=1, stride=1, padding=0) 56 | nn.init.constant(self.W.weight, 0) 57 | nn.init.constant(self.W.bias, 0) 58 | 59 | self.theta = None 60 | self.phi = None 61 | 62 | if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']: 63 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 64 | kernel_size=1, stride=1, padding=0) 65 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 66 | kernel_size=1, stride=1, padding=0) 67 | 68 | if mode in ['concatenation']: 69 | self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False) 70 | self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False) 71 | elif mode in ['concat_proper', 'concat_proper_down']: 72 | self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, 73 | padding=0, bias=True) 74 | 75 | if mode == 'embedded_gaussian': 76 | self.operation_function = self._embedded_gaussian 77 | elif mode == 'dot_product': 78 | self.operation_function = self._dot_product 79 | elif mode == 'gaussian': 80 | self.operation_function = self._gaussian 81 | elif mode == 'concatenation': 82 | self.operation_function = self._concatenation 83 | elif mode == 'concat_proper': 84 | self.operation_function = self._concatenation_proper 85 | elif mode == 'concat_proper_down': 86 | self.operation_function = self._concatenation_proper_down 87 | else: 88 | raise NotImplementedError('Unknown operation function.') 89 | 90 | if any(ss > 1 for ss in self.sub_sample_factor): 91 | self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor)) 92 | if self.phi is None: 93 | self.phi = max_pool(kernel_size=sub_sample_factor) 94 | else: 95 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor)) 96 | if mode == 'concat_proper_down': 97 | self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor)) 98 | 99 | # Initialise weights 100 | for m in self.children(): 101 | init_weights(m, init_type='kaiming') 102 | 103 | def forward(self, x): 104 | ''' 105 | :param x: (b, c, t, h, w) 106 | :return: 107 | ''' 108 | 109 | output = self.operation_function(x) 110 | return output 111 | 112 | def _embedded_gaussian(self, x): 113 | batch_size = x.size(0) 114 | 115 | # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) 116 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 117 | g_x = g_x.permute(0, 2, 1) 118 | 119 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) 120 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) 121 | # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) 122 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 123 | theta_x = theta_x.permute(0, 2, 1) 124 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 125 | f = torch.matmul(theta_x, phi_x) 126 | f_div_C = F.softmax(f, dim=-1) 127 | 128 | # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) 129 | y = torch.matmul(f_div_C, g_x) 130 | y = y.permute(0, 2, 1).contiguous() 131 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 132 | W_y = self.W(y) 133 | z = W_y + x 134 | 135 | return z 136 | 137 | def _gaussian(self, x): 138 | batch_size = x.size(0) 139 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 140 | g_x = g_x.permute(0, 2, 1) 141 | 142 | theta_x = x.view(batch_size, self.in_channels, -1) 143 | theta_x = theta_x.permute(0, 2, 1) 144 | 145 | if self.sub_sample_factor > 1: 146 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1) 147 | else: 148 | phi_x = x.view(batch_size, self.in_channels, -1) 149 | 150 | f = torch.matmul(theta_x, phi_x) 151 | f_div_C = F.softmax(f, dim=-1) 152 | 153 | y = torch.matmul(f_div_C, g_x) 154 | y = y.permute(0, 2, 1).contiguous() 155 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 156 | W_y = self.W(y) 157 | z = W_y + x 158 | 159 | return z 160 | 161 | def _dot_product(self, x): 162 | batch_size = x.size(0) 163 | 164 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 165 | g_x = g_x.permute(0, 2, 1) 166 | 167 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 168 | theta_x = theta_x.permute(0, 2, 1) 169 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 170 | f = torch.matmul(theta_x, phi_x) 171 | N = f.size(-1) 172 | f_div_C = f / N 173 | 174 | y = torch.matmul(f_div_C, g_x) 175 | y = y.permute(0, 2, 1).contiguous() 176 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 177 | W_y = self.W(y) 178 | z = W_y + x 179 | 180 | return z 181 | 182 | def _concatenation(self, x): 183 | batch_size = x.size(0) 184 | 185 | # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) 186 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 187 | 188 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) 189 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c) 190 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) 191 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) 192 | 193 | # theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw) 194 | # phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw) 195 | # f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw) 196 | f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \ 197 | self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1)) 198 | f = F.relu(f, inplace=True) 199 | 200 | # Normalise the relations 201 | N = f.size(-1) 202 | f_div_c = f / N 203 | 204 | # g(x_j) * f(x_j, x_i) 205 | # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) 206 | y = torch.matmul(g_x, f_div_c) 207 | y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) 208 | W_y = self.W(y) 209 | z = W_y + x 210 | 211 | return z 212 | 213 | def _concatenation_proper(self, x): 214 | batch_size = x.size(0) 215 | 216 | # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) 217 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 218 | 219 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) 220 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) 221 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 222 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 223 | 224 | # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) 225 | # phi => (b, 0.5c, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) 226 | # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) 227 | f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ 228 | phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) 229 | f = F.relu(f, inplace=True) 230 | 231 | # psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) 232 | f = torch.squeeze(self.psi(f), dim=1) 233 | 234 | # Normalise the relations 235 | f_div_c = F.softmax(f, dim=1) 236 | 237 | # g(x_j) * f(x_j, x_i) 238 | # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) 239 | y = torch.matmul(g_x, f_div_c) 240 | y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) 241 | W_y = self.W(y) 242 | z = W_y + x 243 | 244 | return z 245 | 246 | def _concatenation_proper_down(self, x): 247 | batch_size = x.size(0) 248 | 249 | # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) 250 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 251 | 252 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) 253 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) 254 | theta_x = self.theta(x) 255 | downsampled_size = theta_x.size() 256 | theta_x = theta_x.view(batch_size, self.inter_channels, -1) 257 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 258 | 259 | # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) 260 | # phi => (b, 0.5, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) 261 | # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) 262 | f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ 263 | phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) 264 | f = F.relu(f, inplace=True) 265 | 266 | # psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) 267 | f = torch.squeeze(self.psi(f), dim=1) 268 | 269 | # Normalise the relations 270 | f_div_c = F.softmax(f, dim=1) 271 | 272 | # g(x_j) * f(x_j, x_i) 273 | # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) 274 | y = torch.matmul(g_x, f_div_c) 275 | y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:]) 276 | 277 | # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3) 278 | y = F.upsample(y, size=x.size()[2:], mode='trilinear') 279 | 280 | # attention block output 281 | W_y = self.W(y) 282 | z = W_y + x 283 | 284 | return z 285 | 286 | 287 | class NONLocalBlock1D(_NonLocalBlockND): 288 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True): 289 | super(NONLocalBlock1D, self).__init__(in_channels, 290 | inter_channels=inter_channels, 291 | dimension=1, mode=mode, 292 | sub_sample_factor=sub_sample_factor, 293 | bn_layer=bn_layer) 294 | 295 | 296 | class NONLocalBlock2D(_NonLocalBlockND): 297 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True): 298 | super(NONLocalBlock2D, self).__init__(in_channels, 299 | inter_channels=inter_channels, 300 | dimension=2, mode=mode, 301 | sub_sample_factor=sub_sample_factor, 302 | bn_layer=bn_layer) 303 | 304 | 305 | class NONLocalBlock3D(_NonLocalBlockND): 306 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True): 307 | super(NONLocalBlock3D, self).__init__(in_channels, 308 | inter_channels=inter_channels, 309 | dimension=3, mode=mode, 310 | sub_sample_factor=sub_sample_factor, 311 | bn_layer=bn_layer) 312 | 313 | 314 | if __name__ == '__main__': 315 | from torch.autograd import Variable 316 | 317 | mode_list = ['concatenation'] 318 | #mode_list = ['embedded_gaussian', 'gaussian', 'dot_product', ] 319 | 320 | for mode in mode_list: 321 | print(mode) 322 | img = Variable(torch.zeros(2, 4, 5)) 323 | net = NONLocalBlock1D(4, mode=mode, sub_sample_factor=2) 324 | out = net(img) 325 | print(out.size()) 326 | 327 | img = Variable(torch.zeros(2, 4, 5, 3)) 328 | net = NONLocalBlock2D(4, mode=mode, sub_sample_factor=1, bn_layer=False) 329 | out = net(img) 330 | print(out.size()) 331 | 332 | img = Variable(torch.zeros(2, 4, 5, 4, 5)) 333 | net = NONLocalBlock3D(4, mode=mode) 334 | out = net(img) 335 | print(out.size()) 336 | -------------------------------------------------------------------------------- /dataio/transformation/myImageTransformations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.ndimage 4 | from scipy.ndimage.filters import gaussian_filter 5 | from scipy.ndimage.interpolation import map_coordinates 6 | import collections 7 | from PIL import Image 8 | import numbers 9 | 10 | 11 | def center_crop(x, center_crop_size): 12 | assert x.ndim == 3 13 | centerw, centerh = x.shape[1] // 2, x.shape[2] // 2 14 | halfw, halfh = center_crop_size[0] // 2, center_crop_size[1] // 2 15 | return x[:, centerw - halfw:centerw + halfw, centerh - halfh:centerh + halfh] 16 | 17 | 18 | def to_tensor(x): 19 | import torch 20 | x = x.transpose((2, 0, 1)) 21 | print(x.shape) 22 | return torch.from_numpy(x).float() 23 | 24 | 25 | def random_num_generator(config, random_state=np.random): 26 | if config[0] == 'uniform': 27 | ret = random_state.uniform(config[1], config[2], 1)[0] 28 | elif config[0] == 'lognormal': 29 | ret = random_state.lognormal(config[1], config[2], 1)[0] 30 | else: 31 | print(config) 32 | raise Exception('unsupported format') 33 | return ret 34 | 35 | 36 | def poisson_downsampling(image, peak, random_state=np.random): 37 | if not isinstance(image, np.ndarray): 38 | imgArr = np.array(image, dtype='float32') 39 | else: 40 | imgArr = image.astype('float32') 41 | Q = imgArr.max(axis=(0, 1)) / peak 42 | if Q[0] == 0: 43 | return imgArr 44 | ima_lambda = imgArr / Q 45 | noisy_img = random_state.poisson(lam=ima_lambda) 46 | return noisy_img.astype('float32') 47 | 48 | 49 | def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random): 50 | """Elastic deformation of image as described in [Simard2003]_. 51 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 52 | Convolutional Neural Networks applied to Visual Document Analysis", in 53 | Proc. of the International Conference on Document Analysis and 54 | Recognition, 2003. 55 | """ 56 | assert image.ndim == 3 57 | shape = image.shape[:2] 58 | 59 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), 60 | sigma, mode="constant", cval=0) * alpha 61 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), 62 | sigma, mode="constant", cval=0) * alpha 63 | 64 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 65 | indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] 66 | result = np.empty_like(image) 67 | for i in range(image.shape[2]): 68 | result[:, :, i] = map_coordinates( 69 | image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape) 70 | return result 71 | 72 | 73 | class Merge(object): 74 | """Merge a group of images 75 | """ 76 | 77 | def __init__(self, axis=-1): 78 | self.axis = axis 79 | 80 | def __call__(self, images): 81 | if isinstance(images, collections.Sequence) or isinstance(images, np.ndarray): 82 | assert all([isinstance(i, np.ndarray) 83 | for i in images]), 'only numpy array is supported' 84 | shapes = [list(i.shape) for i in images] 85 | for s in shapes: 86 | s[self.axis] = None 87 | assert all([s == shapes[0] for s in shapes] 88 | ), 'shapes must be the same except the merge axis' 89 | return np.concatenate(images, axis=self.axis) 90 | else: 91 | raise Exception("obj is not a sequence (list, tuple, etc)") 92 | 93 | 94 | class Split(object): 95 | """Split images into individual arraies 96 | """ 97 | 98 | def __init__(self, *slices, **kwargs): 99 | assert isinstance(slices, collections.Sequence) 100 | slices_ = [] 101 | for s in slices: 102 | if isinstance(s, collections.Sequence): 103 | slices_.append(slice(*s)) 104 | else: 105 | slices_.append(s) 106 | assert all([isinstance(s, slice) for s in slices_] 107 | ), 'slices must be consist of slice instances' 108 | self.slices = slices_ 109 | self.axis = kwargs.get('axis', -1) 110 | 111 | def __call__(self, image): 112 | if isinstance(image, np.ndarray): 113 | ret = [] 114 | for s in self.slices: 115 | sl = [slice(None)] * image.ndim 116 | sl[self.axis] = s 117 | ret.append(image[sl]) 118 | return ret 119 | else: 120 | raise Exception("obj is not an numpy array") 121 | 122 | 123 | class ElasticTransform(object): 124 | """Apply elastic transformation on a numpy.ndarray (H x W x C) 125 | """ 126 | 127 | def __init__(self, alpha, sigma): 128 | self.alpha = alpha 129 | self.sigma = sigma 130 | 131 | def __call__(self, image): 132 | if isinstance(self.alpha, collections.Sequence): 133 | alpha = random_num_generator(self.alpha) 134 | else: 135 | alpha = self.alpha 136 | if isinstance(self.sigma, collections.Sequence): 137 | sigma = random_num_generator(self.sigma) 138 | else: 139 | sigma = self.sigma 140 | return elastic_transform(image, alpha=alpha, sigma=sigma) 141 | 142 | 143 | class PoissonSubsampling(object): 144 | """Poisson subsampling on a numpy.ndarray (H x W x C) 145 | """ 146 | 147 | def __init__(self, peak, random_state=np.random): 148 | self.peak = peak 149 | self.random_state = random_state 150 | 151 | def __call__(self, image): 152 | if isinstance(self.peak, collections.Sequence): 153 | peak = random_num_generator( 154 | self.peak, random_state=self.random_state) 155 | else: 156 | peak = self.peak 157 | return poisson_downsampling(image, peak, random_state=self.random_state) 158 | 159 | 160 | class AddGaussianNoise(object): 161 | """Add gaussian noise to a numpy.ndarray (H x W x C) 162 | """ 163 | 164 | def __init__(self, mean, sigma, random_state=np.random): 165 | self.sigma = sigma 166 | self.mean = mean 167 | self.random_state = random_state 168 | 169 | def __call__(self, image): 170 | if isinstance(self.sigma, collections.Sequence): 171 | sigma = random_num_generator(self.sigma, random_state=self.random_state) 172 | else: 173 | sigma = self.sigma 174 | if isinstance(self.mean, collections.Sequence): 175 | mean = random_num_generator(self.mean, random_state=self.random_state) 176 | else: 177 | mean = self.mean 178 | row, col, ch = image.shape 179 | gauss = self.random_state.normal(mean, sigma, (row, col, ch)) 180 | gauss = gauss.reshape(row, col, ch) 181 | image += gauss 182 | return image 183 | 184 | 185 | class AddSpeckleNoise(object): 186 | """Add speckle noise to a numpy.ndarray (H x W x C) 187 | """ 188 | 189 | def __init__(self, mean, sigma, random_state=np.random): 190 | self.sigma = sigma 191 | self.mean = mean 192 | self.random_state = random_state 193 | 194 | def __call__(self, image): 195 | if isinstance(self.sigma, collections.Sequence): 196 | sigma = random_num_generator( 197 | self.sigma, random_state=self.random_state) 198 | else: 199 | sigma = self.sigma 200 | if isinstance(self.mean, collections.Sequence): 201 | mean = random_num_generator( 202 | self.mean, random_state=self.random_state) 203 | else: 204 | mean = self.mean 205 | row, col, ch = image.shape 206 | gauss = self.random_state.normal(mean, sigma, (row, col, ch)) 207 | gauss = gauss.reshape(row, col, ch) 208 | image += image * gauss 209 | return image 210 | 211 | 212 | class GaussianBlurring(object): 213 | """Apply gaussian blur to a numpy.ndarray (H x W x C) 214 | """ 215 | 216 | def __init__(self, sigma, random_state=np.random): 217 | self.sigma = sigma 218 | self.random_state = random_state 219 | 220 | def __call__(self, image): 221 | if isinstance(self.sigma, collections.Sequence): 222 | sigma = random_num_generator( 223 | self.sigma, random_state=self.random_state) 224 | else: 225 | sigma = self.sigma 226 | image = gaussian_filter(image, sigma=(sigma, sigma, 0)) 227 | return image 228 | 229 | 230 | class AddGaussianPoissonNoise(object): 231 | """Add poisson noise with gaussian blurred image to a numpy.ndarray (H x W x C) 232 | """ 233 | 234 | def __init__(self, sigma, peak, random_state=np.random): 235 | self.sigma = sigma 236 | self.peak = peak 237 | self.random_state = random_state 238 | 239 | def __call__(self, image): 240 | if isinstance(self.sigma, collections.Sequence): 241 | sigma = random_num_generator( 242 | self.sigma, random_state=self.random_state) 243 | else: 244 | sigma = self.sigma 245 | if isinstance(self.peak, collections.Sequence): 246 | peak = random_num_generator( 247 | self.peak, random_state=self.random_state) 248 | else: 249 | peak = self.peak 250 | bg = gaussian_filter(image, sigma=(sigma, sigma, 0)) 251 | bg = poisson_downsampling( 252 | bg, peak=peak, random_state=self.random_state) 253 | return image + bg 254 | 255 | 256 | class MaxScaleNumpy(object): 257 | """scale with max and min of each channel of the numpy array i.e. 258 | channel = (channel - mean) / std 259 | """ 260 | 261 | def __init__(self, range_min=0.0, range_max=1.0): 262 | self.scale = (range_min, range_max) 263 | 264 | def __call__(self, image): 265 | mn = image.min(axis=(0, 1)) 266 | mx = image.max(axis=(0, 1)) 267 | return self.scale[0] + (image - mn) * (self.scale[1] - self.scale[0]) / (mx - mn) 268 | 269 | 270 | class MedianScaleNumpy(object): 271 | """Scale with median and mean of each channel of the numpy array i.e. 272 | channel = (channel - mean) / std 273 | """ 274 | 275 | def __init__(self, range_min=0.0, range_max=1.0): 276 | self.scale = (range_min, range_max) 277 | 278 | def __call__(self, image): 279 | mn = image.min(axis=(0, 1)) 280 | md = np.median(image, axis=(0, 1)) 281 | return self.scale[0] + (image - mn) * (self.scale[1] - self.scale[0]) / (md - mn) 282 | 283 | 284 | class NormalizeNumpy(object): 285 | """Normalize each channel of the numpy array i.e. 286 | channel = (channel - mean) / std 287 | """ 288 | 289 | def __call__(self, image): 290 | image -= image.mean(axis=(0, 1)) 291 | s = image.std(axis=(0, 1)) 292 | s[s == 0] = 1.0 293 | image /= s 294 | return image 295 | 296 | 297 | class MutualExclude(object): 298 | """Remove elements from one channel 299 | """ 300 | 301 | def __init__(self, exclude_channel, from_channel): 302 | self.from_channel = from_channel 303 | self.exclude_channel = exclude_channel 304 | 305 | def __call__(self, image): 306 | mask = image[:, :, self.exclude_channel] > 0 307 | image[:, :, self.from_channel][mask] = 0 308 | return image 309 | 310 | 311 | class RandomCropNumpy(object): 312 | """Crops the given numpy array at a random location to have a region of 313 | the given size. size can be a tuple (target_height, target_width) 314 | or an integer, in which case the target will be of a square shape (size, size) 315 | """ 316 | 317 | def __init__(self, size, random_state=np.random): 318 | if isinstance(size, numbers.Number): 319 | self.size = (int(size), int(size)) 320 | else: 321 | self.size = size 322 | self.random_state = random_state 323 | 324 | def __call__(self, img): 325 | w, h = img.shape[:2] 326 | th, tw = self.size 327 | if w == tw and h == th: 328 | return img 329 | 330 | x1 = self.random_state.randint(0, w - tw) 331 | y1 = self.random_state.randint(0, h - th) 332 | return img[x1:x1 + tw, y1: y1 + th, :] 333 | 334 | 335 | class CenterCropNumpy(object): 336 | """Crops the given numpy array at the center to have a region of 337 | the given size. size can be a tuple (target_height, target_width) 338 | or an integer, in which case the target will be of a square shape (size, size) 339 | """ 340 | 341 | def __init__(self, size): 342 | if isinstance(size, numbers.Number): 343 | self.size = (int(size), int(size)) 344 | else: 345 | self.size = size 346 | 347 | def __call__(self, img): 348 | w, h = img.shape[:2] 349 | th, tw = self.size 350 | x1 = int(round((w - tw) / 2.)) 351 | y1 = int(round((h - th) / 2.)) 352 | return img[x1:x1 + tw, y1: y1 + th, :] 353 | 354 | 355 | class RandomRotate(object): 356 | """Rotate a PIL.Image or numpy.ndarray (H x W x C) randomly 357 | """ 358 | 359 | def __init__(self, angle_range=(0.0, 360.0), axes=(0, 1), mode='reflect', random_state=np.random): 360 | assert isinstance(angle_range, tuple) 361 | self.angle_range = angle_range 362 | self.random_state = random_state 363 | self.axes = axes 364 | self.mode = mode 365 | 366 | def __call__(self, image): 367 | angle = self.random_state.uniform( 368 | self.angle_range[0], self.angle_range[1]) 369 | if isinstance(image, np.ndarray): 370 | mi, ma = image.min(), image.max() 371 | image = scipy.ndimage.interpolation.rotate( 372 | image, angle, reshape=False, axes=self.axes, mode=self.mode) 373 | return np.clip(image, mi, ma) 374 | elif isinstance(image, Image.Image): 375 | return image.rotate(angle) 376 | else: 377 | raise Exception('unsupported type') 378 | 379 | 380 | class BilinearResize(object): 381 | """Resize a PIL.Image or numpy.ndarray (H x W x C) 382 | """ 383 | 384 | def __init__(self, zoom): 385 | self.zoom = [zoom, zoom, 1] 386 | 387 | def __call__(self, image): 388 | if isinstance(image, np.ndarray): 389 | return scipy.ndimage.interpolation.zoom(image, self.zoom) 390 | elif isinstance(image, Image.Image): 391 | return image.resize(self.size, Image.BILINEAR) 392 | else: 393 | raise Exception('unsupported type') 394 | 395 | 396 | class EnhancedCompose(object): 397 | """Composes several transforms together. 398 | Args: 399 | transforms (List[Transform]): list of transforms to compose. 400 | Example: 401 | >>> transforms.Compose([ 402 | >>> transforms.CenterCrop(10), 403 | >>> transforms.ToTensor(), 404 | >>> ]) 405 | """ 406 | 407 | def __init__(self, transforms): 408 | self.transforms = transforms 409 | 410 | def __call__(self, img): 411 | for t in self.transforms: 412 | if isinstance(t, collections.Sequence): 413 | assert isinstance(img, collections.Sequence) and len(img) == len( 414 | t), "size of image group and transform group does not fit" 415 | tmp_ = [] 416 | for i, im_ in enumerate(img): 417 | if callable(t[i]): 418 | tmp_.append(t[i](im_)) 419 | else: 420 | tmp_.append(im_) 421 | img = tmp_ 422 | elif callable(t): 423 | img = t(img) 424 | elif t is None: 425 | continue 426 | else: 427 | raise Exception('unexpected type') 428 | return img 429 | 430 | 431 | if __name__ == '__main__': 432 | from torchvision.transforms import Lambda 433 | 434 | input_channel = 3 435 | target_channel = 3 436 | 437 | # define a transform pipeline 438 | transform = EnhancedCompose([ 439 | Merge(), 440 | RandomCropNumpy(size=(512, 512)), 441 | RandomRotate(), 442 | Split([0, input_channel], [input_channel, input_channel + target_channel]), 443 | [CenterCropNumpy(size=(256, 256)), CenterCropNumpy(size=(256, 256))], 444 | [NormalizeNumpy(), MaxScaleNumpy(0, 1.0)], 445 | # for non-pytorch usage, remove to_tensor conversion 446 | [Lambda(to_tensor), Lambda(to_tensor)] 447 | ]) 448 | # read input dataio for test 449 | image_in = np.array(Image.open('input.jpg')) 450 | image_target = np.array(Image.open('target.jpg')) 451 | 452 | # apply the transform 453 | x, y = transform([image_in, image_target]) --------------------------------------------------------------------------------