├── dataio
├── __init__.py
├── loader
│ ├── hms_dataset.py
│ ├── __init__.py
│ ├── test_dataset.py
│ ├── utils.py
│ ├── cmr_3D_dataset.py
│ ├── ukbb_dataset.py
│ └── us_dataset.py
└── transformation
│ ├── __init__.py
│ ├── transforms.py
│ └── myImageTransformations.py
├── utils
├── __init__.py
├── html.py
├── util.py
├── error_logger.py
├── post_process_crf.py
├── metrics.py
└── visualiser.py
├── models
├── layers
│ ├── __init__.py
│ ├── loss.py
│ └── nonlocal_layer.py
├── networks
│ ├── unet_2D.py
│ ├── unet_3D.py
│ ├── __init__.py
│ ├── sononet.py
│ ├── unet_nonlocal_2D.py
│ ├── unet_nonlocal_3D.py
│ ├── unet_CT_dsv_3D.py
│ ├── unet_grid_attention_3D.py
│ ├── sononet_grid_attention.py
│ ├── unet_CT_single_att_dsv_3D.py
│ └── unet_CT_multi_att_dsv_3D.py
├── __init__.py
├── base_model.py
├── aggregated_classifier.py
├── utils.py
├── feedforward_seg_model.py
└── feedforward_classifier.py
├── figures
├── figure1.png
└── figure2.jpg
├── setup.py
├── LICENSE
├── README.md
├── configs
├── config_unet_ct_dsv.json
├── config_unet_ct_multi_att_dsv.json
├── config_sononet_8.json
├── config_sononet_grid_att_8_ft.json
├── config_sononet_grid_att_8.json
└── config_sononet_grid_att_8_deepsup.json
├── .gitignore
├── visualise_fmaps.py
├── visualise_att_maps_epoch.py
├── validation.py
├── train_segmentation.py
├── test_classification.py
├── visualise_attention.py
└── train_classifaction.py
/dataio/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figures/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozan-oktay/Attention-Gated-Networks/HEAD/figures/figure1.png
--------------------------------------------------------------------------------
/figures/figure2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ozan-oktay/Attention-Gated-Networks/HEAD/figures/figure2.jpg
--------------------------------------------------------------------------------
/dataio/loader/hms_dataset.py:
--------------------------------------------------------------------------------
1 | # Author: Ozan Oktay
2 | # Date: January 2018
3 |
4 | class HMSDataset:
5 |
6 | def __init__(self):
7 | raise NotImplemented
--------------------------------------------------------------------------------
/dataio/transformation/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 | from dataio.transformation.transforms import Transformations
3 |
4 |
5 | def get_dataset_transformation(name, opts=None):
6 | '''
7 | :param opts: augmentation parameters
8 | :return:
9 | '''
10 | # Build the transformation object and initialise the augmentation parameters
11 | trans_obj = Transformations(name)
12 | if opts: trans_obj.initialise(opts)
13 |
14 | # Print the input options
15 | trans_obj.print()
16 |
17 | # Returns a dictionary of transformations
18 | return trans_obj.get_transformation()
19 |
--------------------------------------------------------------------------------
/dataio/loader/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from dataio.loader.ukbb_dataset import UKBBDataset
4 | from dataio.loader.test_dataset import TestDataset
5 | from dataio.loader.hms_dataset import HMSDataset
6 | from dataio.loader.cmr_3D_dataset import CMR3DDataset
7 | from dataio.loader.us_dataset import UltraSoundDataset
8 |
9 |
10 | def get_dataset(name):
11 | """get_dataset
12 |
13 | :param name:
14 | """
15 | return {
16 | 'ukbb_sax': CMR3DDataset,
17 | 'acdc_sax': CMR3DDataset,
18 | 'rvsc_sax': CMR3DDataset,
19 | 'hms_sax': HMSDataset,
20 | 'test_sax': TestDataset,
21 | 'us': UltraSoundDataset
22 | }[name]
23 |
24 |
25 | def get_dataset_path(dataset_name, opts):
26 | """get_data_path
27 |
28 | :param dataset_name:
29 | :param opts:
30 | """
31 |
32 | return getattr(opts, dataset_name)
33 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup, find_packages
4 |
5 | with open('README.md') as f:
6 | readme = f.read()
7 |
8 | setup(name='AttentionGatedNetworks',
9 | version='1.0',
10 | description='Pytorch library for Soft Attention',
11 | long_description=readme,
12 | author='Ozan Oktay & Jo Schlemper',
13 | install_requires=[
14 | "numpy",
15 | "torch",
16 | "matplotlib",
17 | "scipy",
18 | "torchvision",
19 | "tqdm",
20 | "visdom",
21 | "nibabel",
22 | "scikit-image",
23 | "h5py",
24 | "pandas",
25 | "dominate",
26 | 'torchsample==0.1.3',
27 | ],
28 | dependency_links=[
29 | 'https://github.com/ozan-oktay/torchsample/tarball/master#egg=torchsample-0.1.3'
30 | ],
31 | packages=find_packages(exclude=('tests', 'docs'))
32 | )
33 |
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Ozan Oktay
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Attention Gated Networks
(Image Classification & Segmentation)
2 |
3 | Pytorch implementation of attention gates used in U-Net and VGG-16 models. The framework can be utilised in both medical image classification and segmentation tasks.
4 |
5 |
6 |
7 | The schematics of the proposed Attention-Gated Sononet
8 |
9 |
10 |
11 |
12 | The schematics of the proposed additive attention gate
13 |
14 |
15 | ### References:
16 |
17 | 1) "Attention-Gated Networks for Improving Ultrasound Scan Plane Detection", MIDL'18, Amsterdam
18 | [Conference Paper](https://openreview.net/pdf?id=BJtn7-3sM)
19 | [Conference Poster](https://www.doc.ic.ac.uk/~oo2113/posters/MIDL2018_poster_Jo.pdf)
20 |
21 | 2) "Attention U-Net: Learning Where to Look for the Pancreas", MIDL'18, Amsterdam
22 | [Conference Paper](https://openreview.net/pdf?id=Skft7cijM)
23 | [Conference Poster](https://www.doc.ic.ac.uk/~oo2113/posters/MIDL2018_poster.pdf)
24 |
25 | ### Installation
26 | pip install --process-dependency-links -e .
27 |
28 |
--------------------------------------------------------------------------------
/configs/config_unet_ct_dsv.json:
--------------------------------------------------------------------------------
1 | {
2 | "training":{
3 | "arch_type": "acdc_sax",
4 | "n_epochs": 1000,
5 | "save_epoch_freq": 10,
6 | "lr_policy": "step",
7 | "lr_decay_iters": 250,
8 | "batchSize": 2,
9 | "preloadData": true
10 | },
11 | "visualisation":{
12 | "display_port": 8098,
13 | "no_html": true,
14 | "display_winsize": 256,
15 | "display_id": 1,
16 | "display_single_pane_ncols": 0
17 | },
18 | "data_path": {
19 | "acdc_sax": "/vol/biomedic2/oo2113/dataset/ken_abdominal_ct/"
20 | },
21 | "augmentation": {
22 | "acdc_sax": {
23 | "shift": [0.1,0.1],
24 | "rotate": 15.0,
25 | "scale": [0.7,1.3],
26 | "intensity": [1.0,1.0],
27 | "random_flip_prob": 0.5,
28 | "scale_size": [160,160,96],
29 | "patch_size": [160,160,96]
30 | }
31 | },
32 | "model":{
33 | "type":"seg",
34 | "continue_train": false,
35 | "which_epoch": -1,
36 | "model_type": "unet_ct_dsv",
37 | "tensor_dim": "3D",
38 | "division_factor": 16,
39 | "input_nc": 1,
40 | "output_nc": 4,
41 | "lr_rate": 1e-4,
42 | "l2_reg_weight": 1e-6,
43 | "feature_scale": 4,
44 | "gpu_ids": [0],
45 | "isTrain": true,
46 | "checkpoints_dir": "./checkpoints",
47 | "experiment_name": "experiment_unet_ct_dsv_big",
48 | "criterion": "dice_loss"
49 | }
50 | }
51 |
52 |
53 |
--------------------------------------------------------------------------------
/configs/config_unet_ct_multi_att_dsv.json:
--------------------------------------------------------------------------------
1 | {
2 | "training":{
3 | "arch_type": "acdc_sax",
4 | "n_epochs": 1000,
5 | "save_epoch_freq": 10,
6 | "lr_policy": "step",
7 | "lr_decay_iters": 250,
8 | "batchSize": 2,
9 | "preloadData": true
10 | },
11 | "visualisation":{
12 | "display_port": 8099,
13 | "no_html": true,
14 | "display_winsize": 256,
15 | "display_id": 1,
16 | "display_single_pane_ncols": 0
17 | },
18 | "data_path": {
19 | "acdc_sax": "/vol/biomedic2/oo2113/dataset/ken_abdominal_ct/"
20 | },
21 | "augmentation": {
22 | "acdc_sax": {
23 | "shift": [0.1,0.1],
24 | "rotate": 15.0,
25 | "scale": [0.7,1.3],
26 | "intensity": [1.0,1.0],
27 | "random_flip_prob": 0.5,
28 | "scale_size": [160,160,96],
29 | "patch_size": [160,160,96]
30 | }
31 | },
32 | "model":{
33 | "type":"seg",
34 | "continue_train": false,
35 | "which_epoch": -1,
36 | "model_type": "unet_ct_multi_att_dsv",
37 | "tensor_dim": "3D",
38 | "division_factor": 16,
39 | "input_nc": 1,
40 | "output_nc": 4,
41 | "lr_rate": 1e-4,
42 | "l2_reg_weight": 1e-6,
43 | "feature_scale": 4,
44 | "gpu_ids": [0],
45 | "isTrain": true,
46 | "checkpoints_dir": "./checkpoints",
47 | "experiment_name": "experiment_unet_ct_multi_att_dsv",
48 | "criterion": "dice_loss"
49 | }
50 | }
51 |
52 |
53 |
--------------------------------------------------------------------------------
/configs/config_sononet_8.json:
--------------------------------------------------------------------------------
1 | {
2 | "training":{
3 | "max_it":10,
4 | "arch_type": "us",
5 | "n_epochs": 300,
6 | "save_epoch_freq": 10,
7 | "lr_policy": "step_warmstart",
8 | "lr_decay_iters": 25,
9 | "lr_red_factor": 0.1,
10 | "batchSize": 64,
11 | "preloadData": false,
12 | "num_workers" : 8,
13 | "sampler": "weighted2",
14 | "bgd_weight_multiplier": 13
15 | },
16 | "visualisation":{
17 | "display_port": 8181,
18 | "no_html": true,
19 | "display_winsize": 256,
20 | "display_id": 1,
21 | "display_single_pane_ncols": 0
22 | },
23 | "data_path": {
24 | "us": "/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5"
25 | },
26 | "augmentation": {
27 | "us": {
28 | "patch_size": [208, 272],
29 | "shift": [0.02,0.02],
30 | "rotate": 25.0,
31 | "scale": [0.7,1.3],
32 | "intensity": [1.0,1.0],
33 | "random_flip_prob": 0.5
34 | }
35 | },
36 | "model":{
37 | "type":"classifier",
38 | "continue_train": false,
39 | "which_epoch": 0,
40 | "model_type": "sononet2",
41 | "tensor_dim": "2D",
42 | "input_nc": 1,
43 | "output_nc": 14,
44 | "lr_rate": 0.1,
45 | "l2_reg_weight": 1e-6,
46 | "feature_scale": 8,
47 | "gpu_ids": [0],
48 | "isTrain": true,
49 | "checkpoints_dir": "./checkpoints",
50 | "experiment_name": "experiment_sononet_8"
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/dataio/loader/test_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import numpy as np
3 | import os
4 |
5 | from os import listdir
6 | from os.path import join
7 | from .utils import load_nifti_img, check_exceptions, is_image_file
8 |
9 |
10 | class TestDataset(data.Dataset):
11 | def __init__(self, root_dir, transform):
12 | super(TestDataset, self).__init__()
13 | image_dir = join(root_dir, 'image')
14 | self.image_filenames = sorted([join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)])
15 |
16 | # Add the corresponding ground-truth images if they exist
17 | self.label_filenames = []
18 | label_dir = join(root_dir, 'label')
19 | if os.path.isdir(label_dir):
20 | self.label_filenames = sorted([join(label_dir, x) for x in listdir(label_dir) if is_image_file(x)])
21 | assert len(self.label_filenames) == len(self.image_filenames)
22 |
23 | # data pre-processing
24 | self.transform = transform
25 |
26 | # report the number of images in the dataset
27 | print('Number of test images: {0} NIFTIs'.format(self.__len__()))
28 |
29 | def __getitem__(self, index):
30 |
31 | # load the NIFTI images
32 | input, input_meta = load_nifti_img(self.image_filenames[index], dtype=np.int16)
33 |
34 | # load the label image if it exists
35 | if self.label_filenames:
36 | label, _ = load_nifti_img(self.label_filenames[index], dtype=np.int16)
37 | check_exceptions(input, label)
38 | else:
39 | label = []
40 | check_exceptions(input)
41 |
42 | # Pre-process the input 3D Nifti image
43 | input = self.transform(input)
44 |
45 | return input, input_meta, label
46 |
47 | def __len__(self):
48 | return len(self.image_filenames)
--------------------------------------------------------------------------------
/configs/config_sononet_grid_att_8_ft.json:
--------------------------------------------------------------------------------
1 | {
2 | "training":{
3 | "max_it":100,
4 | "arch_type": "us",
5 | "n_epochs": 300,
6 | "save_epoch_freq": 1,
7 | "lr_policy": "step_warmstart",
8 | "lr_decay_iters": 25,
9 | "lr_red_factor": 0.1,
10 | "batchSize": 64,
11 | "preloadData": false,
12 | "num_workers" : 8,
13 | "sampler": "weighted2",
14 | "bgd_weight_multiplier": 13
15 | },
16 | "visualisation":{
17 | "display_port": 8181,
18 | "no_html": true,
19 | "display_winsize": 256,
20 | "display_id": 1,
21 | "display_single_pane_ncols": 0
22 | },
23 | "data_path": {
24 | "us": "/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5"
25 | },
26 | "augmentation": {
27 | "us": {
28 | "patch_size": [208, 272],
29 | "shift": [0.02,0.02],
30 | "rotate": 25.0,
31 | "scale": [0.7,1.3],
32 | "intensity": [1.0,1.0],
33 | "random_flip_prob": 0.5
34 | }
35 | },
36 | "model":{
37 | "type":"aggregated_classifier",
38 | "criterion":"cross_entropy",
39 | "model_type": "sononet_grid_attention",
40 | "nonlocal_mode": "concatenation_mean_flow",
41 | "aggregation_mode": "ft",
42 | "weight":[1],
43 | "aggregation":"mean",
44 | "continue_train": false,
45 | "which_epoch": 0,
46 | "tensor_dim": "2D",
47 | "input_nc": 1,
48 | "output_nc": 14,
49 | "lr_rate": 0.1,
50 | "l2_reg_weight": 1e-6,
51 | "feature_scale": 8,
52 | "gpu_ids": [0],
53 | "isTrain": true,
54 | "checkpoints_dir": "./checkpoints",
55 | "experiment_name": "experiment_sononet_grid_attention_fs8_avg_v12"
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/configs/config_sononet_grid_att_8.json:
--------------------------------------------------------------------------------
1 | {
2 | "training":{
3 | "max_it":10,
4 | "arch_type": "us",
5 | "n_epochs": 300,
6 | "save_epoch_freq": 1,
7 | "lr_policy": "step_warmstart",
8 | "lr_decay_iters": 25,
9 | "lr_red_factor": 0.1,
10 | "batchSize": 64,
11 | "preloadData": false,
12 | "num_workers" : 8,
13 | "sampler": "weighted2",
14 | "bgd_weight_multiplier": 13
15 | },
16 | "visualisation":{
17 | "display_port": 8181,
18 | "no_html": true,
19 | "display_winsize": 256,
20 | "display_id": 1,
21 | "display_single_pane_ncols": 0
22 | },
23 | "data_path": {
24 | "us": "/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5"
25 | },
26 | "augmentation": {
27 | "us": {
28 | "patch_size": [208, 272],
29 | "shift": [0.02,0.02],
30 | "rotate": 25.0,
31 | "scale": [0.7,1.3],
32 | "intensity": [1.0,1.0],
33 | "random_flip_prob": 0.5
34 | }
35 | },
36 | "model":{
37 | "type":"aggregated_classifier",
38 | "criterion":"cross_entropy",
39 | "model_type": "sononet_grid_attention",
40 | "nonlocal_mode": "concatenation_mean_flow",
41 | "aggregation_mode": "mean",
42 | "weight":[1, 1, 1],
43 | "aggregation":"mean",
44 | "continue_train": false,
45 | "which_epoch": 0,
46 | "tensor_dim": "2D",
47 | "input_nc": 1,
48 | "output_nc": 14,
49 | "lr_rate": 0.1,
50 | "l2_reg_weight": 1e-6,
51 | "feature_scale": 8,
52 | "gpu_ids": [0],
53 | "isTrain": true,
54 | "checkpoints_dir": "./checkpoints",
55 | "experiment_name": "experiment_sononet_grid_attention_fs8_avg_v12"
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.pyc
6 |
7 | # Saved results
8 | checkpoints/
9 | checkpoints_2/
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | .hypothesis/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | .static_storage/
61 | .media/
62 | local_settings.py
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 |
111 | *~
112 | .#*
113 | flycheck*
--------------------------------------------------------------------------------
/configs/config_sononet_grid_att_8_deepsup.json:
--------------------------------------------------------------------------------
1 | {
2 | "training":{
3 | "max_it":100,
4 | "arch_type": "us",
5 | "n_epochs": 300,
6 | "save_epoch_freq": 1,
7 | "lr_policy": "step_warmstart",
8 | "lr_decay_iters": 25,
9 | "lr_red_factor": 0.1,
10 | "batchSize": 64,
11 | "preloadData": false,
12 | "num_workers" : 8,
13 | "sampler": "weighted2",
14 | "bgd_weight_multiplier": 13
15 | },
16 | "visualisation":{
17 | "display_port": 8181,
18 | "no_html": true,
19 | "display_winsize": 256,
20 | "display_id": 1,
21 | "display_single_pane_ncols": 0
22 | },
23 | "data_path": {
24 | "us": "/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5"
25 | },
26 | "augmentation": {
27 | "us": {
28 | "patch_size": [208, 272],
29 | "shift": [0.02,0.02],
30 | "rotate": 25.0,
31 | "scale": [0.7,1.3],
32 | "intensity": [1.0,1.0],
33 | "random_flip_prob": 0.5
34 | }
35 | },
36 | "model":{
37 | "type":"aggregated_classifier",
38 | "criterion":"cross_entropy",
39 | "model_type": "sononet_grid_attention",
40 | "nonlocal_mode": "concatenation_mean_flow",
41 | "aggregation_mode": "deep_sup",
42 | "weight":[1, 0.1, 0.1, 0.1],
43 | "aggregation":"idx",
44 | "aggregation_param":0,
45 | "continue_train": false,
46 | "which_epoch": 0,
47 | "tensor_dim": "2D",
48 | "input_nc": 1,
49 | "output_nc": 14,
50 | "lr_rate": 0.1,
51 | "l2_reg_weight": 1e-6,
52 | "feature_scale": 8,
53 | "gpu_ids": [0],
54 | "isTrain": true,
55 | "checkpoints_dir": "./checkpoints",
56 | "experiment_name": "experiment_sononet_grid_attention_fs8_avg_v12"
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/dataio/loader/utils.py:
--------------------------------------------------------------------------------
1 | import nibabel as nib
2 | import numpy as np
3 | import os
4 | from utils.util import mkdir
5 |
6 | def is_image_file(filename):
7 | return any(filename.endswith(extension) for extension in [".nii.gz"])
8 |
9 |
10 | def load_nifti_img(filepath, dtype):
11 | '''
12 | NIFTI Image Loader
13 | :param filepath: path to the input NIFTI image
14 | :param dtype: dataio type of the nifti numpy array
15 | :return: return numpy array
16 | '''
17 | nim = nib.load(filepath)
18 | out_nii_array = np.array(nim.get_data(),dtype=dtype)
19 | out_nii_array = np.squeeze(out_nii_array) # drop singleton dim in case temporal dim exists
20 | meta = {'affine': nim.get_affine(),
21 | 'dim': nim.header['dim'],
22 | 'pixdim': nim.header['pixdim'],
23 | 'name': os.path.basename(filepath)
24 | }
25 |
26 | return out_nii_array, meta
27 |
28 |
29 | def write_nifti_img(input_nii_array, meta, savedir):
30 | mkdir(savedir)
31 | affine = meta['affine'][0].cpu().numpy()
32 | pixdim = meta['pixdim'][0].cpu().numpy()
33 | dim = meta['dim'][0].cpu().numpy()
34 |
35 | img = nib.Nifti1Image(input_nii_array, affine=affine)
36 | img.header['dim'] = dim
37 | img.header['pixdim'] = pixdim
38 |
39 | savename = os.path.join(savedir, meta['name'][0])
40 | print('saving: ', savename)
41 | nib.save(img, savename)
42 |
43 |
44 | def check_exceptions(image, label=None):
45 | if label is not None:
46 | if image.shape != label.shape:
47 | print('Error: mismatched size, image.shape = {0}, '
48 | 'label.shape = {1}'.format(image.shape, label.shape))
49 | #print('Skip {0}, {1}'.format(image_name, label_name))
50 | raise(Exception('image and label sizes do not match'))
51 |
52 | if image.max() < 1e-6:
53 | print('Error: blank image, image.max = {0}'.format(image.max()))
54 | #print('Skip {0} {1}'.format(image_name, label_name))
55 | raise (Exception('blank image exception'))
--------------------------------------------------------------------------------
/utils/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, reflesh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 | # print(self.img_dir)
16 |
17 | self.doc = dominate.document(title=title)
18 | if reflesh > 0:
19 | with self.doc.head:
20 | meta(http_equiv="reflesh", content=str(reflesh))
21 |
22 | def get_image_dir(self):
23 | return self.img_dir
24 |
25 | def add_header(self, str):
26 | with self.doc:
27 | h3(str)
28 |
29 | def add_table(self, border=1):
30 | self.t = table(border=border, style="table-layout: fixed;")
31 | self.doc.add(self.t)
32 |
33 | def add_images(self, ims, txts, links, width=400):
34 | self.add_table()
35 | with self.t:
36 | with tr():
37 | for im, txt, link in zip(ims, txts, links):
38 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
39 | with p():
40 | with a(href=os.path.join('images', link)):
41 | img(style="width:%dpx" % width, src=os.path.join('images', im))
42 | br()
43 | p(txt)
44 |
45 | def save(self):
46 | html_file = '%s/index.html' % self.web_dir
47 | f = open(html_file, 'wt')
48 | f.write(self.doc.render())
49 | f.close()
50 |
51 |
52 | if __name__ == '__main__':
53 | html = HTML('web/', 'test_html')
54 | html.add_header('hello world')
55 |
56 | ims = []
57 | txts = []
58 | links = []
59 | for n in range(4):
60 | ims.append('image_%d.png' % n)
61 | txts.append('text_%d' % n)
62 | links.append('image_%d.png' % n)
63 | html.add_images(ims, txts, links)
64 | html.save()
--------------------------------------------------------------------------------
/dataio/loader/cmr_3D_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import numpy as np
3 | import datetime
4 |
5 | from os import listdir
6 | from os.path import join
7 | from .utils import load_nifti_img, check_exceptions, is_image_file
8 |
9 |
10 | class CMR3DDataset(data.Dataset):
11 | def __init__(self, root_dir, split, transform=None, preload_data=False):
12 | super(CMR3DDataset, self).__init__()
13 | image_dir = join(root_dir, split, 'image')
14 | target_dir = join(root_dir, split, 'label')
15 | self.image_filenames = sorted([join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)])
16 | self.target_filenames = sorted([join(target_dir, x) for x in listdir(target_dir) if is_image_file(x)])
17 | assert len(self.image_filenames) == len(self.target_filenames)
18 |
19 | # report the number of images in the dataset
20 | print('Number of {0} images: {1} NIFTIs'.format(split, self.__len__()))
21 |
22 | # data augmentation
23 | self.transform = transform
24 |
25 | # data load into the ram memory
26 | self.preload_data = preload_data
27 | if self.preload_data:
28 | print('Preloading the {0} dataset ...'.format(split))
29 | self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames]
30 | self.raw_labels = [load_nifti_img(ii, dtype=np.uint8)[0] for ii in self.target_filenames]
31 | print('Loading is done\n')
32 |
33 |
34 | def __getitem__(self, index):
35 | # update the seed to avoid workers sample the same augmentation parameters
36 | np.random.seed(datetime.datetime.now().second + datetime.datetime.now().microsecond)
37 |
38 | # load the nifti images
39 | if not self.preload_data:
40 | input, _ = load_nifti_img(self.image_filenames[index], dtype=np.int16)
41 | target, _ = load_nifti_img(self.target_filenames[index], dtype=np.uint8)
42 | else:
43 | input = np.copy(self.raw_images[index])
44 | target = np.copy(self.raw_labels[index])
45 |
46 | # handle exceptions
47 | check_exceptions(input, target)
48 | if self.transform:
49 | input, target = self.transform(input, target)
50 |
51 | return input, target
52 |
53 | def __len__(self):
54 | return len(self.image_filenames)
--------------------------------------------------------------------------------
/dataio/loader/ukbb_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import numpy as np
3 | import datetime
4 |
5 | from os import listdir
6 | from os.path import join
7 | from .utils import load_nifti_img, check_exceptions, is_image_file
8 |
9 |
10 | class UKBBDataset(data.Dataset):
11 | def __init__(self, root_dir, split, transform=None, preload_data=False):
12 | super(UKBBDataset, self).__init__()
13 | image_dir = join(root_dir, split, 'image')
14 | target_dir = join(root_dir, split, 'label')
15 | self.image_filenames = sorted([join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)])
16 | self.target_filenames = sorted([join(target_dir, x) for x in listdir(target_dir) if is_image_file(x)])
17 | assert len(self.image_filenames) == len(self.target_filenames)
18 |
19 | # report the number of images in the dataset
20 | print('Number of {0} images: {1} NIFTIs'.format(split, self.__len__()))
21 |
22 | # data augmentation
23 | self.transform = transform
24 |
25 | # data load into the ram memory
26 | self.preload_data = preload_data
27 | if self.preload_data:
28 | print('Preloading the {0} dataset ...'.format(split))
29 | self.raw_images = [load_nifti_img(ii, dtype=np.int16)[0] for ii in self.image_filenames]
30 | self.raw_labels = [load_nifti_img(ii, dtype=np.uint8)[0] for ii in self.target_filenames]
31 | print('Loading is done\n')
32 |
33 | def __getitem__(self, index):
34 | # update the seed to avoid workers sample the same augmentation parameters
35 | np.random.seed(datetime.datetime.now().second + datetime.datetime.now().microsecond)
36 |
37 | # load the nifti images
38 | if not self.preload_data:
39 | input, _ = load_nifti_img(self.image_filenames[index], dtype=np.int16)
40 | target, _ = load_nifti_img(self.target_filenames[index], dtype=np.uint8)
41 | else:
42 | input = np.copy(self.raw_images[index])
43 | target = np.copy(self.raw_labels[index])
44 |
45 | # pass a random slice for the time being
46 | id_slice = np.random.randint(0,input.shape[2])
47 | input = input[:,:,[id_slice]]
48 | target= target[:,:,[id_slice]]
49 |
50 | # handle exceptions
51 | check_exceptions(input, target)
52 | if self.transform:
53 | input, target = self.transform(input, target)
54 |
55 | return input, target
56 |
57 | def __len__(self):
58 | return len(self.image_filenames)
--------------------------------------------------------------------------------
/dataio/loader/us_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import h5py
4 | import numpy as np
5 | import datetime
6 |
7 | from os import listdir
8 | from os.path import join
9 | #from .utils import check_exceptions
10 |
11 |
12 | class UltraSoundDataset(data.Dataset):
13 | def __init__(self, root_path, split, transform=None, preload_data=False):
14 | super(UltraSoundDataset, self).__init__()
15 |
16 | f = h5py.File(root_path)
17 |
18 | self.images = f['x_'+split]
19 |
20 | if preload_data:
21 | self.images = np.array(self.images[:])
22 |
23 | self.labels = np.array(f['p_'+split][:], dtype=np.int64)#[:1000]
24 | self.label_names = [x.decode('utf-8') for x in f['label_names'][:].tolist()]
25 | #print(self.label_names)
26 | #print(np.unique(self.labels[:]))
27 | # construct weight for entry
28 | self.n_class = len(self.label_names)
29 | class_weight = np.zeros(self.n_class)
30 | for lab in range(self.n_class):
31 | class_weight[lab] = np.sum(self.labels[:] == lab)
32 |
33 | class_weight = 1 / class_weight
34 |
35 | self.weight = np.zeros(len(self.labels))
36 | for i in range(len(self.labels)):
37 | self.weight[i] = class_weight[self.labels[i]]
38 |
39 | #print(class_weight)
40 | assert len(self.images) == len(self.labels)
41 |
42 | # data augmentation
43 | self.transform = transform
44 |
45 | # report the number of images in the dataset
46 | print('Number of {0} images: {1} NIFTIs'.format(split, self.__len__()))
47 |
48 | def __getitem__(self, index):
49 | # update the seed to avoid workers sample the same augmentation parameters
50 | np.random.seed(datetime.datetime.now().second + datetime.datetime.now().microsecond)
51 |
52 | # load the nifti images
53 | input = self.images[index][0]
54 | target = self.labels[index]
55 |
56 | #input = input.transpose((1,2,0))
57 |
58 | # handle exceptions
59 | #check_exceptions(input, target)
60 | if self.transform:
61 | input = self.transform(input)
62 |
63 | #print(input.shape, torch.from_numpy(np.array([target])))
64 | #print("target",np.int64(target))
65 | return input, int(target)
66 |
67 | def __len__(self):
68 | return len(self.images)
69 |
70 |
71 | # if __name__ == '__main__':
72 | # dataset = UltraSoundDataset("/vol/bitbucket/js3611/data_ultrasound/preproc_combined_inp_224x288.hdf5",'test')
73 |
74 | # from torch.utils.data import DataLoader, sampler
75 | # ds = DataLoader(dataset=dataset, num_workers=1, batch_size=2)
76 |
--------------------------------------------------------------------------------
/models/networks/unet_2D.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | from .utils import unetConv2, unetUp
4 | import torch.nn.functional as F
5 | from models.networks_other import init_weights
6 |
7 | class unet_2D(nn.Module):
8 |
9 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True):
10 | super(unet_2D, self).__init__()
11 | self.is_deconv = is_deconv
12 | self.in_channels = in_channels
13 | self.is_batchnorm = is_batchnorm
14 | self.feature_scale = feature_scale
15 |
16 | filters = [64, 128, 256, 512, 1024]
17 | filters = [int(x / self.feature_scale) for x in filters]
18 |
19 | # downsampling
20 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
21 | self.maxpool1 = nn.MaxPool2d(kernel_size=2)
22 |
23 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
24 | self.maxpool2 = nn.MaxPool2d(kernel_size=2)
25 |
26 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
27 | self.maxpool3 = nn.MaxPool2d(kernel_size=2)
28 |
29 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
30 | self.maxpool4 = nn.MaxPool2d(kernel_size=2)
31 |
32 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
33 |
34 | # upsampling
35 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
36 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
37 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
38 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)
39 |
40 | # final conv (without any concat)
41 | self.final = nn.Conv2d(filters[0], n_classes, 1)
42 |
43 | # initialise weights
44 | for m in self.modules():
45 | if isinstance(m, nn.Conv2d):
46 | init_weights(m, init_type='kaiming')
47 | elif isinstance(m, nn.BatchNorm2d):
48 | init_weights(m, init_type='kaiming')
49 |
50 |
51 | def forward(self, inputs):
52 | conv1 = self.conv1(inputs)
53 | maxpool1 = self.maxpool1(conv1)
54 |
55 | conv2 = self.conv2(maxpool1)
56 | maxpool2 = self.maxpool2(conv2)
57 |
58 | conv3 = self.conv3(maxpool2)
59 | maxpool3 = self.maxpool3(conv3)
60 |
61 | conv4 = self.conv4(maxpool3)
62 | maxpool4 = self.maxpool4(conv4)
63 |
64 | center = self.center(maxpool4)
65 | up4 = self.up_concat4(conv4, center)
66 | up3 = self.up_concat3(conv3, up4)
67 | up2 = self.up_concat2(conv2, up3)
68 | up1 = self.up_concat1(conv1, up2)
69 |
70 | final = self.final(up1)
71 |
72 | return final
73 |
74 | @staticmethod
75 | def apply_argmax_softmax(pred):
76 | log_p = F.softmax(pred, dim=1)
77 |
78 | return log_p
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/models/networks/unet_3D.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | from .utils import UnetConv3, UnetUp3
4 | import torch.nn.functional as F
5 | from models.networks_other import init_weights
6 |
7 | class unet_3D(nn.Module):
8 |
9 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True):
10 | super(unet_3D, self).__init__()
11 | self.is_deconv = is_deconv
12 | self.in_channels = in_channels
13 | self.is_batchnorm = is_batchnorm
14 | self.feature_scale = feature_scale
15 |
16 | filters = [64, 128, 256, 512, 1024]
17 | filters = [int(x / self.feature_scale) for x in filters]
18 |
19 | # downsampling
20 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm)
21 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1))
22 |
23 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm)
24 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1))
25 |
26 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm)
27 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1))
28 |
29 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm)
30 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1))
31 |
32 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm)
33 |
34 | # upsampling
35 | self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv, is_batchnorm)
36 | self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv, is_batchnorm)
37 | self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv, is_batchnorm)
38 | self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv, is_batchnorm)
39 |
40 | # final conv (without any concat)
41 | self.final = nn.Conv3d(filters[0], n_classes, 1)
42 |
43 | # initialise weights
44 | for m in self.modules():
45 | if isinstance(m, nn.Conv3d):
46 | init_weights(m, init_type='kaiming')
47 | elif isinstance(m, nn.BatchNorm3d):
48 | init_weights(m, init_type='kaiming')
49 |
50 | def forward(self, inputs):
51 | conv1 = self.conv1(inputs)
52 | maxpool1 = self.maxpool1(conv1)
53 |
54 | conv2 = self.conv2(maxpool1)
55 | maxpool2 = self.maxpool2(conv2)
56 |
57 | conv3 = self.conv3(maxpool2)
58 | maxpool3 = self.maxpool3(conv3)
59 |
60 | conv4 = self.conv4(maxpool3)
61 | maxpool4 = self.maxpool4(conv4)
62 |
63 | center = self.center(maxpool4)
64 | up4 = self.up_concat4(conv4, center)
65 | up3 = self.up_concat3(conv3, up4)
66 | up2 = self.up_concat2(conv2, up3)
67 | up1 = self.up_concat1(conv1, up2)
68 |
69 | final = self.final(up1)
70 |
71 | return final
72 |
73 | @staticmethod
74 | def apply_argmax_softmax(pred):
75 | log_p = F.softmax(pred, dim=1)
76 |
77 | return log_p
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/models/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_2D import *
2 | from .unet_3D import *
3 | from .unet_nonlocal_2D import *
4 | from .unet_nonlocal_3D import *
5 | from .unet_grid_attention_3D import *
6 | from .unet_CT_dsv_3D import *
7 | from .unet_CT_single_att_dsv_3D import *
8 | from .unet_CT_multi_att_dsv_3D import *
9 | from .sononet import *
10 | from .sononet_grid_attention import *
11 |
12 | def get_network(name, n_classes, in_channels=3, feature_scale=4, tensor_dim='2D',
13 | nonlocal_mode='embedded_gaussian', attention_dsample=(2,2,2),
14 | aggregation_mode='concat'):
15 | model = _get_model_instance(name, tensor_dim)
16 |
17 | if name in ['unet', 'unet_ct_dsv']:
18 | model = model(n_classes=n_classes,
19 | is_batchnorm=True,
20 | in_channels=in_channels,
21 | feature_scale=feature_scale,
22 | is_deconv=False)
23 | elif name in ['unet_nonlocal']:
24 | model = model(n_classes=n_classes,
25 | is_batchnorm=True,
26 | in_channels=in_channels,
27 | is_deconv=False,
28 | nonlocal_mode=nonlocal_mode,
29 | feature_scale=feature_scale)
30 | elif name in ['unet_grid_gating',
31 | 'unet_ct_single_att_dsv',
32 | 'unet_ct_multi_att_dsv']:
33 | model = model(n_classes=n_classes,
34 | is_batchnorm=True,
35 | in_channels=in_channels,
36 | nonlocal_mode=nonlocal_mode,
37 | feature_scale=feature_scale,
38 | attention_dsample=attention_dsample,
39 | is_deconv=False)
40 | elif name in ['sononet','sononet2']:
41 | model = model(n_classes=n_classes,
42 | is_batchnorm=True,
43 | in_channels=in_channels,
44 | feature_scale=feature_scale)
45 | elif name in ['sononet_grid_attention']:
46 | model = model(n_classes=n_classes,
47 | is_batchnorm=True,
48 | in_channels=in_channels,
49 | feature_scale=feature_scale,
50 | nonlocal_mode=nonlocal_mode,
51 | aggregation_mode=aggregation_mode)
52 | else:
53 | raise 'Model {} not available'.format(name)
54 |
55 | return model
56 |
57 |
58 | def _get_model_instance(name, tensor_dim):
59 | return {
60 | 'unet':{'2D': unet_2D, '3D': unet_3D},
61 | 'unet_nonlocal':{'2D': unet_nonlocal_2D, '3D': unet_nonlocal_3D},
62 | 'unet_grid_gating': {'3D': unet_grid_attention_3D},
63 | 'unet_ct_dsv': {'3D': unet_CT_dsv_3D},
64 | 'unet_ct_single_att_dsv': {'3D': unet_CT_single_att_dsv_3D},
65 | 'unet_ct_multi_att_dsv': {'3D': unet_CT_multi_att_dsv_3D},
66 | 'sononet': {'2D': sononet},
67 | 'sononet2': {'2D': sononet2},
68 | 'sononet_grid_attention': {'2D': sononet_grid_attention}
69 | }[name][tensor_dim]
70 |
--------------------------------------------------------------------------------
/models/networks/sononet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import torch.nn as nn
4 | from .utils import unetConv2, unetUp, conv2DBatchNormRelu, conv2DBatchNorm
5 | import torch.nn.functional as F
6 | from models.networks_other import init_weights
7 |
8 | class sononet(nn.Module):
9 |
10 | def __init__(self, feature_scale=4, n_classes=21, in_channels=3, is_batchnorm=True, n_convs=None):
11 | super(sononet, self).__init__()
12 | self.in_channels = in_channels
13 | self.is_batchnorm = is_batchnorm
14 | self.feature_scale = feature_scale
15 | self.n_classes= n_classes
16 |
17 | filters = [64, 128, 256, 512]
18 | filters = [int(x / self.feature_scale) for x in filters]
19 |
20 | if n_convs is None:
21 | n_convs = [2,2,3,3,3]
22 |
23 | # downsampling
24 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm, n=n_convs[0])
25 | self.maxpool1 = nn.MaxPool2d(kernel_size=2)
26 |
27 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm, n=n_convs[1])
28 | self.maxpool2 = nn.MaxPool2d(kernel_size=2)
29 |
30 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm, n=n_convs[2])
31 | self.maxpool3 = nn.MaxPool2d(kernel_size=2)
32 |
33 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm, n=n_convs[3])
34 | self.maxpool4 = nn.MaxPool2d(kernel_size=2)
35 |
36 | self.conv5 = unetConv2(filters[3], filters[3], self.is_batchnorm, n=n_convs[4])
37 |
38 | # adaptation layer
39 | self.conv5_p = conv2DBatchNormRelu(filters[3], filters[2], 1, 1, 0)
40 | self.conv6_p = conv2DBatchNorm(filters[2], self.n_classes, 1, 1, 0)
41 |
42 | # initialise weights
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d):
45 | init_weights(m, init_type='kaiming')
46 | elif isinstance(m, nn.BatchNorm2d):
47 | init_weights(m, init_type='kaiming')
48 |
49 |
50 | def forward(self, inputs):
51 | # Feature Extraction
52 | conv1 = self.conv1(inputs)
53 | maxpool1 = self.maxpool1(conv1)
54 |
55 | conv2 = self.conv2(maxpool1)
56 | maxpool2 = self.maxpool2(conv2)
57 |
58 | conv3 = self.conv3(maxpool2)
59 | maxpool3 = self.maxpool3(conv3)
60 |
61 | conv4 = self.conv4(maxpool3)
62 | maxpool4 = self.maxpool4(conv4)
63 |
64 | conv5 = self.conv5(maxpool4)
65 |
66 | conv5_p = self.conv5_p(conv5)
67 | conv6_p = self.conv6_p(conv5_p)
68 |
69 | batch_size = inputs.shape[0]
70 | pooled = F.adaptive_avg_pool2d(conv6_p, (1, 1)).view(batch_size, -1)
71 |
72 | return pooled
73 |
74 |
75 | @staticmethod
76 | def apply_argmax_softmax(pred):
77 | log_p = F.softmax(pred, dim=1)
78 |
79 | return log_p
80 |
81 |
82 | def sononet2(feature_scale=4, n_classes=21, in_channels=3, is_batchnorm=True):
83 | return sononet(feature_scale, n_classes, in_channels, is_batchnorm, n_convs=[3,3,3,2,2])
84 |
85 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Abstract level model definition
2 | # Returns the model class for specified network type
3 | import os
4 |
5 |
6 | class ModelOpts:
7 | def __init__(self):
8 | self.gpu_ids = [0]
9 | self.isTrain = True
10 | self.continue_train = False
11 | self.which_epoch = int(0)
12 | self.save_dir = './checkpoints/default'
13 | self.model_type = 'unet'
14 | self.input_nc = 1
15 | self.output_nc = 4
16 | self.lr_rate = 1e-12
17 | self.l2_reg_weight = 0.0
18 | self.feature_scale = 4
19 | self.tensor_dim = '2D'
20 | self.path_pre_trained_model = None
21 | self.criterion = 'cross_entropy'
22 | self.type = 'seg'
23 |
24 | # Attention
25 | self.nonlocal_mode = 'concatenation'
26 | self.attention_dsample = (2,2,2)
27 |
28 | # Attention Classifier
29 | self.aggregation_mode = 'concatenation'
30 |
31 |
32 | def initialise(self, json_opts):
33 | opts = json_opts
34 |
35 | self.raw = json_opts
36 | self.gpu_ids = opts.gpu_ids
37 | self.isTrain = opts.isTrain
38 | self.save_dir = os.path.join(opts.checkpoints_dir, opts.experiment_name)
39 | self.model_type = opts.model_type
40 | self.input_nc = opts.input_nc
41 | self.output_nc = opts.output_nc
42 | self.continue_train = opts.continue_train
43 | self.which_epoch = opts.which_epoch
44 |
45 | if hasattr(opts, 'type'): self.type = opts.type
46 | if hasattr(opts, 'l2_reg_weight'): self.l2_reg_weight = opts.l2_reg_weight
47 | if hasattr(opts, 'lr_rate'): self.lr_rate = opts.lr_rate
48 | if hasattr(opts, 'feature_scale'): self.feature_scale = opts.feature_scale
49 | if hasattr(opts, 'tensor_dim'): self.tensor_dim = opts.tensor_dim
50 |
51 | if hasattr(opts, 'path_pre_trained_model'): self.path_pre_trained_model = opts.path_pre_trained_model
52 | if hasattr(opts, 'criterion'): self.criterion = opts.criterion
53 |
54 | if hasattr(opts, 'nonlocal_mode'): self.nonlocal_mode = opts.nonlocal_mode
55 | if hasattr(opts, 'attention_dsample'): self.attention_dsample = opts.attention_dsample
56 | # Classifier
57 | if hasattr(opts, 'aggregation_mode'): self.aggregation_mode = opts.aggregation_mode
58 |
59 | def get_model(json_opts):
60 |
61 | # Neural Network Model Initialisation
62 | model = None
63 | model_opts = ModelOpts()
64 | model_opts.initialise(json_opts)
65 |
66 | # Print the model type
67 | print('\nInitialising model {}'.format(model_opts.model_type))
68 |
69 | model_type = model_opts.type
70 | if model_type == 'seg':
71 | # Return the model type
72 | from .feedforward_seg_model import FeedForwardSegmentation
73 | model = FeedForwardSegmentation()
74 |
75 | elif model_type == 'classifier':
76 | # Return the model type
77 | from .feedforward_classifier import FeedForwardClassifier
78 | model = FeedForwardClassifier()
79 |
80 | elif model_type == 'aggregated_classifier':
81 | # Return the model type
82 | from .aggregated_classifier import AggregatedClassifier
83 | model = AggregatedClassifier()
84 |
85 |
86 | # Initialise the created model
87 | model.initialize(model_opts)
88 | print("Model [%s] is created" % (model.name()))
89 |
90 | return model
91 |
--------------------------------------------------------------------------------
/models/networks/unet_nonlocal_2D.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | from .utils import unetConv2, unetUp
4 | from models.layers.nonlocal_layer import NONLocalBlock2D
5 | import torch.nn.functional as F
6 |
7 |
8 | class unet_nonlocal_2D(nn.Module):
9 |
10 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3,
11 | is_batchnorm=True, nonlocal_mode='embedded_gaussian', nonlocal_sf=4):
12 | super(unet_nonlocal_2D, self).__init__()
13 | self.is_deconv = is_deconv
14 | self.in_channels = in_channels
15 | self.is_batchnorm = is_batchnorm
16 | self.feature_scale = feature_scale
17 |
18 | filters = [64, 128, 256, 512, 1024]
19 | filters = [int(x / self.feature_scale) for x in filters]
20 |
21 | # downsampling
22 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
23 | self.maxpool1 = nn.MaxPool2d(kernel_size=2)
24 | self.nonlocal1 = NONLocalBlock2D(in_channels=filters[0], inter_channels=filters[0] // 4,
25 | sub_sample_factor=nonlocal_sf, mode=nonlocal_mode)
26 |
27 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
28 | self.maxpool2 = nn.MaxPool2d(kernel_size=2)
29 | self.nonlocal2 = NONLocalBlock2D(in_channels=filters[1], inter_channels=filters[1] // 4,
30 | sub_sample_factor=nonlocal_sf, mode=nonlocal_mode)
31 |
32 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
33 | self.maxpool3 = nn.MaxPool2d(kernel_size=2)
34 |
35 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
36 | self.maxpool4 = nn.MaxPool2d(kernel_size=2)
37 |
38 | self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
39 |
40 | # upsampling
41 | self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
42 | self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
43 | self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
44 | self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)
45 |
46 | # final conv (without any concat)
47 | self.final = nn.Conv2d(filters[0], n_classes, 1)
48 |
49 | # initialise weights
50 | for m in self.modules():
51 | if isinstance(m, nn.Conv2d):
52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
53 | m.weight.data.normal_(0, math.sqrt(2. / n))
54 | elif isinstance(m, nn.BatchNorm2d):
55 | m.weight.data.fill_(1)
56 | m.bias.data.zero_()
57 |
58 | def forward(self, inputs):
59 | conv1 = self.conv1(inputs)
60 | maxpool1 = self.maxpool1(conv1)
61 | nonlocal1 = self.nonlocal1(maxpool1)
62 |
63 | conv2 = self.conv2(nonlocal1)
64 | maxpool2 = self.maxpool2(conv2)
65 | nonlocal2 = self.nonlocal2(maxpool2)
66 |
67 | conv3 = self.conv3(nonlocal2)
68 | maxpool3 = self.maxpool3(conv3)
69 |
70 | conv4 = self.conv4(maxpool3)
71 | maxpool4 = self.maxpool4(conv4)
72 |
73 | center = self.center(maxpool4)
74 | up4 = self.up_concat4(conv4, center)
75 | up3 = self.up_concat3(conv3, up4)
76 | up2 = self.up_concat2(conv2, up3)
77 | up1 = self.up_concat1(conv1, up2)
78 |
79 | final = self.final(up1)
80 |
81 | return final
82 |
83 | @staticmethod
84 | def apply_argmax_softmax(pred):
85 | log_p = F.softmax(pred, dim=1)
86 |
87 | return log_p
88 |
--------------------------------------------------------------------------------
/models/networks/unet_nonlocal_3D.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | from .utils import UnetConv3, UnetUp3
4 | import torch.nn.functional as F
5 | from models.layers.nonlocal_layer import NONLocalBlock3D
6 | from models.networks_other import init_weights
7 |
8 | class unet_nonlocal_3D(nn.Module):
9 |
10 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True,
11 | nonlocal_mode='embedded_gaussian', nonlocal_sf=4):
12 | super(unet_nonlocal_3D, self).__init__()
13 | self.is_deconv = is_deconv
14 | self.in_channels = in_channels
15 | self.is_batchnorm = is_batchnorm
16 | self.feature_scale = feature_scale
17 |
18 | filters = [64, 128, 256, 512, 1024]
19 | filters = [int(x / self.feature_scale) for x in filters]
20 |
21 | # downsampling
22 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm)
23 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1))
24 |
25 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm)
26 | self.nonlocal2 = NONLocalBlock3D(in_channels=filters[1], inter_channels=filters[1] // 4,
27 | sub_sample_factor=nonlocal_sf, mode=nonlocal_mode)
28 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1))
29 |
30 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm)
31 | self.nonlocal3 = NONLocalBlock3D(in_channels=filters[2], inter_channels=filters[2] // 4,
32 | sub_sample_factor=nonlocal_sf, mode=nonlocal_mode)
33 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1))
34 |
35 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm)
36 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1))
37 |
38 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm)
39 |
40 | # upsampling
41 | self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv)
42 | self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv)
43 | self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv)
44 | self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv)
45 |
46 | # final conv (without any concat)
47 | self.final = nn.Conv3d(filters[0], n_classes, 1)
48 |
49 | # initialise weights
50 | for m in self.modules():
51 | if isinstance(m, nn.Conv3d):
52 | init_weights(m, init_type='kaiming')
53 | elif isinstance(m, nn.BatchNorm3d):
54 | init_weights(m, init_type='kaiming')
55 |
56 | def forward(self, inputs):
57 | conv1 = self.conv1(inputs)
58 | maxpool1 = self.maxpool1(conv1)
59 |
60 | conv2 = self.conv2(maxpool1)
61 | nl2 = self.nonlocal2(conv2)
62 | maxpool2 = self.maxpool2(nl2)
63 |
64 | conv3 = self.conv3(maxpool2)
65 | nl3 = self.nonlocal3(conv3)
66 | maxpool3 = self.maxpool3(nl3)
67 |
68 | conv4 = self.conv4(maxpool3)
69 | maxpool4 = self.maxpool4(conv4)
70 |
71 | center = self.center(maxpool4)
72 | up4 = self.up_concat4(conv4, center)
73 | up3 = self.up_concat3(nl3, up4)
74 | up2 = self.up_concat2(nl2, up3)
75 | up1 = self.up_concat1(conv1, up2)
76 |
77 | final = self.final(up1)
78 |
79 | return final
80 |
81 | @staticmethod
82 | def apply_argmax_softmax(pred):
83 | log_p = F.softmax(pred, dim=1)
84 |
85 | return log_p
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/visualise_fmaps.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 |
3 | from dataio.loader import get_dataset, get_dataset_path
4 | from dataio.transformation import get_dataset_transformation
5 | from utils.util import json_file_to_pyobj
6 | from models import get_model
7 |
8 | import matplotlib.cm as cm
9 | import matplotlib.pyplot as plt
10 | import math, numpy, os
11 | from scipy.misc import imresize
12 | from skimage.transform import resize
13 | from dataio.loader.utils import write_nifti_img
14 | from torch.nn import functional as F
15 |
16 | def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
17 | plt.ion()
18 | filters = units.shape[2]
19 | n_columns = round(math.sqrt(filters))
20 | n_rows = math.ceil(filters / n_columns) + 1
21 | fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
22 | fig.clf()
23 |
24 | for i in range(filters):
25 | ax1 = plt.subplot(n_rows, n_columns, i+1)
26 | plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
27 | plt.axis('on')
28 | ax1.set_xticklabels([])
29 | ax1.set_yticklabels([])
30 | plt.colorbar()
31 | if colormap_lim:
32 | plt.clim(colormap_lim[0],colormap_lim[1])
33 |
34 | plt.subplots_adjust(wspace=0, hspace=0)
35 | plt.tight_layout()
36 |
37 | # Load options
38 | json_opts = json_file_to_pyobj('/vol/biomedic2/oo2113/projects/syntAI/ukbb_pytorch/configs_final/debug_ct.json')
39 |
40 | # Setup the NN Model
41 | model = get_model(json_opts.model)
42 |
43 | # Setup Dataset and Augmentation
44 | dataset_class = get_dataset('test_sax')
45 | dataset_path = get_dataset_path('test_sax', json_opts.data_path)
46 | dataset_transform = get_dataset_transformation('test_sax', json_opts.augmentation)
47 |
48 | # Setup Data Loader
49 | dataset = dataset_class(dataset_path, transform=dataset_transform['test'])
50 | data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, shuffle=False)
51 |
52 | # test
53 | for iteration, (input_arr, input_meta, _) in enumerate(data_loader, 1):
54 | model.set_input(input_arr)
55 | layer_name = 'attentionblock1'
56 | inp_fmap, out_fmap = model.get_feature_maps(layer_name=layer_name, upscale=False)
57 |
58 | # Display the input image and Down_sample the input image
59 | orig_input_img = model.input.permute(2, 3, 4, 1, 0).cpu().numpy()
60 | upsampled_attention = F.upsample(out_fmap[1], size=input_arr.size()[2:], mode='trilinear').data.squeeze().permute(1,2,3,0).cpu().numpy()
61 | upsampled_fmap_before = F.upsample(inp_fmap[0], size=input_arr.size()[2:], mode='trilinear').data.squeeze().permute(1,2,3,0).cpu().numpy()
62 | upsampled_fmap_after = F.upsample(out_fmap[2], size=input_arr.size()[2:], mode='trilinear').data.squeeze().permute(1,2,3,0).cpu().numpy()
63 |
64 | # Define the directories
65 | save_directory = os.path.join('/vol/bitbucket/oo2113/tmp/feature_maps', layer_name)
66 | basename = input_meta['name'][0].split('.')[0]
67 |
68 | # Write the attentions to a nifti image
69 | input_meta['name'][0] = basename + '_img.nii.gz'
70 | write_nifti_img(orig_input_img, input_meta, savedir=save_directory)
71 |
72 | input_meta['name'][0] = basename + '_att.nii.gz'
73 | write_nifti_img(upsampled_attention, input_meta, savedir=save_directory)
74 |
75 | input_meta['name'][0] = basename + '_fmap_before.nii.gz'
76 | write_nifti_img(upsampled_fmap_before, input_meta, savedir=save_directory)
77 |
78 | input_meta['name'][0] = basename + '_fmap_after.nii.gz'
79 | write_nifti_img(upsampled_fmap_after, input_meta, savedir=save_directory)
80 |
81 | model.destructor()
82 | #if iteration == 1: break
--------------------------------------------------------------------------------
/visualise_att_maps_epoch.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 |
3 | from dataio.loader import get_dataset, get_dataset_path
4 | from dataio.transformation import get_dataset_transformation
5 | from utils.util import json_file_to_pyobj
6 | from models import get_model
7 |
8 | import matplotlib.cm as cm
9 | import matplotlib.pyplot as plt
10 | import math, numpy, os
11 | from dataio.loader.utils import write_nifti_img
12 | from torch.nn import functional as F
13 |
14 |
15 | def mkdirfun(directory):
16 | if not os.path.exists(directory):
17 | os.makedirs(directory)
18 |
19 |
20 | def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None):
21 | plt.ion()
22 | filters = units.shape[2]
23 | n_columns = round(math.sqrt(filters))
24 | n_rows = math.ceil(filters / n_columns) + 1
25 | fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
26 | fig.clf()
27 |
28 | for i in range(filters):
29 | ax1 = plt.subplot(n_rows, n_columns, i+1)
30 | plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
31 | plt.axis('on')
32 | ax1.set_xticklabels([])
33 | ax1.set_yticklabels([])
34 | plt.colorbar()
35 | if colormap_lim:
36 | plt.clim(colormap_lim[0],colormap_lim[1])
37 |
38 | plt.subplots_adjust(wspace=0, hspace=0)
39 | plt.tight_layout()
40 |
41 | # Epochs
42 | layer_name = 'attentionblock2'
43 | layer_save_directory = os.path.join('/vol/bitbucket/oo2113/tmp/attention_maps', layer_name); mkdirfun(layer_save_directory)
44 | epochs = range(225, 230, 3)
45 | att_maps = list()
46 | int_imgs = list()
47 | subject_id = int(2)
48 | for epoch in epochs:
49 |
50 | # Load options and replace the epoch attribute
51 | json_opts = json_file_to_pyobj('/vol/biomedic2/oo2113/projects/syntAI/ukbb_pytorch/configs_final/debug_ct.json')
52 | json_opts = json_opts._replace(model=json_opts.model._replace(which_epoch=epoch))
53 |
54 | # Setup the NN Model
55 | model = get_model(json_opts.model)
56 |
57 | # Setup Dataset and Augmentation
58 | dataset_class = get_dataset('test_sax')
59 | dataset_path = get_dataset_path('test_sax', json_opts.data_path)
60 | dataset_transform = get_dataset_transformation('test_sax', json_opts.augmentation)
61 |
62 | # Setup Data Loader
63 | dataset = dataset_class(dataset_path, transform=dataset_transform['test'])
64 | data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, shuffle=False)
65 |
66 | # test
67 | for iteration, (input_arr, input_meta, _) in enumerate(data_loader, 1):
68 | # look for the subject_id
69 | if iteration == subject_id:
70 | # load the input image into the model
71 | model.set_input(input_arr)
72 | inp_fmap, out_fmap = model.get_feature_maps(layer_name=layer_name, upscale=False)
73 |
74 | # Display the input image and Down_sample the input image
75 | orig_input_img = model.input.permute(2, 3, 4, 1, 0).cpu().numpy()
76 | upsampled_attention = F.upsample(out_fmap[1], size=input_arr.size()[2:], mode='trilinear').data.squeeze().permute(1,2,3,0).cpu().numpy()
77 |
78 | # Append it to the list
79 | int_imgs.append(orig_input_img[:,:,:,0,0])
80 | att_maps.append(upsampled_attention[:,:,:,1])
81 |
82 | # return the model
83 | model.destructor()
84 |
85 | # Write the attentions to a nifti image
86 | input_meta['name'][0] = str(subject_id) + '_img_2.nii.gz'
87 | int_imgs = numpy.array(int_imgs).transpose([1,2,3,0])
88 | write_nifti_img(int_imgs, input_meta, savedir=layer_save_directory)
89 |
90 | input_meta['name'][0] = str(subject_id) + '_att_2.nii.gz'
91 | att_maps = numpy.array(att_maps).transpose([1,2,3,0])
92 | write_nifti_img(att_maps, input_meta, savedir=layer_save_directory)
93 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy
3 | import torch
4 | from utils.util import mkdir
5 | from .networks_other import get_n_parameters
6 |
7 | class BaseModel():
8 | def __init__(self):
9 | self.input = None
10 | self.net = None
11 | self.isTrain = False
12 | self.use_cuda = True
13 | self.schedulers = []
14 | self.optimizers = []
15 | self.save_dir = None
16 | self.gpu_ids = []
17 | self.which_epoch = int(0)
18 | self.path_pre_trained_model = None
19 |
20 | def name(self):
21 | return 'BaseModel'
22 |
23 | def initialize(self, opt, **kwargs):
24 | self.gpu_ids = opt.gpu_ids
25 | self.isTrain = opt.isTrain
26 | self.ImgTensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
27 | self.LblTensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
28 | self.save_dir = opt.save_dir; mkdir(self.save_dir)
29 |
30 | def set_input(self, input):
31 | self.input = input
32 |
33 | def set_scheduler(self, train_opt):
34 | pass
35 |
36 | def forward(self, split):
37 | pass
38 |
39 | # used in test time, no backprop
40 | def test(self):
41 | pass
42 |
43 | def get_image_paths(self):
44 | pass
45 |
46 | def optimize_parameters(self):
47 | pass
48 |
49 | def get_current_visuals(self):
50 | return self.input
51 |
52 | def get_current_errors(self):
53 | return {}
54 |
55 | def get_input_size(self):
56 | return self.input.size() if input else None
57 |
58 | def save(self, label):
59 | pass
60 |
61 | # helper saving function that can be used by subclasses
62 | def save_network(self, network, network_label, epoch_label, gpu_ids):
63 | print('Saving the model {0} at the end of epoch {1}'.format(network_label, epoch_label))
64 | save_filename = '{0:03d}_net_{1}.pth'.format(epoch_label, network_label)
65 | save_path = os.path.join(self.save_dir, save_filename)
66 | torch.save(network.cpu().state_dict(), save_path)
67 | if len(gpu_ids) and torch.cuda.is_available():
68 | network.cuda(gpu_ids[0])
69 |
70 | # helper loading function that can be used by subclasses
71 | def load_network(self, network, network_label, epoch_label):
72 | print('Loading the model {0} - epoch {1}'.format(network_label, epoch_label))
73 | save_filename = '{0:03d}_net_{1}.pth'.format(epoch_label, network_label)
74 | save_path = os.path.join(self.save_dir, save_filename)
75 | network.load_state_dict(torch.load(save_path))
76 |
77 | def load_network_from_path(self, network, network_filepath, strict):
78 | network_label = os.path.basename(network_filepath)
79 | epoch_label = network_label.split('_')[0]
80 | print('Loading the model {0} - epoch {1}'.format(network_label, epoch_label))
81 | network.load_state_dict(torch.load(network_filepath), strict=strict)
82 |
83 | # update learning rate (called once every epoch)
84 | def update_learning_rate(self, metric=None, epoch=None):
85 | for scheduler in self.schedulers:
86 | if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
87 | scheduler.step(metrics=metric)
88 | else:
89 | scheduler.step()
90 | lr = self.optimizers[0].param_groups[0]['lr']
91 | print('current learning rate = %.7f' % lr)
92 |
93 | # returns the number of trainable parameters
94 | def get_number_parameters(self):
95 | return get_n_parameters(self.net)
96 |
97 | # clean up the GPU memory
98 | def destructor(self):
99 | del self.net
100 | del self.input
101 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | from PIL import Image
4 | import inspect, re
5 | import numpy as np
6 | import os
7 | import collections
8 | import json
9 | import csv
10 | from skimage.exposure import rescale_intensity
11 |
12 | # Converts a Tensor into a Numpy array
13 | # |imtype|: the desired type of the converted numpy array
14 | def tensor2im(image_tensor, imgtype='img', datatype=np.uint8):
15 | image_numpy = image_tensor[0].cpu().float().numpy()
16 | if image_numpy.ndim == 4:# image_numpy (C x W x H x S)
17 | mid_slice = image_numpy.shape[-1]//2
18 | image_numpy = image_numpy[:,:,:,mid_slice]
19 | if image_numpy.shape[0] == 1:
20 | image_numpy = np.tile(image_numpy, (3, 1, 1))
21 | image_numpy = np.transpose(image_numpy, (1, 2, 0))
22 | if imgtype == 'img':
23 | image_numpy = (image_numpy + 8) / 16.0 * 255.0
24 | if np.unique(image_numpy).size == int(1):
25 | return image_numpy.astype(datatype)
26 | return rescale_intensity(image_numpy.astype(datatype))
27 |
28 |
29 | def diagnose_network(net, name='network'):
30 | mean = 0.0
31 | count = 0
32 | for param in net.parameters():
33 | if param.grad is not None:
34 | mean += torch.mean(torch.abs(param.grad.data))
35 | count += 1
36 | if count > 0:
37 | mean = mean / count
38 | print(name)
39 | print(mean)
40 |
41 |
42 | def save_image(image_numpy, image_path):
43 | image_pil = Image.fromarray(image_numpy)
44 | image_pil.save(image_path)
45 |
46 |
47 | def info(object, spacing=10, collapse=1):
48 | """Print methods and doc strings.
49 | Takes module, class, list, dictionary, or string."""
50 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
51 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
52 | print( "\n".join(["%s %s" %
53 | (method.ljust(spacing),
54 | processFunc(str(getattr(object, method).__doc__)))
55 | for method in methodList]) )
56 |
57 | def varname(p):
58 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
59 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
60 | if m:
61 | return m.group(1)
62 |
63 | def print_numpy(x, val=True, shp=False):
64 | x = x.astype(np.float64)
65 | if shp:
66 | print('shape,', x.shape)
67 | if val:
68 | x = x.flatten()
69 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
70 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
71 |
72 |
73 | def mkdirs(paths):
74 | if isinstance(paths, list) and not isinstance(paths, str):
75 | for path in paths:
76 | mkdir(path)
77 | else:
78 | mkdir(paths)
79 |
80 |
81 | def mkdir(path):
82 | if not os.path.exists(path):
83 | os.makedirs(path)
84 |
85 |
86 | def json_file_to_pyobj(filename):
87 | def _json_object_hook(d): return collections.namedtuple('X', d.keys())(*d.values())
88 | def json2obj(data): return json.loads(data, object_hook=_json_object_hook)
89 | return json2obj(open(filename).read())
90 |
91 |
92 | def determine_crop_size(inp_shape, div_factor):
93 | div_factor= np.array(div_factor, dtype=np.float32)
94 | new_shape = np.ceil(np.divide(inp_shape, div_factor)) * div_factor
95 | pre_pad = np.round((new_shape - inp_shape) / 2.0).astype(np.int16)
96 | post_pad = ((new_shape - inp_shape) - pre_pad).astype(np.int16)
97 | return pre_pad, post_pad
98 |
99 |
100 | def csv_write(out_filename, in_header_list, in_val_list):
101 | with open(out_filename, 'w') as f:
102 | writer = csv.writer(f)
103 | writer.writerow(in_header_list)
104 | writer.writerows(zip(*in_val_list))
105 |
--------------------------------------------------------------------------------
/models/networks/unet_CT_dsv_3D.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .utils import UnetConv3, UnetUp3_CT, UnetDsv3
3 | import torch.nn.functional as F
4 | from models.networks_other import init_weights
5 | import torch
6 |
7 | class unet_CT_dsv_3D(nn.Module):
8 |
9 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True):
10 | super(unet_CT_dsv_3D, self).__init__()
11 | self.is_deconv = is_deconv
12 | self.in_channels = in_channels
13 | self.is_batchnorm = is_batchnorm
14 | self.feature_scale = feature_scale
15 |
16 | filters = [64, 128, 256, 512, 1024]
17 | filters = [int(x / self.feature_scale) for x in filters]
18 |
19 | # downsampling
20 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
21 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))
22 |
23 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
24 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))
25 |
26 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
27 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))
28 |
29 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
30 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))
31 |
32 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
33 |
34 | # upsampling
35 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
36 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
37 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
38 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)
39 |
40 | # deep supervision
41 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8)
42 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4)
43 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2)
44 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1)
45 |
46 | # final conv (without any concat)
47 | self.final = nn.Conv3d(n_classes*4, n_classes, 1)
48 |
49 |
50 | # initialise weights
51 | for m in self.modules():
52 | if isinstance(m, nn.Conv3d):
53 | init_weights(m, init_type='kaiming')
54 | elif isinstance(m, nn.BatchNorm3d):
55 | init_weights(m, init_type='kaiming')
56 |
57 | def forward(self, inputs):
58 | conv1 = self.conv1(inputs)
59 | maxpool1 = self.maxpool1(conv1)
60 |
61 | conv2 = self.conv2(maxpool1)
62 | maxpool2 = self.maxpool2(conv2)
63 |
64 | conv3 = self.conv3(maxpool2)
65 | maxpool3 = self.maxpool3(conv3)
66 |
67 | conv4 = self.conv4(maxpool3)
68 | maxpool4 = self.maxpool4(conv4)
69 |
70 | center = self.center(maxpool4)
71 | up4 = self.up_concat4(conv4, center)
72 | up3 = self.up_concat3(conv3, up4)
73 | up2 = self.up_concat2(conv2, up3)
74 | up1 = self.up_concat1(conv1, up2)
75 |
76 | # Deep Supervision
77 | dsv4 = self.dsv4(up4)
78 | dsv3 = self.dsv3(up3)
79 | dsv2 = self.dsv2(up2)
80 | dsv1 = self.dsv1(up1)
81 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1))
82 |
83 | return final
84 |
85 | @staticmethod
86 | def apply_argmax_softmax(pred):
87 | log_p = F.softmax(pred, dim=1)
88 |
89 | return log_p
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/models/aggregated_classifier.py:
--------------------------------------------------------------------------------
1 | import os, collections
2 | import numpy as np
3 | import torch
4 | from torch.autograd import Variable
5 | from .feedforward_classifier import FeedForwardClassifier
6 |
7 |
8 | class AggregatedClassifier(FeedForwardClassifier):
9 | def name(self):
10 | return 'AggregatedClassifier'
11 |
12 | def initialize(self, opts, **kwargs):
13 | FeedForwardClassifier.initialize(self, opts, **kwargs)
14 |
15 | weight = self.opts.raw.weight[:] # copy
16 | weight_t = torch.from_numpy(np.array(weight, dtype=np.float32))
17 | self.weight = weight
18 | self.aggregation = opts.raw.aggregation
19 | self.aggregation_param = opts.raw.aggregation_param
20 | self.aggregation_weight = Variable(weight_t, volatile=True).view(-1,1,1).cuda()
21 |
22 | def compute_loss(self):
23 | """Compute loss function. Iterate over multiple output"""
24 | preds = self.predictions
25 | weights = self.weight
26 | if not isinstance(preds, collections.Sequence):
27 | preds = [preds]
28 | weights = [1]
29 |
30 | loss = 0
31 | for lmda, prediction in zip(weights, preds):
32 | if lmda == 0:
33 | continue
34 | loss += lmda * self.criterion(prediction, self.target)
35 |
36 | self.loss = loss
37 |
38 | def aggregate_output(self):
39 | """Given a list of predictions from net, make a decision based on aggreagation rule"""
40 | if isinstance(self.predictions, collections.Sequence):
41 | logits = []
42 | for pred in self.predictions:
43 | logit = self.net.apply_argmax_softmax(pred).unsqueeze(0)
44 | logits.append(logit)
45 |
46 | logits = torch.cat(logits, 0)
47 | if self.aggregation == 'max':
48 | self.pred = logits.data.max(0)[0].max(1)
49 | elif self.aggregation == 'mean':
50 | self.pred = logits.data.mean(0).max(1)
51 | elif self.aggregation == 'weighted_mean':
52 | self.pred = (self.aggregation_weight.expand_as(logits) * logits).data.mean(0).max(1)
53 | elif self.aggregation == 'idx':
54 | self.pred = logits[self.aggregation_param].data.max(1)
55 | else:
56 | # Apply a softmax and return a segmentation map
57 | self.logits = self.net.apply_argmax_softmax(self.predictions)
58 | self.pred = self.logits.data.max(1)
59 |
60 |
61 | def forward(self, split):
62 | if split == 'train':
63 | self.predictions = self.net(Variable(self.input))
64 | elif split in ['validation', 'test']:
65 | self.predictions = self.net(Variable(self.input, volatile=True))
66 | self.aggregate_output()
67 |
68 | def backward(self):
69 | self.compute_loss()
70 | self.loss.backward()
71 |
72 | def validate(self):
73 | self.net.eval()
74 | self.forward(split='test')
75 | self.compute_loss()
76 | self.accumulate_results()
77 |
78 | def update_state(self, epoch):
79 | """ A function that is called at the end of every epoch. Can adjust state of the network here.
80 | For example, if one wants to change the loss weights for prediction during training (e.g. deep supervision), """
81 | if hasattr(self.opts.raw,'late_gate'):
82 | if epoch < self.opts.raw.late_gate:
83 | self.weight[0] = 0
84 | self.weight[1] = 0
85 | print('='*10,'weight={}'.format(self.weight), '='*10)
86 | if epoch == self.opts.raw.late_gate:
87 | self.weight = self.opts.raw.weight[:]
88 | weight_t = torch.from_numpy(np.array(self.weight, dtype=np.float32))
89 | self.aggregation_weight = Variable(weight_t,volatile=True).view(-1,1,1).cuda()
90 | print('='*10,'weight={}'.format(self.weight), '='*10)
91 |
--------------------------------------------------------------------------------
/models/layers/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.modules.loss import _Loss
5 | from torch.autograd import Function, Variable
6 |
7 | def cross_entropy_2D(input, target, weight=None, size_average=True):
8 | n, c, h, w = input.size()
9 | log_p = F.log_softmax(input, dim=1)
10 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
11 | target = target.view(target.numel())
12 | loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
13 | if size_average:
14 | loss /= float(target.numel())
15 | return loss
16 |
17 |
18 | def cross_entropy_3D(input, target, weight=None, size_average=True):
19 | n, c, h, w, s = input.size()
20 | log_p = F.log_softmax(input, dim=1)
21 | log_p = log_p.transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous().view(-1, c)
22 | target = target.view(target.numel())
23 | loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
24 | if size_average:
25 | loss /= float(target.numel())
26 | return loss
27 |
28 |
29 | class SoftDiceLoss(nn.Module):
30 | def __init__(self, n_classes):
31 | super(SoftDiceLoss, self).__init__()
32 | self.one_hot_encoder = One_Hot(n_classes).forward
33 | self.n_classes = n_classes
34 |
35 | def forward(self, input, target):
36 | smooth = 0.01
37 | batch_size = input.size(0)
38 |
39 | input = F.softmax(input, dim=1).view(batch_size, self.n_classes, -1)
40 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1)
41 |
42 | inter = torch.sum(input * target, 2) + smooth
43 | union = torch.sum(input, 2) + torch.sum(target, 2) + smooth
44 |
45 | score = torch.sum(2.0 * inter / union)
46 | score = 1.0 - score / (float(batch_size) * float(self.n_classes))
47 |
48 | return score
49 |
50 |
51 | class CustomSoftDiceLoss(nn.Module):
52 | def __init__(self, n_classes, class_ids):
53 | super(CustomSoftDiceLoss, self).__init__()
54 | self.one_hot_encoder = One_Hot(n_classes).forward
55 | self.n_classes = n_classes
56 | self.class_ids = class_ids
57 |
58 | def forward(self, input, target):
59 | smooth = 0.01
60 | batch_size = input.size(0)
61 |
62 | input = F.softmax(input[:,self.class_ids], dim=1).view(batch_size, len(self.class_ids), -1)
63 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1)
64 | target = target[:, self.class_ids, :]
65 |
66 | inter = torch.sum(input * target, 2) + smooth
67 | union = torch.sum(input, 2) + torch.sum(target, 2) + smooth
68 |
69 | score = torch.sum(2.0 * inter / union)
70 | score = 1.0 - score / (float(batch_size) * float(self.n_classes))
71 |
72 | return score
73 |
74 |
75 | class One_Hot(nn.Module):
76 | def __init__(self, depth):
77 | super(One_Hot, self).__init__()
78 | self.depth = depth
79 | self.ones = torch.sparse.torch.eye(depth).cuda()
80 |
81 | def forward(self, X_in):
82 | n_dim = X_in.dim()
83 | output_size = X_in.size() + torch.Size([self.depth])
84 | num_element = X_in.numel()
85 | X_in = X_in.data.long().view(num_element)
86 | out = Variable(self.ones.index_select(0, X_in)).view(output_size)
87 | return out.permute(0, -1, *range(1, n_dim)).squeeze(dim=2).float()
88 |
89 | def __repr__(self):
90 | return self.__class__.__name__ + "({})".format(self.depth)
91 |
92 |
93 | if __name__ == '__main__':
94 | from torch.autograd import Variable
95 | depth=3
96 | batch_size=2
97 | encoder = One_Hot(depth=depth).forward
98 | y = Variable(torch.LongTensor(batch_size, 1, 1, 2 ,2).random_() % depth).cuda() # 4 classes,1x3x3 img
99 | y_onehot = encoder(y)
100 | x = Variable(torch.randn(y_onehot.size()).float()).cuda()
101 | dicemetric = SoftDiceLoss(n_classes=depth)
102 | dicemetric(x,y)
--------------------------------------------------------------------------------
/validation.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 |
3 | from dataio.loader import get_dataset, get_dataset_path
4 | from dataio.transformation import get_dataset_transformation
5 | from utils.util import json_file_to_pyobj
6 |
7 | from models import get_model
8 | import numpy as np
9 | import os
10 | from utils.metrics import dice_score, distance_metric, precision_and_recall
11 | from utils.error_logger import StatLogger
12 |
13 |
14 | def mkdirfun(directory):
15 | if not os.path.exists(directory):
16 | os.makedirs(directory)
17 |
18 |
19 | def validation(json_name):
20 | # Load options
21 | json_opts = json_file_to_pyobj(json_name)
22 | train_opts = json_opts.training
23 |
24 | # Setup the NN Model
25 | model = get_model(json_opts.model)
26 | save_directory = os.path.join(model.save_dir, train_opts.arch_type); mkdirfun(save_directory)
27 |
28 | # Setup Dataset and Augmentation
29 | dataset_class = get_dataset(train_opts.arch_type)
30 | dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path)
31 | dataset_transform = get_dataset_transformation(train_opts.arch_type, opts=json_opts.augmentation)
32 |
33 | # Setup Data Loader
34 | dataset = dataset_class(dataset_path, split='validation', transform=dataset_transform['valid'])
35 | data_loader = DataLoader(dataset=dataset, num_workers=8, batch_size=1, shuffle=False)
36 |
37 | # Visualisation Parameters
38 | #visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
39 |
40 | # Setup stats logger
41 | stat_logger = StatLogger()
42 |
43 | # test
44 | for iteration, data in enumerate(data_loader, 1):
45 | model.set_input(data[0], data[1])
46 | model.test()
47 |
48 | input_arr = np.squeeze(data[0].cpu().numpy()).astype(np.float32)
49 | label_arr = np.squeeze(data[1].cpu().numpy()).astype(np.int16)
50 | output_arr = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype(np.int16)
51 |
52 | # If there is a label image - compute statistics
53 | dice_vals = dice_score(label_arr, output_arr, n_class=int(4))
54 | md, hd = distance_metric(label_arr, output_arr, dx=2.00, k=2)
55 | precision, recall = precision_and_recall(label_arr, output_arr, n_class=int(4))
56 | stat_logger.update(split='test', input_dict={'img_name': '',
57 | 'dice_LV': dice_vals[1],
58 | 'dice_MY': dice_vals[2],
59 | 'dice_RV': dice_vals[3],
60 | 'prec_MYO':precision[2],
61 | 'reca_MYO':recall[2],
62 | 'md_MYO': md,
63 | 'hd_MYO': hd
64 | })
65 |
66 | # Write a nifti image
67 | import SimpleITK as sitk
68 | input_img = sitk.GetImageFromArray(np.transpose(input_arr, (2, 1, 0))); input_img.SetDirection([-1,0,0,0,-1,0,0,0,1])
69 | label_img = sitk.GetImageFromArray(np.transpose(label_arr, (2, 1, 0))); label_img.SetDirection([-1,0,0,0,-1,0,0,0,1])
70 | predi_img = sitk.GetImageFromArray(np.transpose(output_arr,(2, 1, 0))); predi_img.SetDirection([-1,0,0,0,-1,0,0,0,1])
71 |
72 | sitk.WriteImage(input_img, os.path.join(save_directory,'{}_img.nii.gz'.format(iteration)))
73 | sitk.WriteImage(label_img, os.path.join(save_directory,'{}_lbl.nii.gz'.format(iteration)))
74 | sitk.WriteImage(predi_img, os.path.join(save_directory,'{}_pred.nii.gz'.format(iteration)))
75 |
76 | stat_logger.statlogger2csv(split='test', out_csv_name=os.path.join(save_directory,'stats.csv'))
77 | for key, (mean_val, std_val) in stat_logger.get_errors(split='test').items():
78 | print('-',key,': \t{0:.3f}+-{1:.3f}'.format(mean_val, std_val),'-')
79 |
80 |
81 | if __name__ == '__main__':
82 | import argparse
83 |
84 | parser = argparse.ArgumentParser(description='CNN Seg Validation Function')
85 |
86 | parser.add_argument('-c', '--config', help='testing config file', required=True)
87 | args = parser.parse_args()
88 |
89 | validation(args.config)
90 |
--------------------------------------------------------------------------------
/models/networks/unet_grid_attention_3D.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .utils import UnetConv3, UnetUp3, UnetGridGatingSignal3
3 | import torch.nn.functional as F
4 | from models.layers.grid_attention_layer import GridAttentionBlock3D
5 | from models.networks_other import init_weights
6 |
7 | class unet_grid_attention_3D(nn.Module):
8 |
9 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3,
10 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True):
11 | super(unet_grid_attention_3D, self).__init__()
12 | self.is_deconv = is_deconv
13 | self.in_channels = in_channels
14 | self.is_batchnorm = is_batchnorm
15 | self.feature_scale = feature_scale
16 |
17 | filters = [64, 128, 256, 512, 1024]
18 | filters = [int(x / self.feature_scale) for x in filters]
19 |
20 | # downsampling
21 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm)
22 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1))
23 |
24 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm)
25 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1))
26 |
27 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm)
28 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1))
29 |
30 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm)
31 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1))
32 |
33 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm)
34 | self.gating = UnetGridGatingSignal3(filters[4], filters[3], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm)
35 |
36 | # attention blocks
37 | self.attentionblock2 = GridAttentionBlock3D(in_channels=filters[1], gating_channels=filters[3],
38 | inter_channels=filters[1], sub_sample_factor=attention_dsample, mode=nonlocal_mode)
39 | self.attentionblock3 = GridAttentionBlock3D(in_channels=filters[2], gating_channels=filters[3],
40 | inter_channels=filters[2], sub_sample_factor=attention_dsample, mode=nonlocal_mode)
41 | self.attentionblock4 = GridAttentionBlock3D(in_channels=filters[3], gating_channels=filters[3],
42 | inter_channels=filters[3], sub_sample_factor=attention_dsample, mode=nonlocal_mode)
43 |
44 | # upsampling
45 | self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv, self.is_batchnorm)
46 | self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv, self.is_batchnorm)
47 | self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv, self.is_batchnorm)
48 | self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv, self.is_batchnorm)
49 |
50 | # final conv (without any concat)
51 | self.final = nn.Conv3d(filters[0], n_classes, 1)
52 |
53 | # initialise weights
54 | for m in self.modules():
55 | if isinstance(m, nn.Conv3d):
56 | init_weights(m, init_type='kaiming')
57 | elif isinstance(m, nn.BatchNorm3d):
58 | init_weights(m, init_type='kaiming')
59 |
60 | def forward(self, inputs):
61 | # Feature Extraction
62 | conv1 = self.conv1(inputs)
63 | maxpool1 = self.maxpool1(conv1)
64 |
65 | conv2 = self.conv2(maxpool1)
66 | maxpool2 = self.maxpool2(conv2)
67 |
68 | conv3 = self.conv3(maxpool2)
69 | maxpool3 = self.maxpool3(conv3)
70 |
71 | conv4 = self.conv4(maxpool3)
72 | maxpool4 = self.maxpool4(conv4)
73 |
74 | # Gating Signal Generation
75 | center = self.center(maxpool4)
76 | gating = self.gating(center)
77 |
78 | # Attention Mechanism
79 | g_conv4, att4 = self.attentionblock4(conv4, gating)
80 | g_conv3, att3 = self.attentionblock3(conv3, gating)
81 | g_conv2, att2 = self.attentionblock2(conv2, gating)
82 |
83 | # Upscaling Part (Decoder)
84 | up4 = self.up_concat4(g_conv4, center)
85 | up3 = self.up_concat3(g_conv3, up4)
86 | up2 = self.up_concat2(g_conv2, up3)
87 | up1 = self.up_concat1(conv1, up2)
88 |
89 | final = self.final(up1)
90 |
91 | return final
92 |
93 | @staticmethod
94 | def apply_argmax_softmax(pred):
95 | log_p = F.softmax(pred, dim=1)
96 |
97 | return log_p
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
--------------------------------------------------------------------------------
/utils/error_logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .util import csv_write
3 |
4 |
5 | class BaseMeter(object):
6 | """Just a place holderb"""
7 |
8 | def __init__(self, name):
9 | self.reset()
10 | self.name = name
11 |
12 | def reset(self):
13 | pass
14 |
15 | def update(self, val):
16 | self.val = val
17 |
18 | def get_value(self):
19 | return self.val
20 |
21 |
22 | class AverageMeter(object):
23 | """Computes and stores the average and current value"""
24 |
25 | def __init__(self, name):
26 | self.reset()
27 | self.name = name
28 |
29 | def reset(self):
30 | self.val = 0.0
31 | self.avg = 0.0
32 | self.sum = 0.0
33 | self.count = 0.0
34 |
35 | def update(self, val, n=1.0):
36 | self.val = val
37 | self.sum += val * n
38 | self.count += n
39 | self.avg = self.sum / self.count
40 |
41 | def get_value(self):
42 | return self.avg
43 |
44 | class StatMeter(object):
45 | """Computes and stores the error vals and image names"""
46 |
47 | def __init__(self, name, csv_name=None):
48 | self.reset()
49 | self.name = name
50 |
51 | def reset(self):
52 | self.vals = []
53 | self.img_names = []
54 |
55 | def update(self, val, img_name):
56 | self.vals.append(val)
57 | self.img_names.append(img_name)
58 |
59 | def return_average(self):
60 | values_array = np.array(self.vals, dtype=np.float)
61 | return np.nanmean(values_array)
62 |
63 | def return_std(self):
64 | values_array = np.array(self.vals, dtype=np.float)
65 | return np.nanstd(values_array)
66 |
67 |
68 | class ErrorLogger(object):
69 |
70 | def __init__(self):
71 | self.variables = {'train': dict(),
72 | 'validation': dict(),
73 | 'test': dict()
74 | }
75 |
76 | def update(self, input_dict, split):
77 |
78 | for key, value in input_dict.items():
79 | if key not in self.variables[split]:
80 | if np.isscalar(value):
81 | self.variables[split][key] = AverageMeter(name=key)
82 | else:
83 | self.variables[split][key] = BaseMeter(name=key)
84 |
85 | self.variables[split][key].update(value)
86 |
87 |
88 | def get_errors(self, split):
89 | output = dict()
90 | for key, meter_obj in self.variables[split].items():
91 | output[key] = meter_obj.get_value()
92 | return output
93 |
94 | def reset(self):
95 | for key, meter_obj in self.variables['train'].items():
96 | meter_obj.reset()
97 | for key, meter_obj in self.variables['validation'].items():
98 | meter_obj.reset()
99 | for key, meter_obj in self.variables['test'].items():
100 | meter_obj.reset()
101 |
102 |
103 | class StatLogger(object):
104 |
105 | def __init__(self):
106 | self.variables = {'train': dict(),
107 | 'validation': dict(),
108 | 'test': dict()
109 | }
110 |
111 | def update(self, input_dict, split):
112 | img_name = input_dict.pop('img_name', None)
113 | for key, value in input_dict.items():
114 | if key not in self.variables[split]:
115 | self.variables[split][key] = StatMeter(name=key)
116 | self.variables[split][key].update(val=value, img_name=img_name)
117 |
118 | def get_errors(self, split):
119 | output = dict()
120 | for key, meter_obj in self.variables[split].items():
121 | output[key] = (meter_obj.return_average(), meter_obj.return_std())
122 | return output
123 |
124 | def statlogger2csv(self, split, out_csv_name):
125 | csv_values = []; csv_header = []
126 | for loopId, (meter_key, meter_obj) in enumerate(self.variables[split].items(), 1):
127 | if loopId == 1: csv_values.append(meter_obj.img_names); csv_header.append('img_names')
128 | csv_values.append(meter_obj.vals)
129 | csv_header.append(meter_key)
130 | csv_write(out_csv_name, csv_header, csv_values)
131 |
132 | def reset(self):
133 | for key, meter_obj in self.variables['train'].items():
134 | meter_obj.reset()
135 | for key, meter_obj in self.variables['validation'].items():
136 | meter_obj.reset()
137 | for key, meter_obj in self.variables['test'].items():
138 | meter_obj.reset()
139 |
--------------------------------------------------------------------------------
/train_segmentation.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | from torch.utils.data import DataLoader
3 | from tqdm import tqdm
4 |
5 |
6 | from dataio.loader import get_dataset, get_dataset_path
7 | from dataio.transformation import get_dataset_transformation
8 | from utils.util import json_file_to_pyobj
9 | from utils.visualiser import Visualiser
10 | from utils.error_logger import ErrorLogger
11 |
12 | from models import get_model
13 |
14 | def train(arguments):
15 |
16 | # Parse input arguments
17 | json_filename = arguments.config
18 | network_debug = arguments.debug
19 |
20 | # Load options
21 | json_opts = json_file_to_pyobj(json_filename)
22 | train_opts = json_opts.training
23 |
24 | # Architecture type
25 | arch_type = train_opts.arch_type
26 |
27 | # Setup Dataset and Augmentation
28 | ds_class = get_dataset(arch_type)
29 | ds_path = get_dataset_path(arch_type, json_opts.data_path)
30 | ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation)
31 |
32 | # Setup the NN Model
33 | model = get_model(json_opts.model)
34 | if network_debug:
35 | print('# of pars: ', model.get_number_parameters())
36 | print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(*model.get_fp_bp_time()))
37 | exit()
38 |
39 | # Setup Data Loader
40 | train_dataset = ds_class(ds_path, split='train', transform=ds_transform['train'], preload_data=train_opts.preloadData)
41 | valid_dataset = ds_class(ds_path, split='validation', transform=ds_transform['valid'], preload_data=train_opts.preloadData)
42 | test_dataset = ds_class(ds_path, split='test', transform=ds_transform['valid'], preload_data=train_opts.preloadData)
43 | train_loader = DataLoader(dataset=train_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=True)
44 | valid_loader = DataLoader(dataset=valid_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=False)
45 | test_loader = DataLoader(dataset=test_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=False)
46 |
47 | # Visualisation Parameters
48 | visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
49 | error_logger = ErrorLogger()
50 |
51 | # Training Function
52 | model.set_scheduler(train_opts)
53 | for epoch in range(model.which_epoch, train_opts.n_epochs):
54 | print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))
55 |
56 | # Training Iterations
57 | for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1), total=len(train_loader)):
58 | # Make a training update
59 | model.set_input(images, labels)
60 | model.optimize_parameters()
61 | #model.optimize_parameters_accumulate_grd(epoch_iter)
62 |
63 | # Error visualisation
64 | errors = model.get_current_errors()
65 | error_logger.update(errors, split='train')
66 |
67 | # Validation and Testing Iterations
68 | for loader, split in zip([valid_loader, test_loader], ['validation', 'test']):
69 | for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1), total=len(loader)):
70 |
71 | # Make a forward pass with the model
72 | model.set_input(images, labels)
73 | model.validate()
74 |
75 | # Error visualisation
76 | errors = model.get_current_errors()
77 | stats = model.get_segmentation_stats()
78 | error_logger.update({**errors, **stats}, split=split)
79 |
80 | # Visualise predictions
81 | visuals = model.get_current_visuals()
82 | visualizer.display_current_results(visuals, epoch=epoch, save_result=False)
83 |
84 | # Update the plots
85 | for split in ['train', 'validation', 'test']:
86 | visualizer.plot_current_errors(epoch, error_logger.get_errors(split), split_name=split)
87 | visualizer.print_current_errors(epoch, error_logger.get_errors(split), split_name=split)
88 | error_logger.reset()
89 |
90 | # Save the model parameters
91 | if epoch % train_opts.save_epoch_freq == 0:
92 | model.save(epoch)
93 |
94 | # Update the model learning rate
95 | model.update_learning_rate()
96 |
97 |
98 | if __name__ == '__main__':
99 | import argparse
100 |
101 | parser = argparse.ArgumentParser(description='CNN Seg Training Function')
102 |
103 | parser.add_argument('-c', '--config', help='training config file', required=True)
104 | parser.add_argument('-d', '--debug', help='returns number of parameters and bp/fp runtime', action='store_true')
105 | args = parser.parse_args()
106 |
107 | train(args)
108 |
--------------------------------------------------------------------------------
/utils/post_process_crf.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import numpy as np, nibabel as nib
3 |
4 | import pydensecrf.densecrf as dcrf
5 | from pydensecrf.utils import create_pairwise_bilateral, create_pairwise_gaussian
6 |
7 |
8 | def apply_crf(input_image, input_prob, theta_a, theta_b, theta_r, mu1, mu2):
9 | n_slices = input_image.shape[2]
10 | output = np.zeros(input_image.shape)
11 | for slice_id in range(n_slices):
12 | image = input_image[:,:,slice_id]
13 | prob = input_prob[:,:,slice_id,:]
14 |
15 | n_pixel = image.shape[0] * image.shape[1]
16 | n_class = prob.shape[-1]
17 |
18 | P = np.transpose(prob, axes=(2, 0, 1))
19 |
20 | # Setup the CRF model
21 | d = dcrf.DenseCRF(n_pixel, n_class)
22 |
23 | # Set unary potentials (negative log probability)
24 | U = - np.log(P + 1e-10)
25 | U = np.ascontiguousarray(U.reshape((n_class, n_pixel)))
26 | d.setUnaryEnergy(U)
27 |
28 | # Set edge potential
29 | # This creates the color-dependent features and then add them to the CRF
30 | feats = create_pairwise_bilateral(sdims=(theta_a, theta_a), schan=(theta_b,), img=image, chdim=-1)
31 | d.addPairwiseEnergy(feats, compat=mu1, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)
32 |
33 | # This creates the color-independent features and then add them to the CRF
34 | feats = create_pairwise_gaussian(sdims=(theta_r, theta_r), shape=image.shape)
35 | d.addPairwiseEnergy(feats, compat=mu2, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)
36 |
37 | # Perform the inference
38 | Q = d.inference(5)
39 | res = np.argmax(Q, axis=0).astype('float32')
40 | res = np.reshape(res, image.shape).astype(dtype='int8')
41 | output[:,:,slice_id] = res
42 |
43 | return output
44 |
45 |
46 | if __name__ == '__main__':
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--n_train', metavar='int', nargs=1, default=['80'], help='number of training subjects')
49 | args = parser.parse_args()
50 |
51 | # Data path
52 | data_path = '/vol/medic02/users/wbai/data/cardiac_atlas/Biobank'
53 | data_list = sorted(os.listdir(data_path))
54 | dest_path = '/vol/bitbucket/wbai/cardiac_cnn/Biobank/seg'
55 |
56 | # Model name
57 | size = 224
58 | n_train = int(args.n_train[0])
59 | model_name = 'FCN_VGG16_sz{0}_n{1}_2d'.format(size, n_train)
60 | epoch = 500
61 |
62 | # model_name = 'FCN_VGG16_sz{0}_prob_atlas_stepwise'.format(size)
63 | # epoch = 200
64 |
65 | # model_name = 'FCN_VGG16_sz{0}_auto_context_stepwise'.format(size)
66 | # epoch = 200
67 |
68 | for data in data_list:
69 | print(data)
70 | data_dir = os.path.join(data_path, data)
71 | dest_dir = os.path.join(dest_path, data)
72 |
73 | # tune_dir = os.path.join(dest_dir, 'tune')
74 | # if not os.path.exists(tune_dir):
75 | # os.mkdir(tune_dir)
76 |
77 | for fr in ['ED', 'ES']:
78 | # Read image
79 | nim = nib.load(os.path.join(data_dir, 'image_{0}.nii.gz'.format(fr)))
80 | image = np.squeeze(nim.get_data())
81 |
82 | # Scale the intensity to be [0, 1] so that we can set a consistent intensity parameter for CRF
83 | #image = intensity_rescaling(image, 1, 99)
84 |
85 | # Read probability map
86 | nim = nib.load(os.path.join(dest_dir, 'prob_{0}_{1}_epoch{2:03d}.nii.gz'.format(fr, model_name, epoch)))
87 | prob = nim.get_data()
88 |
89 | # Apply CRF
90 | mu1 = 1
91 | theta_a = 0.5
92 | theta_b = 1
93 | mu2 = 2
94 | theta_r = 1
95 | seg = apply_crf(image, prob, theta_a, theta_b, theta_r, mu1, mu2)
96 |
97 | # Save the CRF segmentation
98 | seg_name = os.path.join(dest_dir, 'seg_{0}_{1}_epoch{2:03d}_crf.nii.gz'.format(fr, model_name, epoch))
99 | nib.save(nib.Nifti1Image(seg, nim.affine), seg_name)
100 |
101 | # For parameter tuning
102 | # nib.save(nib.Nifti1Image(seg, nim.affine), os.path.join(tune_dir, 'seg_{0}_crf_mu2{1}_sr{2}.nii.gz'.format(fr, mu2, theta_r)))
103 | # nib.save(nib.Nifti1Image(seg, nim.affine), os.path.join(tune_dir, 'seg_{0}_crf_mu1{1}_sa{2:.1f}_sb{1}.nii.gz'.format(fr, mu1, theta_a, theta_b)))
104 |
105 | # # Fit to the template
106 | # template_dir = '/vol/medic02/users/wbai/data/imperial_atlas/template'
107 | # par_dir = '/vol/vipdata/data/biobank/cardiac/Application_18545/par'
108 | # out_name = os.path.join(dest_dir, 'seg_{0}_{1}_epoch{2:03d}_crf_fit.nii.gz'.format(fr, model_name, epoch))
109 | # fit_to_template(seg_name, fr, template_dir, par_dir, out_name)
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Misc Utility functions
3 | '''
4 |
5 | import os
6 | import numpy as np
7 | import torch.optim as optim
8 | from torch.nn import CrossEntropyLoss
9 | from utils.metrics import segmentation_scores, dice_score_list
10 | from sklearn import metrics
11 | from .layers.loss import *
12 |
13 | def get_optimizer(option, params):
14 | opt_alg = 'sgd' if not hasattr(option, 'optim') else option.optim
15 | if opt_alg == 'sgd':
16 | optimizer = optim.SGD(params,
17 | lr=option.lr_rate,
18 | momentum=0.9,
19 | nesterov=True,
20 | weight_decay=option.l2_reg_weight)
21 |
22 | if opt_alg == 'adam':
23 | optimizer = optim.Adam(params,
24 | lr=option.lr_rate,
25 | betas=(0.9, 0.999),
26 | weight_decay=option.l2_reg_weight)
27 |
28 | return optimizer
29 |
30 |
31 | def get_criterion(opts):
32 | if opts.criterion == 'cross_entropy':
33 | if opts.type == 'seg':
34 | criterion = cross_entropy_2D if opts.tensor_dim == '2D' else cross_entropy_3D
35 | elif 'classifier' in opts.type:
36 | criterion = CrossEntropyLoss()
37 | elif opts.criterion == 'dice_loss':
38 | criterion = SoftDiceLoss(opts.output_nc)
39 | elif opts.criterion == 'dice_loss_pancreas_only':
40 | criterion = CustomSoftDiceLoss(opts.output_nc, class_ids=[0, 2])
41 |
42 | return criterion
43 |
44 | def recursive_glob(rootdir='.', suffix=''):
45 | """Performs recursive glob with given suffix and rootdir
46 | :param rootdir is the root directory
47 | :param suffix is the suffix to be searched
48 | """
49 | return [os.path.join(looproot, filename)
50 | for looproot, _, filenames in os.walk(rootdir)
51 | for filename in filenames if filename.endswith(suffix)]
52 |
53 | def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9,):
54 | """Polynomial decay of learning rate
55 | :param init_lr is base learning rate
56 | :param iter is a current iteration
57 | :param lr_decay_iter how frequently decay occurs, default is 1
58 | :param max_iter is number of maximum iterations
59 | :param power is a polymomial power
60 |
61 | """
62 | if iter % lr_decay_iter or iter > max_iter:
63 | return optimizer
64 |
65 | for param_group in optimizer.param_groups:
66 | param_group['lr'] = init_lr*(1 - iter/max_iter)**power
67 |
68 |
69 | def adjust_learning_rate(optimizer, init_lr, epoch):
70 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
71 | lr = init_lr * (0.1 ** (epoch // 30))
72 | for param_group in optimizer.param_groups:
73 | param_group['lr'] = lr
74 |
75 |
76 | def segmentation_stats(pred_seg, target):
77 | n_classes = pred_seg.size(1)
78 | pred_lbls = pred_seg.data.max(1)[1].cpu().numpy()
79 | gt = np.squeeze(target.data.cpu().numpy(), axis=1)
80 | gts, preds = [], []
81 | for gt_, pred_ in zip(gt, pred_lbls):
82 | gts.append(gt_)
83 | preds.append(pred_)
84 |
85 | iou = segmentation_scores(gts, preds, n_class=n_classes)
86 | dice = dice_score_list(gts, preds, n_class=n_classes)
87 |
88 | return iou, dice
89 |
90 |
91 | def classification_scores(gts, preds, labels):
92 | accuracy = metrics.accuracy_score(gts, preds)
93 | class_accuracies = []
94 | for lab in labels: # TODO Fix
95 | class_accuracies.append(metrics.accuracy_score(gts[gts == lab], preds[gts == lab]))
96 | class_accuracies = np.array(class_accuracies)
97 |
98 | f1_micro = metrics.f1_score(gts, preds, average='micro')
99 | precision_micro = metrics.precision_score(gts, preds, average='micro')
100 | recall_micro = metrics.recall_score(gts, preds, average='micro')
101 | f1_macro = metrics.f1_score(gts, preds, average='macro')
102 | precision_macro = metrics.precision_score(gts, preds, average='macro')
103 | recall_macro = metrics.recall_score(gts, preds, average='macro')
104 |
105 | # class wise score
106 | f1s = metrics.f1_score(gts, preds, average=None)
107 | precisions = metrics.precision_score(gts, preds, average=None)
108 | recalls = metrics.recall_score(gts, preds, average=None)
109 |
110 | confusion = metrics.confusion_matrix(gts,preds, labels=labels)
111 |
112 | #TODO confusion matrix, recall, precision
113 | return accuracy, f1_micro, precision_micro, recall_micro, f1_macro, precision_macro, recall_macro, confusion, class_accuracies, f1s, precisions, recalls
114 |
115 |
116 | def classification_stats(pred_seg, target, labels):
117 | return classification_scores(target, pred_seg, labels)
118 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Originally written by wkentaro
2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
3 |
4 | import numpy as np
5 | import cv2
6 |
7 | def _fast_hist(label_true, label_pred, n_class):
8 | mask = (label_true >= 0) & (label_true < n_class)
9 | hist = np.bincount(
10 | n_class * label_true[mask].astype(int) +
11 | label_pred[mask], minlength=n_class**2).reshape(n_class, n_class)
12 | return hist
13 |
14 |
15 | def segmentation_scores(label_trues, label_preds, n_class):
16 | """Returns accuracy score evaluation result.
17 | - overall accuracy
18 | - mean accuracy
19 | - mean IU
20 | - fwavacc
21 | """
22 | hist = np.zeros((n_class, n_class))
23 | for lt, lp in zip(label_trues, label_preds):
24 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
25 | acc = np.diag(hist).sum() / hist.sum()
26 | acc_cls = np.diag(hist) / hist.sum(axis=1)
27 | acc_cls = np.nanmean(acc_cls)
28 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
29 | mean_iu = np.nanmean(iu)
30 | freq = hist.sum(axis=1) / hist.sum()
31 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
32 |
33 | return {'overall_acc': acc,
34 | 'mean_acc': acc_cls,
35 | 'freq_w_acc': fwavacc,
36 | 'mean_iou': mean_iu}
37 |
38 |
39 | def dice_score_list(label_gt, label_pred, n_class):
40 | """
41 |
42 | :param label_gt: [WxH] (2D images)
43 | :param label_pred: [WxH] (2D images)
44 | :param n_class: number of label classes
45 | :return:
46 | """
47 | epsilon = 1.0e-6
48 | assert len(label_gt) == len(label_pred)
49 | batchSize = len(label_gt)
50 | dice_scores = np.zeros((batchSize, n_class), dtype=np.float32)
51 | for batch_id, (l_gt, l_pred) in enumerate(zip(label_gt, label_pred)):
52 | for class_id in range(n_class):
53 | img_A = np.array(l_gt == class_id, dtype=np.float32).flatten()
54 | img_B = np.array(l_pred == class_id, dtype=np.float32).flatten()
55 | score = 2.0 * np.sum(img_A * img_B) / (np.sum(img_A) + np.sum(img_B) + epsilon)
56 | dice_scores[batch_id, class_id] = score
57 |
58 | return np.mean(dice_scores, axis=0)
59 |
60 |
61 | def dice_score(label_gt, label_pred, n_class):
62 | """
63 |
64 | :param label_gt:
65 | :param label_pred:
66 | :param n_class:
67 | :return:
68 | """
69 |
70 | epsilon = 1.0e-6
71 | assert np.all(label_gt.shape == label_pred.shape)
72 | dice_scores = np.zeros(n_class, dtype=np.float32)
73 | for class_id in range(n_class):
74 | img_A = np.array(label_gt == class_id, dtype=np.float32).flatten()
75 | img_B = np.array(label_pred == class_id, dtype=np.float32).flatten()
76 | score = 2.0 * np.sum(img_A * img_B) / (np.sum(img_A) + np.sum(img_B) + epsilon)
77 | dice_scores[class_id] = score
78 |
79 | return dice_scores
80 |
81 |
82 | def precision_and_recall(label_gt, label_pred, n_class):
83 | from sklearn.metrics import precision_score, recall_score
84 | assert len(label_gt) == len(label_pred)
85 | precision = np.zeros(n_class, dtype=np.float32)
86 | recall = np.zeros(n_class, dtype=np.float32)
87 | img_A = np.array(label_gt, dtype=np.float32).flatten()
88 | img_B = np.array(label_pred, dtype=np.float32).flatten()
89 | precision[:] = precision_score(img_A, img_B, average=None, labels=range(n_class))
90 | recall[:] = recall_score(img_A, img_B, average=None, labels=range(n_class))
91 |
92 | return precision, recall
93 |
94 |
95 | def distance_metric(seg_A, seg_B, dx, k):
96 | """
97 | Measure the distance errors between the contours of two segmentations.
98 | The manual contours are drawn on 2D slices.
99 | We calculate contour to contour distance for each slice.
100 | """
101 |
102 | # Extract the label k from the segmentation maps to generate binary maps
103 | seg_A = (seg_A == k)
104 | seg_B = (seg_B == k)
105 |
106 | table_md = []
107 | table_hd = []
108 | X, Y, Z = seg_A.shape
109 | for z in range(Z):
110 | # Binary mask at this slice
111 | slice_A = seg_A[:, :, z].astype(np.uint8)
112 | slice_B = seg_B[:, :, z].astype(np.uint8)
113 |
114 | # The distance is defined only when both contours exist on this slice
115 | if np.sum(slice_A) > 0 and np.sum(slice_B) > 0:
116 | # Find contours and retrieve all the points
117 | _, contours, _ = cv2.findContours(cv2.inRange(slice_A, 1, 1),
118 | cv2.RETR_EXTERNAL,
119 | cv2.CHAIN_APPROX_NONE)
120 | pts_A = contours[0]
121 | for i in range(1, len(contours)):
122 | pts_A = np.vstack((pts_A, contours[i]))
123 |
124 | _, contours, _ = cv2.findContours(cv2.inRange(slice_B, 1, 1),
125 | cv2.RETR_EXTERNAL,
126 | cv2.CHAIN_APPROX_NONE)
127 | pts_B = contours[0]
128 | for i in range(1, len(contours)):
129 | pts_B = np.vstack((pts_B, contours[i]))
130 |
131 | # Distance matrix between point sets
132 | M = np.zeros((len(pts_A), len(pts_B)))
133 | for i in range(len(pts_A)):
134 | for j in range(len(pts_B)):
135 | M[i, j] = np.linalg.norm(pts_A[i, 0] - pts_B[j, 0])
136 |
137 | # Mean distance and hausdorff distance
138 | md = 0.5 * (np.mean(np.min(M, axis=0)) + np.mean(np.min(M, axis=1))) * dx
139 | hd = np.max([np.max(np.min(M, axis=0)), np.max(np.min(M, axis=1))]) * dx
140 | table_md += [md]
141 | table_hd += [hd]
142 |
143 | # Return the mean distance and Hausdorff distance across 2D slices
144 | mean_md = np.mean(table_md) if table_md else None
145 | mean_hd = np.mean(table_hd) if table_hd else None
146 | return mean_md, mean_hd
--------------------------------------------------------------------------------
/models/networks/sononet_grid_attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import torch.nn as nn
4 | from .utils import unetConv2, unetUp, conv2DBatchNormRelu, conv2DBatchNorm
5 | import torch
6 | import torch.nn.functional as F
7 | from models.layers.grid_attention_layer import GridAttentionBlock2D_TORR as AttentionBlock2D
8 | from models.networks_other import init_weights
9 |
10 | class sononet_grid_attention(nn.Module):
11 |
12 | def __init__(self, feature_scale=4, n_classes=21, in_channels=3, is_batchnorm=True, n_convs=None,
13 | nonlocal_mode='concatenation', aggregation_mode='concat'):
14 | super(sononet_grid_attention, self).__init__()
15 | self.in_channels = in_channels
16 | self.is_batchnorm = is_batchnorm
17 | self.feature_scale = feature_scale
18 | self.n_classes= n_classes
19 | self.aggregation_mode = aggregation_mode
20 | self.deep_supervised = True
21 |
22 | if n_convs is None:
23 | n_convs = [3, 3, 3, 2, 2]
24 |
25 | filters = [64, 128, 256, 512]
26 | filters = [int(x / self.feature_scale) for x in filters]
27 |
28 | ####################
29 | # Feature Extraction
30 | self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm, n=n_convs[0])
31 | self.maxpool1 = nn.MaxPool2d(kernel_size=2)
32 |
33 | self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm, n=n_convs[1])
34 | self.maxpool2 = nn.MaxPool2d(kernel_size=2)
35 |
36 | self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm, n=n_convs[2])
37 | self.maxpool3 = nn.MaxPool2d(kernel_size=2)
38 |
39 | self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm, n=n_convs[3])
40 | self.maxpool4 = nn.MaxPool2d(kernel_size=2)
41 |
42 | self.conv5 = unetConv2(filters[3], filters[3], self.is_batchnorm, n=n_convs[4])
43 |
44 | ################
45 | # Attention Maps
46 | self.compatibility_score1 = AttentionBlock2D(in_channels=filters[2], gating_channels=filters[3],
47 | inter_channels=filters[3], sub_sample_factor=(1,1),
48 | mode=nonlocal_mode, use_W=False, use_phi=True,
49 | use_theta=True, use_psi=True, nonlinearity1='relu')
50 |
51 | self.compatibility_score2 = AttentionBlock2D(in_channels=filters[3], gating_channels=filters[3],
52 | inter_channels=filters[3], sub_sample_factor=(1,1),
53 | mode=nonlocal_mode, use_W=False, use_phi=True,
54 | use_theta=True, use_psi=True, nonlinearity1='relu')
55 |
56 | #########################
57 | # Aggreagation Strategies
58 | self.attention_filter_sizes = [filters[2], filters[3]]
59 |
60 | if aggregation_mode == 'concat':
61 | self.classifier = nn.Linear(filters[2]+filters[3]+filters[3], n_classes)
62 | self.aggregate = self.aggreagation_concat
63 |
64 | else:
65 | self.classifier1 = nn.Linear(filters[2], n_classes)
66 | self.classifier2 = nn.Linear(filters[3], n_classes)
67 | self.classifier3 = nn.Linear(filters[3], n_classes)
68 | self.classifiers = [self.classifier1, self.classifier2, self.classifier3]
69 |
70 | if aggregation_mode == 'mean':
71 | self.aggregate = self.aggregation_sep
72 |
73 | elif aggregation_mode == 'deep_sup':
74 | self.classifier = nn.Linear(filters[2] + filters[3] + filters[3], n_classes)
75 | self.aggregate = self.aggregation_ds
76 |
77 | elif aggregation_mode == 'ft':
78 | self.classifier = nn.Linear(n_classes*3, n_classes)
79 | self.aggregate = self.aggregation_ft
80 | else:
81 | raise NotImplementedError
82 |
83 | ####################
84 | # initialise weights
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | init_weights(m, init_type='kaiming')
88 | elif isinstance(m, nn.BatchNorm2d):
89 | init_weights(m, init_type='kaiming')
90 |
91 |
92 | def aggregation_sep(self, *attended_maps):
93 | return [ clf(att) for clf, att in zip(self.classifiers, attended_maps) ]
94 |
95 | def aggregation_ft(self, *attended_maps):
96 | preds = self.aggregation_sep(*attended_maps)
97 | return self.classifier(torch.cat(preds, dim=1))
98 |
99 | def aggregation_ds(self, *attended_maps):
100 | preds_sep = self.aggregation_sep(*attended_maps)
101 | pred = self.aggregation_concat(*attended_maps)
102 | return [pred] + preds_sep
103 |
104 | def aggregation_concat(self, *attended_maps):
105 | return self.classifier(torch.cat(attended_maps, dim=1))
106 |
107 |
108 | def forward(self, inputs):
109 | # Feature Extraction
110 | conv1 = self.conv1(inputs)
111 | maxpool1 = self.maxpool1(conv1)
112 |
113 | conv2 = self.conv2(maxpool1)
114 | maxpool2 = self.maxpool2(conv2)
115 |
116 | conv3 = self.conv3(maxpool2)
117 | maxpool3 = self.maxpool3(conv3)
118 |
119 | conv4 = self.conv4(maxpool3)
120 | maxpool4 = self.maxpool4(conv4)
121 |
122 | conv5 = self.conv5(maxpool4)
123 |
124 | batch_size = inputs.shape[0]
125 | pooled = F.adaptive_avg_pool2d(conv5, (1, 1)).view(batch_size, -1)
126 |
127 | # Attention Mechanism
128 | g_conv1, att1 = self.compatibility_score1(conv3, conv5)
129 | g_conv2, att2 = self.compatibility_score2(conv4, conv5)
130 |
131 | # flatten to get single feature vector
132 | fsizes = self.attention_filter_sizes
133 | g1 = torch.sum(g_conv1.view(batch_size, fsizes[0], -1), dim=-1)
134 | g2 = torch.sum(g_conv2.view(batch_size, fsizes[1], -1), dim=-1)
135 |
136 | return self.aggregate(g1, g2, pooled)
137 |
138 |
139 | @staticmethod
140 | def apply_argmax_softmax(pred):
141 | log_p = F.softmax(pred, dim=1)
142 |
143 | return log_p
144 |
--------------------------------------------------------------------------------
/models/networks/unet_CT_single_att_dsv_3D.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from .utils import UnetConv3, UnetUp3_CT, UnetGridGatingSignal3, UnetDsv3
4 | import torch.nn.functional as F
5 | from models.networks_other import init_weights
6 | from models.layers.grid_attention_layer import GridAttentionBlock3D
7 |
8 |
9 | class unet_CT_single_att_dsv_3D(nn.Module):
10 |
11 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3,
12 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True):
13 | super(unet_CT_single_att_dsv_3D, self).__init__()
14 | self.is_deconv = is_deconv
15 | self.in_channels = in_channels
16 | self.is_batchnorm = is_batchnorm
17 | self.feature_scale = feature_scale
18 |
19 | filters = [64, 128, 256, 512, 1024]
20 | filters = [int(x / self.feature_scale) for x in filters]
21 |
22 | # downsampling
23 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
24 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))
25 |
26 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
27 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))
28 |
29 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
30 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))
31 |
32 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
33 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))
34 |
35 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
36 | self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm)
37 |
38 | # attention blocks
39 | self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],
40 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
41 | self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],
42 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
43 | self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3],
44 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
45 |
46 | # upsampling
47 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
48 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
49 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
50 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)
51 |
52 | # deep supervision
53 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8)
54 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4)
55 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2)
56 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1)
57 |
58 | # final conv (without any concat)
59 | self.final = nn.Conv3d(n_classes*4, n_classes, 1)
60 |
61 | # initialise weights
62 | for m in self.modules():
63 | if isinstance(m, nn.Conv3d):
64 | init_weights(m, init_type='kaiming')
65 | elif isinstance(m, nn.BatchNorm3d):
66 | init_weights(m, init_type='kaiming')
67 |
68 | def forward(self, inputs):
69 | # Feature Extraction
70 | conv1 = self.conv1(inputs)
71 | maxpool1 = self.maxpool1(conv1)
72 |
73 | conv2 = self.conv2(maxpool1)
74 | maxpool2 = self.maxpool2(conv2)
75 |
76 | conv3 = self.conv3(maxpool2)
77 | maxpool3 = self.maxpool3(conv3)
78 |
79 | conv4 = self.conv4(maxpool3)
80 | maxpool4 = self.maxpool4(conv4)
81 |
82 | # Gating Signal Generation
83 | center = self.center(maxpool4)
84 | gating = self.gating(center)
85 |
86 | # Attention Mechanism
87 | # Upscaling Part (Decoder)
88 | g_conv4, att4 = self.attentionblock4(conv4, gating)
89 | up4 = self.up_concat4(g_conv4, center)
90 | g_conv3, att3 = self.attentionblock3(conv3, up4)
91 | up3 = self.up_concat3(g_conv3, up4)
92 | g_conv2, att2 = self.attentionblock2(conv2, up3)
93 | up2 = self.up_concat2(g_conv2, up3)
94 | up1 = self.up_concat1(conv1, up2)
95 |
96 | # Deep Supervision
97 | dsv4 = self.dsv4(up4)
98 | dsv3 = self.dsv3(up3)
99 | dsv2 = self.dsv2(up2)
100 | dsv1 = self.dsv1(up1)
101 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1))
102 |
103 | return final
104 |
105 |
106 | @staticmethod
107 | def apply_argmax_softmax(pred):
108 | log_p = F.softmax(pred, dim=1)
109 |
110 | return log_p
111 |
112 |
113 | class MultiAttentionBlock(nn.Module):
114 | def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor):
115 | super(MultiAttentionBlock, self).__init__()
116 | self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size,
117 | inter_channels=inter_size, mode=nonlocal_mode,
118 | sub_sample_factor= sub_sample_factor)
119 | self.combine_gates = nn.Sequential(nn.Conv3d(in_size, in_size, kernel_size=1, stride=1, padding=0),
120 | nn.BatchNorm3d(in_size),
121 | nn.ReLU(inplace=True)
122 | )
123 |
124 | # initialise the blocks
125 | for m in self.children():
126 | if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue
127 | init_weights(m, init_type='kaiming')
128 |
129 | def forward(self, input, gating_signal):
130 | gate_1, attention_1 = self.gate_block_1(input, gating_signal)
131 |
132 | return self.combine_gates(gate_1), attention_1
133 |
134 |
135 |
--------------------------------------------------------------------------------
/models/feedforward_seg_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import torch.optim as optim
4 |
5 | from collections import OrderedDict
6 | import utils.util as util
7 | from .base_model import BaseModel
8 | from .networks import get_network
9 | from .layers.loss import *
10 | from .networks_other import get_scheduler, print_network, benchmark_fp_bp_time
11 | from .utils import segmentation_stats, get_optimizer, get_criterion
12 | from .networks.utils import HookBasedFeatureExtractor
13 |
14 |
15 | class FeedForwardSegmentation(BaseModel):
16 |
17 | def name(self):
18 | return 'FeedForwardSegmentation'
19 |
20 | def initialize(self, opts, **kwargs):
21 | BaseModel.initialize(self, opts, **kwargs)
22 | self.isTrain = opts.isTrain
23 |
24 | # define network input and output pars
25 | self.input = None
26 | self.target = None
27 | self.tensor_dim = opts.tensor_dim
28 |
29 | # load/define networks
30 | self.net = get_network(opts.model_type, n_classes=opts.output_nc,
31 | in_channels=opts.input_nc, nonlocal_mode=opts.nonlocal_mode,
32 | tensor_dim=opts.tensor_dim, feature_scale=opts.feature_scale,
33 | attention_dsample=opts.attention_dsample)
34 | if self.use_cuda: self.net = self.net.cuda()
35 |
36 | # load the model if a path is specified or it is in inference mode
37 | if not self.isTrain or opts.continue_train:
38 | self.path_pre_trained_model = opts.path_pre_trained_model
39 | if self.path_pre_trained_model:
40 | self.load_network_from_path(self.net, self.path_pre_trained_model, strict=False)
41 | self.which_epoch = int(0)
42 | else:
43 | self.which_epoch = opts.which_epoch
44 | self.load_network(self.net, 'S', self.which_epoch)
45 |
46 | # training objective
47 | if self.isTrain:
48 | self.criterion = get_criterion(opts)
49 | # initialize optimizers
50 | self.schedulers = []
51 | self.optimizers = []
52 | self.optimizer_S = get_optimizer(opts, self.net.parameters())
53 | self.optimizers.append(self.optimizer_S)
54 |
55 | # print the network details
56 | # print the network details
57 | if kwargs.get('verbose', True):
58 | print('Network is initialized')
59 | print_network(self.net)
60 |
61 | def set_scheduler(self, train_opt):
62 | for optimizer in self.optimizers:
63 | self.schedulers.append(get_scheduler(optimizer, train_opt))
64 | print('Scheduler is added for optimiser {0}'.format(optimizer))
65 |
66 | def set_input(self, *inputs):
67 | # self.input.resize_(inputs[0].size()).copy_(inputs[0])
68 | for idx, _input in enumerate(inputs):
69 | # If it's a 5D array and 2D model then (B x C x H x W x Z) -> (BZ x C x H x W)
70 | bs = _input.size()
71 | if (self.tensor_dim == '2D') and (len(bs) > 4):
72 | _input = _input.permute(0,4,1,2,3).contiguous().view(bs[0]*bs[4], bs[1], bs[2], bs[3])
73 |
74 | # Define that it's a cuda array
75 | if idx == 0:
76 | self.input = _input.cuda() if self.use_cuda else _input
77 | elif idx == 1:
78 | self.target = Variable(_input.cuda()) if self.use_cuda else Variable(_input)
79 | assert self.input.size() == self.target.size()
80 |
81 | def forward(self, split):
82 | if split == 'train':
83 | self.prediction = self.net(Variable(self.input))
84 | elif split == 'test':
85 | self.prediction = self.net(Variable(self.input, volatile=True))
86 | # Apply a softmax and return a segmentation map
87 | self.logits = self.net.apply_argmax_softmax(self.prediction)
88 | self.pred_seg = self.logits.data.max(1)[1].unsqueeze(1)
89 |
90 | def backward(self):
91 | self.loss_S = self.criterion(self.prediction, self.target)
92 | self.loss_S.backward()
93 |
94 | def optimize_parameters(self):
95 | self.net.train()
96 | self.forward(split='train')
97 |
98 | self.optimizer_S.zero_grad()
99 | self.backward()
100 | self.optimizer_S.step()
101 |
102 | # This function updates the network parameters every "accumulate_iters"
103 | def optimize_parameters_accumulate_grd(self, iteration):
104 | accumulate_iters = int(2)
105 | if iteration == 0: self.optimizer_S.zero_grad()
106 | self.net.train()
107 | self.forward(split='train')
108 | self.backward()
109 |
110 | if iteration % accumulate_iters == 0:
111 | self.optimizer_S.step()
112 | self.optimizer_S.zero_grad()
113 |
114 | def test(self):
115 | self.net.eval()
116 | self.forward(split='test')
117 |
118 | def validate(self):
119 | self.net.eval()
120 | self.forward(split='test')
121 | self.loss_S = self.criterion(self.prediction, self.target)
122 |
123 | def get_segmentation_stats(self):
124 | self.seg_scores, self.dice_score = segmentation_stats(self.prediction, self.target)
125 | seg_stats = [('Overall_Acc', self.seg_scores['overall_acc']), ('Mean_IOU', self.seg_scores['mean_iou'])]
126 | for class_id in range(self.dice_score.size):
127 | seg_stats.append(('Class_{}'.format(class_id), self.dice_score[class_id]))
128 | return OrderedDict(seg_stats)
129 |
130 | def get_current_errors(self):
131 | return OrderedDict([('Seg_Loss', self.loss_S.data[0])
132 | ])
133 |
134 | def get_current_visuals(self):
135 | inp_img = util.tensor2im(self.input, 'img')
136 | seg_img = util.tensor2im(self.pred_seg, 'lbl')
137 | return OrderedDict([('out_S', seg_img), ('inp_S', inp_img)])
138 |
139 | def get_feature_maps(self, layer_name, upscale):
140 | feature_extractor = HookBasedFeatureExtractor(self.net, layer_name, upscale)
141 | return feature_extractor.forward(Variable(self.input))
142 |
143 | # returns the fp/bp times of the model
144 | def get_fp_bp_time (self, size=None):
145 | if size is None:
146 | size = (1, 1, 160, 160, 96)
147 |
148 | inp_array = Variable(torch.zeros(*size)).cuda()
149 | out_array = Variable(torch.zeros(*size)).cuda()
150 | fp, bp = benchmark_fp_bp_time(self.net, inp_array, out_array)
151 |
152 | bsize = size[0]
153 | return fp/float(bsize), bp/float(bsize)
154 |
155 | def save(self, epoch_label):
156 | self.save_network(self.net, 'S', epoch_label, self.gpu_ids)
157 |
--------------------------------------------------------------------------------
/models/networks/unet_CT_multi_att_dsv_3D.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from .utils import UnetConv3, UnetUp3_CT, UnetGridGatingSignal3, UnetDsv3
4 | import torch.nn.functional as F
5 | from models.networks_other import init_weights
6 | from models.layers.grid_attention_layer import GridAttentionBlock3D
7 |
8 |
9 | class unet_CT_multi_att_dsv_3D(nn.Module):
10 |
11 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3,
12 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True):
13 | super(unet_CT_multi_att_dsv_3D, self).__init__()
14 | self.is_deconv = is_deconv
15 | self.in_channels = in_channels
16 | self.is_batchnorm = is_batchnorm
17 | self.feature_scale = feature_scale
18 |
19 | filters = [64, 128, 256, 512, 1024]
20 | filters = [int(x / self.feature_scale) for x in filters]
21 |
22 | # downsampling
23 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
24 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))
25 |
26 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
27 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))
28 |
29 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
30 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))
31 |
32 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
33 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))
34 |
35 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
36 | self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm)
37 |
38 | # attention blocks
39 | self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],
40 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
41 | self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],
42 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
43 | self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3],
44 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
45 |
46 | # upsampling
47 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
48 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
49 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
50 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)
51 |
52 | # deep supervision
53 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8)
54 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4)
55 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2)
56 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1)
57 |
58 | # final conv (without any concat)
59 | self.final = nn.Conv3d(n_classes*4, n_classes, 1)
60 |
61 | # initialise weights
62 | for m in self.modules():
63 | if isinstance(m, nn.Conv3d):
64 | init_weights(m, init_type='kaiming')
65 | elif isinstance(m, nn.BatchNorm3d):
66 | init_weights(m, init_type='kaiming')
67 |
68 | def forward(self, inputs):
69 | # Feature Extraction
70 | conv1 = self.conv1(inputs)
71 | maxpool1 = self.maxpool1(conv1)
72 |
73 | conv2 = self.conv2(maxpool1)
74 | maxpool2 = self.maxpool2(conv2)
75 |
76 | conv3 = self.conv3(maxpool2)
77 | maxpool3 = self.maxpool3(conv3)
78 |
79 | conv4 = self.conv4(maxpool3)
80 | maxpool4 = self.maxpool4(conv4)
81 |
82 | # Gating Signal Generation
83 | center = self.center(maxpool4)
84 | gating = self.gating(center)
85 |
86 | # Attention Mechanism
87 | # Upscaling Part (Decoder)
88 | g_conv4, att4 = self.attentionblock4(conv4, gating)
89 | up4 = self.up_concat4(g_conv4, center)
90 | g_conv3, att3 = self.attentionblock3(conv3, up4)
91 | up3 = self.up_concat3(g_conv3, up4)
92 | g_conv2, att2 = self.attentionblock2(conv2, up3)
93 | up2 = self.up_concat2(g_conv2, up3)
94 | up1 = self.up_concat1(conv1, up2)
95 |
96 | # Deep Supervision
97 | dsv4 = self.dsv4(up4)
98 | dsv3 = self.dsv3(up3)
99 | dsv2 = self.dsv2(up2)
100 | dsv1 = self.dsv1(up1)
101 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1))
102 |
103 | return final
104 |
105 |
106 | @staticmethod
107 | def apply_argmax_softmax(pred):
108 | log_p = F.softmax(pred, dim=1)
109 |
110 | return log_p
111 |
112 |
113 | class MultiAttentionBlock(nn.Module):
114 | def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor):
115 | super(MultiAttentionBlock, self).__init__()
116 | self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size,
117 | inter_channels=inter_size, mode=nonlocal_mode,
118 | sub_sample_factor= sub_sample_factor)
119 | self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size,
120 | inter_channels=inter_size, mode=nonlocal_mode,
121 | sub_sample_factor=sub_sample_factor)
122 | self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0),
123 | nn.BatchNorm3d(in_size),
124 | nn.ReLU(inplace=True)
125 | )
126 |
127 | # initialise the blocks
128 | for m in self.children():
129 | if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue
130 | init_weights(m, init_type='kaiming')
131 |
132 | def forward(self, input, gating_signal):
133 | gate_1, attention_1 = self.gate_block_1(input, gating_signal)
134 | gate_2, attention_2 = self.gate_block_2(input, gating_signal)
135 |
136 | return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1)
137 |
138 |
139 |
--------------------------------------------------------------------------------
/test_classification.py:
--------------------------------------------------------------------------------
1 | import os, sys, numpy as np
2 | from torch.utils.data import DataLoader, sampler
3 | from tqdm import tqdm
4 |
5 |
6 | from dataio.loader import get_dataset, get_dataset_path
7 | from dataio.transformation import get_dataset_transformation
8 | from utils.util import json_file_to_pyobj
9 | from utils.visualiser import Visualiser
10 | from utils.error_logger import ErrorLogger
11 | from models.networks_other import adjust_learning_rate
12 |
13 | from models import get_model
14 |
15 | class HiddenPrints:
16 | def __enter__(self):
17 | self._original_stdout = sys.stdout
18 | sys.stdout = None
19 |
20 | def __exit__(self, exc_type, exc_val, exc_tb):
21 | sys.stdout = self._original_stdout
22 |
23 | class StratifiedSampler(object):
24 | """Stratified Sampling
25 | Provides equal representation of target classes in each batch
26 | """
27 | def __init__(self, class_vector, batch_size):
28 | """
29 | Arguments
30 | ---------
31 | class_vector : torch tensor
32 | a vector of class labels
33 | batch_size : integer
34 | batch_size
35 | """
36 | self.class_vector = class_vector
37 | self.batch_size = batch_size
38 | self.num_iter = len(class_vector) // 52
39 | self.n_class = 14
40 | self.sample_n = 2
41 | # create pool of each vectors
42 | indices = {}
43 | for i in range(self.n_class):
44 | indices[i] = np.where(self.class_vector == i)[0]
45 |
46 | self.indices = indices
47 | self.background_index = np.argmax([ len(indices[i]) for i in range(self.n_class)])
48 |
49 |
50 | def gen_sample_array(self):
51 | # sample 2 from each class
52 | sample_array = []
53 | for i in range(self.num_iter):
54 | arrs = []
55 | for i in range(self.n_class):
56 | n = self.sample_n
57 | if i == self.background_index:
58 | n = self.sample_n * (self.n_class-1)
59 | arr = np.random.choice(self.indices[i], n)
60 | arrs.append(arr)
61 |
62 | sample_array.append(np.hstack(arrs))
63 | return np.hstack(sample_array)
64 |
65 | def __iter__(self):
66 | return iter(self.gen_sample_array())
67 |
68 | def __len__(self):
69 | return len(self.class_vector)
70 |
71 |
72 | def test(arguments):
73 |
74 | # Parse input arguments
75 | json_filename = arguments.config
76 | network_debug = arguments.debug
77 |
78 | # Load options
79 | json_opts = json_file_to_pyobj(json_filename)
80 | train_opts = json_opts.training
81 |
82 | # Architecture type
83 | arch_type = train_opts.arch_type
84 |
85 | # Setup Dataset and Augmentation
86 | ds_class = get_dataset(arch_type)
87 | ds_path = get_dataset_path(arch_type, json_opts.data_path)
88 | ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation)
89 |
90 | # Setup the NN Model
91 | with HiddenPrints():
92 | model = get_model(json_opts.model)
93 |
94 | if network_debug:
95 | print('# of pars: ', model.get_number_parameters())
96 | print('fp time: {0:.8f} sec\tbp time: {1:.8f} sec per sample'.format(*model.get_fp_bp_time2((1,1,224,288))))
97 | exit()
98 |
99 | # Setup Data Loader
100 | num_workers = train_opts.num_workers if hasattr(train_opts, 'num_workers') else 16
101 |
102 | valid_dataset = ds_class(ds_path, split='val', transform=ds_transform['valid'], preload_data=train_opts.preloadData)
103 | test_dataset = ds_class(ds_path, split='test', transform=ds_transform['valid'], preload_data=train_opts.preloadData)
104 | # loader
105 | batch_size = train_opts.batchSize
106 | valid_loader = DataLoader(dataset=valid_dataset, num_workers=num_workers, batch_size=train_opts.batchSize, shuffle=False)
107 | test_loader = DataLoader(dataset=test_dataset, num_workers=0, batch_size=train_opts.batchSize, shuffle=False)
108 |
109 | # Visualisation Parameters
110 | filename = 'test_loss_log.txt'
111 | visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir,
112 | filename=filename)
113 | error_logger = ErrorLogger()
114 |
115 | # Training Function
116 | track_labels = np.arange(len(valid_dataset.label_names))
117 | model.set_labels(track_labels)
118 | model.set_scheduler(train_opts)
119 |
120 | if hasattr(model.net, 'deep_supervised'):
121 | model.net.deep_supervised = False
122 |
123 | # Validation and Testing Iterations
124 | pr_lbls = []
125 | gt_lbls = []
126 | for loader, split in zip([test_loader], ['test']):
127 | #for loader, split in zip([valid_loader, test_loader], ['validation', 'test']):
128 | model.reset_results()
129 |
130 | for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1), total=len(loader)):
131 |
132 | # Make a forward pass with the model
133 | model.set_input(images, labels)
134 | model.validate()
135 |
136 | # Error visualisation
137 | errors = model.get_accumulated_errors()
138 | stats = model.get_classification_stats()
139 | error_logger.update({**errors, **stats}, split=split)
140 |
141 | # Update the plots
142 | # for split in ['train', 'validation', 'test']:
143 | for split in ['test']:
144 | # exclude bckground
145 | #track_labels = np.delete(track_labels, 3)
146 | #show_labels = train_dataset.label_names[:3] + train_dataset.label_names[4:]
147 | show_labels = valid_dataset.label_names
148 | visualizer.plot_current_errors(300, error_logger.get_errors(split), split_name=split, labels=show_labels)
149 | visualizer.print_current_errors(300, error_logger.get_errors(split), split_name=split)
150 |
151 | import pickle as pkl
152 | dst_file = os.path.join(model.save_dir, 'test_result.pkl')
153 | with open(dst_file, 'wb') as f:
154 | d = error_logger.get_errors(split)
155 | d['labels'] = valid_dataset.label_names
156 | d['pr_lbls'] = np.hstack(model.pr_lbls)
157 | d['gt_lbls'] = np.hstack(model.gt_lbls)
158 | pkl.dump(d, f)
159 |
160 | error_logger.reset()
161 |
162 | if arguments.time:
163 | print('# of pars: ', model.get_number_parameters())
164 | print('fp time: {0:.8f} sec\tbp time: {1:.8f} sec per sample'.format(*model.get_fp_bp_time2((1,1,224,288))))
165 |
166 |
167 | if __name__ == '__main__':
168 | import argparse
169 |
170 | parser = argparse.ArgumentParser(description='CNN Seg Training Function')
171 |
172 | parser.add_argument('-c', '--config', help='training config file', required=True)
173 | parser.add_argument('-d', '--debug', help='returns number of parameters and bp/fp runtime', action='store_true')
174 | parser.add_argument('-t', '--time', help='returns number of parameters and bp/fp runtime', action='store_true')
175 | args = parser.parse_args()
176 |
177 | test(args)
178 |
--------------------------------------------------------------------------------
/models/feedforward_classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import utils.util as util
4 | from collections import OrderedDict
5 |
6 | import torch
7 | from torch.autograd import Variable
8 | from .base_model import BaseModel
9 | from .networks import get_network
10 | from .layers.loss import *
11 | from .networks_other import get_scheduler, print_network, benchmark_fp_bp_time
12 | from .utils import classification_stats, get_optimizer, get_criterion
13 | from .networks.utils import HookBasedFeatureExtractor
14 |
15 |
16 | class FeedForwardClassifier(BaseModel):
17 |
18 | def name(self):
19 | return 'FeedForwardClassifier'
20 |
21 | def initialize(self, opts, **kwargs):
22 | BaseModel.initialize(self, opts, **kwargs)
23 | self.opts = opts
24 | self.isTrain = opts.isTrain
25 |
26 | # define network input and output pars
27 | self.input = None
28 | self.target = None
29 | self.labels = None
30 | self.tensor_dim = opts.tensor_dim
31 |
32 | # load/define networks
33 | self.net = get_network(opts.model_type, n_classes=opts.output_nc,
34 | in_channels=opts.input_nc, nonlocal_mode=opts.nonlocal_mode,
35 | tensor_dim=opts.tensor_dim, feature_scale=opts.feature_scale,
36 | attention_dsample=opts.attention_dsample,
37 | aggregation_mode=opts.aggregation_mode)
38 | if self.use_cuda: self.net = self.net.cuda()
39 |
40 | # load the model if a path is specified or it is in inference mode
41 | if not self.isTrain or opts.continue_train:
42 | self.path_pre_trained_model = opts.path_pre_trained_model
43 | if self.path_pre_trained_model:
44 | self.load_network_from_path(self.net, self.path_pre_trained_model, strict=False)
45 | self.which_epoch = int(0)
46 | else:
47 | self.which_epoch = opts.which_epoch
48 | self.load_network(self.net, 'S', self.which_epoch)
49 |
50 | # training objective
51 | if self.isTrain:
52 | self.criterion = get_criterion(opts)
53 | # initialize optimizers
54 | self.schedulers = []
55 | self.optimizers = []
56 |
57 | self.optimizer = get_optimizer(opts, self.net.parameters())
58 | self.optimizers.append(self.optimizer)
59 |
60 | # print the network details
61 | if kwargs.get('verbose', True):
62 | print('Network is initialized')
63 | print_network(self.net)
64 |
65 | # for accumulator
66 | self.reset_results()
67 |
68 | def set_scheduler(self, train_opt):
69 | for optimizer in self.optimizers:
70 | self.schedulers.append(get_scheduler(optimizer, train_opt))
71 | print('Scheduler is added for optimiser {0}'.format(optimizer))
72 |
73 | def set_input(self, *inputs):
74 | # self.input.resize_(inputs[0].size()).copy_(inputs[0])
75 | for idx, _input in enumerate(inputs):
76 | # If it's a 5D array and 2D model then (B x C x H x W x Z) -> (BZ x C x H x W)
77 | bs = _input.size()
78 | if (self.tensor_dim == '2D') and (len(bs) > 4):
79 | _input = _input.permute(0,4,1,2,3).contiguous().view(bs[0]*bs[4], bs[1], bs[2], bs[3])
80 |
81 | # Define that it's a cuda array
82 | if idx == 0:
83 | self.input = _input.cuda() if self.use_cuda else _input
84 | elif idx == 1:
85 | self.target = Variable(_input.cuda()) if self.use_cuda else Variable(_input)
86 | assert self.input.shape[0] == self.target.shape[0]
87 |
88 | def forward(self, split):
89 | if split == 'train':
90 | self.prediction = self.net(Variable(self.input))
91 | elif split in ['validation', 'test']:
92 | self.prediction = self.net(Variable(self.input, volatile=True))
93 | # Apply a softmax and return a segmentation map
94 | self.logits = self.net.apply_argmax_softmax(self.prediction)
95 | self.pred = self.logits.data.max(1)
96 |
97 |
98 | def backward(self):
99 | #print(self.net.apply_argmax_softmax(self.prediction), self.target)
100 | self.loss = self.criterion(self.prediction, self.target)
101 | self.loss.backward()
102 |
103 | def optimize_parameters(self):
104 | self.net.train()
105 | self.forward(split='train')
106 |
107 | self.optimizer.zero_grad()
108 | self.backward()
109 | self.optimizer.step()
110 |
111 | def test(self):
112 | self.net.eval()
113 | self.forward(split='test')
114 | self.accumulate_results()
115 |
116 | def validate(self):
117 | self.net.eval()
118 | self.forward(split='test')
119 | self.loss = self.criterion(self.prediction, self.target)
120 | self.accumulate_results()
121 |
122 | def reset_results(self):
123 | self.losses = []
124 | self.pr_lbls = []
125 | self.pr_probs = []
126 | self.gt_lbls = []
127 |
128 | def accumulate_results(self):
129 | self.losses.append(self.loss.data[0])
130 | self.pr_probs.append(self.pred[0].cpu().numpy())
131 | self.pr_lbls.append(self.pred[1].cpu().numpy())
132 | self.gt_lbls.append(self.target.data.cpu().numpy())
133 |
134 | def get_classification_stats(self):
135 | self.pr_lbls = np.concatenate(self.pr_lbls)
136 | self.gt_lbls = np.concatenate(self.gt_lbls)
137 | res = classification_stats(self.pr_lbls, self.gt_lbls, self.labels)
138 | (self.accuracy, self.f1_micro, self.precision_micro,
139 | self.recall_micro, self.f1_macro, self.precision_macro,
140 | self.recall_macro, self.confusion, self.class_accuracies,
141 | self.f1s, self.precisions,self.recalls) = res
142 |
143 | breakdown = dict(type='table',
144 | colnames=['|accuracy|',' precison|',' recall|',' f1_score|'],
145 | rownames=self.labels,
146 | data=[self.class_accuracies, self.precisions,self.recalls, self.f1s])
147 |
148 | return OrderedDict([('accuracy', self.accuracy),
149 | ('confusion', self.confusion),
150 | ('f1', self.f1_macro),
151 | ('precision', self.precision_macro),
152 | ('recall', self.recall_macro),
153 | ('confusion', self.confusion),
154 | ('breakdown', breakdown)])
155 |
156 | def get_current_errors(self):
157 | return OrderedDict([('CE', self.loss.data[0])])
158 |
159 | def get_accumulated_errors(self):
160 | return OrderedDict([('CE', np.mean(self.losses))])
161 |
162 | def get_current_visuals(self):
163 | inp_img = util.tensor2im(self.input, 'img')
164 | return OrderedDict([('inp_S', inp_img)])
165 |
166 | def get_feature_maps(self, layer_name, upscale):
167 | feature_extractor = HookBasedFeatureExtractor(self.net, layer_name, upscale)
168 | return feature_extractor.forward(Variable(self.input))
169 |
170 |
171 | def save(self, epoch_label):
172 | self.save_network(self.net, 'S', epoch_label, self.gpu_ids)
173 |
174 | def set_labels(self, labels):
175 | self.labels = labels
176 |
177 | def load_network_from_path(self, network, network_filepath, strict):
178 | network_label = os.path.basename(network_filepath)
179 | epoch_label = network_label.split('_')[0]
180 | print('Loading the model {0} - epoch {1}'.format(network_label, epoch_label))
181 | network.load_state_dict(torch.load(network_filepath), strict=strict)
182 |
183 | def update_state(self, epoch):
184 | pass
185 |
186 | def get_fp_bp_time2(self, size=None):
187 | # returns the fp/bp times of the model
188 | if size is None:
189 | size = (8, 1, 192, 192)
190 |
191 | inp_array = Variable(torch.rand(*size)).cuda()
192 | out_array = Variable(torch.rand(*size)).cuda()
193 | fp, bp = benchmark_fp_bp_time(self.net, inp_array, out_array)
194 |
195 | bsize = size[0]
196 | return fp/float(bsize), bp/float(bsize)
197 |
--------------------------------------------------------------------------------
/dataio/transformation/transforms.py:
--------------------------------------------------------------------------------
1 | import torchsample.transforms as ts
2 | from pprint import pprint
3 |
4 |
5 | class Transformations:
6 |
7 | def __init__(self, name):
8 | self.name = name
9 |
10 | # Input patch and scale size
11 | self.scale_size = (192, 192, 1)
12 | self.patch_size = (128, 128, 1)
13 | # self.patch_size = (208, 272, 1)
14 |
15 | # Affine and Intensity Transformations
16 | self.shift_val = (0.1, 0.1)
17 | self.rotate_val = 15.0
18 | self.scale_val = (0.7, 1.3)
19 | self.inten_val = (1.0, 1.0)
20 | self.random_flip_prob = 0.0
21 |
22 | # Divisibility factor for testing
23 | self.division_factor = (16, 16, 1)
24 |
25 | def get_transformation(self):
26 | return {
27 | 'ukbb_sax': self.cmr_3d_sax_transform,
28 | 'hms_sax': self.hms_sax_transform,
29 | 'test_sax': self.test_3d_sax_transform,
30 | 'acdc_sax': self.cmr_3d_sax_transform,
31 | 'us': self.ultrasound_transform,
32 | }[self.name]()
33 |
34 | def print(self):
35 | print('\n\n############# Augmentation Parameters #############')
36 | pprint(vars(self))
37 | print('###################################################\n\n')
38 |
39 | def initialise(self, opts):
40 | t_opts = getattr(opts, self.name)
41 |
42 | # Affine and Intensity Transformations
43 | if hasattr(t_opts, 'scale_size'): self.scale_size = t_opts.scale_size
44 | if hasattr(t_opts, 'patch_size'): self.patch_size = t_opts.patch_size
45 | if hasattr(t_opts, 'shift_val'): self.shift_val = t_opts.shift
46 | if hasattr(t_opts, 'rotate_val'): self.rotate_val = t_opts.rotate
47 | if hasattr(t_opts, 'scale_val'): self.scale_val = t_opts.scale
48 | if hasattr(t_opts, 'inten_val'): self.inten_val = t_opts.intensity
49 | if hasattr(t_opts, 'random_flip_prob'): self.random_flip_prob = t_opts.random_flip_prob
50 | if hasattr(t_opts, 'division_factor'): self.division_factor = t_opts.division_factor
51 |
52 | def ukbb_sax_transform(self):
53 |
54 | train_transform = ts.Compose([ts.PadNumpy(size=self.scale_size),
55 | ts.ToTensor(),
56 | ts.ChannelsFirst(),
57 | ts.TypeCast(['float', 'float']),
58 | ts.RandomFlip(h=True, v=True, p=self.random_flip_prob),
59 | ts.RandomAffine(rotation_range=self.rotate_val, translation_range=self.shift_val,
60 | zoom_range=self.scale_val, interp=('bilinear', 'nearest')),
61 | ts.NormalizeMedicPercentile(norm_flag=(True, False)),
62 | ts.RandomCrop(size=self.patch_size),
63 | ts.TypeCast(['float', 'long'])
64 | ])
65 |
66 | valid_transform = ts.Compose([ts.PadNumpy(size=self.scale_size),
67 | ts.ToTensor(),
68 | ts.ChannelsFirst(),
69 | ts.TypeCast(['float', 'float']),
70 | ts.NormalizeMedicPercentile(norm_flag=(True, False)),
71 | ts.SpecialCrop(size=self.patch_size, crop_type=0),
72 | ts.TypeCast(['float', 'long'])
73 | ])
74 |
75 | return {'train': train_transform, 'valid': valid_transform}
76 |
77 | def cmr_3d_sax_transform(self):
78 |
79 | train_transform = ts.Compose([ts.PadNumpy(size=self.scale_size),
80 | ts.ToTensor(),
81 | ts.ChannelsFirst(),
82 | ts.TypeCast(['float', 'float']),
83 | ts.RandomFlip(h=True, v=True, p=self.random_flip_prob),
84 | ts.RandomAffine(rotation_range=self.rotate_val, translation_range=self.shift_val,
85 | zoom_range=self.scale_val, interp=('bilinear', 'nearest')),
86 | #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
87 | ts.NormalizeMedic(norm_flag=(True, False)),
88 | ts.ChannelsLast(),
89 | ts.AddChannel(axis=0),
90 | ts.RandomCrop(size=self.patch_size),
91 | ts.TypeCast(['float', 'long'])
92 | ])
93 |
94 | valid_transform = ts.Compose([ts.PadNumpy(size=self.scale_size),
95 | ts.ToTensor(),
96 | ts.ChannelsFirst(),
97 | ts.TypeCast(['float', 'float']),
98 | #ts.NormalizeMedicPercentile(norm_flag=(True, False)),
99 | ts.NormalizeMedic(norm_flag=(True, False)),
100 | ts.ChannelsLast(),
101 | ts.AddChannel(axis=0),
102 | ts.SpecialCrop(size=self.patch_size, crop_type=0),
103 | ts.TypeCast(['float', 'long'])
104 | ])
105 |
106 | return {'train': train_transform, 'valid': valid_transform}
107 |
108 | def hms_sax_transform(self):
109 |
110 | # Training transformation
111 | # 2D Stack input - 3D High Resolution output segmentation
112 |
113 | train_transform = []
114 | valid_transform = []
115 |
116 | # First pad to a fixed size
117 | # Torch tensor
118 | # Channels first
119 | # Joint affine transformation
120 | # In-plane respiratory motion artefacts (translation and rotation)
121 | # Random Crop
122 | # Normalise the intensity range
123 | train_transform = ts.Compose([])
124 |
125 | return {'train': train_transform, 'valid': valid_transform}
126 |
127 | def test_3d_sax_transform(self):
128 | test_transform = ts.Compose([ts.PadFactorNumpy(factor=self.division_factor),
129 | ts.ToTensor(),
130 | ts.ChannelsFirst(),
131 | ts.TypeCast(['float']),
132 | #ts.NormalizeMedicPercentile(norm_flag=True),
133 | ts.NormalizeMedic(norm_flag=True),
134 | ts.ChannelsLast(),
135 | ts.AddChannel(axis=0),
136 | ])
137 |
138 | return {'test': test_transform}
139 |
140 |
141 | def ultrasound_transform(self):
142 |
143 | train_transform = ts.Compose([ts.ToTensor(),
144 | ts.TypeCast(['float']),
145 | ts.AddChannel(axis=0),
146 | ts.SpecialCrop(self.patch_size,0),
147 | ts.RandomFlip(h=True, v=False, p=self.random_flip_prob),
148 | ts.RandomAffine(rotation_range=self.rotate_val,
149 | translation_range=self.shift_val,
150 | zoom_range=self.scale_val,
151 | interp=('bilinear')),
152 | ts.StdNormalize(),
153 | ])
154 |
155 | valid_transform = ts.Compose([ts.ToTensor(),
156 | ts.TypeCast(['float']),
157 | ts.AddChannel(axis=0),
158 | ts.SpecialCrop(self.patch_size,0),
159 | ts.StdNormalize(),
160 | ])
161 |
162 | return {'train': train_transform, 'valid': valid_transform}
163 |
--------------------------------------------------------------------------------
/visualise_attention.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 |
3 | from dataio.loader import get_dataset, get_dataset_path
4 | from dataio.transformation import get_dataset_transformation
5 | from utils.util import json_file_to_pyobj
6 | from utils.visualiser import Visualiser
7 | from models import get_model
8 | import os, time
9 |
10 | # import matplotlib
11 | # matplotlib.use('Agg')
12 |
13 | import matplotlib.cm as cm
14 | import matplotlib.pyplot as plt
15 | import math, numpy
16 | import numpy as np
17 | from scipy.misc import imresize
18 | from skimage.transform import resize
19 |
20 | def plotNNFilter(units, figure_id, interp='bilinear', colormap=cm.jet, colormap_lim=None, title=''):
21 | plt.ion()
22 | filters = units.shape[2]
23 | n_columns = round(math.sqrt(filters))
24 | n_rows = math.ceil(filters / n_columns) + 1
25 | fig = plt.figure(figure_id, figsize=(n_rows*3,n_columns*3))
26 | fig.clf()
27 |
28 | for i in range(filters):
29 | ax1 = plt.subplot(n_rows, n_columns, i+1)
30 | plt.imshow(units[:,:,i].T, interpolation=interp, cmap=colormap)
31 | plt.axis('on')
32 | ax1.set_xticklabels([])
33 | ax1.set_yticklabels([])
34 | plt.colorbar()
35 | if colormap_lim:
36 | plt.clim(colormap_lim[0],colormap_lim[1])
37 |
38 | plt.subplots_adjust(wspace=0, hspace=0)
39 | plt.tight_layout()
40 | plt.suptitle(title)
41 |
42 | def plotNNFilterOverlay(input_im, units, figure_id, interp='bilinear',
43 | colormap=cm.jet, colormap_lim=None, title='', alpha=0.8):
44 | plt.ion()
45 | filters = units.shape[2]
46 | fig = plt.figure(figure_id, figsize=(5,5))
47 | fig.clf()
48 |
49 | for i in range(filters):
50 | plt.imshow(input_im[:,:,0], interpolation=interp, cmap='gray')
51 | plt.imshow(units[:,:,i], interpolation=interp, cmap=colormap, alpha=alpha)
52 | plt.axis('off')
53 | plt.colorbar()
54 | plt.title(title, fontsize='small')
55 | if colormap_lim:
56 | plt.clim(colormap_lim[0],colormap_lim[1])
57 |
58 | plt.subplots_adjust(wspace=0, hspace=0)
59 | plt.tight_layout()
60 |
61 | # plt.savefig('{}/{}.png'.format(dir_name,time.time()))
62 |
63 |
64 |
65 |
66 | ## Load options
67 | PAUSE = .01
68 | #config_name = 'config_sononet_attention_fs8_v6.json'
69 | #config_name = 'config_sononet_attention_fs8_v8.json'
70 | #config_name = 'config_sononet_attention_fs8_v9.json'
71 | #config_name = 'config_sononet_attention_fs8_v10.json'
72 | #config_name = 'config_sononet_attention_fs8_v11.json'
73 | #config_name = 'config_sononet_attention_fs8_v13.json'
74 | #config_name = 'config_sononet_attention_fs8_v14.json'
75 | #config_name = 'config_sononet_attention_fs8_v15.json'
76 | #config_name = 'config_sononet_attention_fs8_v16.json'
77 | #config_name = 'config_sononet_grid_attention_fs8_v1.json'
78 | config_name = 'config_sononet_grid_attention_fs8_deepsup_v1.json'
79 | config_name = 'config_sononet_grid_attention_fs8_deepsup_v2.json'
80 | config_name = 'config_sononet_grid_attention_fs8_deepsup_v3.json'
81 | config_name = 'config_sononet_grid_attention_fs8_deepsup_v4.json'
82 |
83 | # config_name = 'config_sononet_grid_att_fs8_avg.json'
84 | config_name = 'config_sononet_grid_att_fs8_avg_v2.json'
85 | # config_name = 'config_sononet_grid_att_fs8_avg_v3.json'
86 | #config_name = 'config_sononet_grid_att_fs8_avg_v4.json'
87 | #config_name = 'config_sononet_grid_att_fs8_avg_v5.json'
88 | #config_name = 'config_sononet_grid_att_fs8_avg_v5.json'
89 | #config_name = 'config_sononet_grid_att_fs8_avg_v6.json'
90 | #config_name = 'config_sononet_grid_att_fs8_avg_v7.json'
91 | #config_name = 'config_sononet_grid_att_fs8_avg_v8.json'
92 | #config_name = 'config_sononet_grid_att_fs8_avg_v9.json'
93 | #config_name = 'config_sononet_grid_att_fs8_avg_v10.json'
94 | #config_name = 'config_sononet_grid_att_fs8_avg_v11.json'
95 | #config_name = 'config_sononet_grid_att_fs8_avg_v12.json'
96 |
97 | config_name = 'config_sononet_grid_att_fs8_avg_v12_scratch.json'
98 | config_name = 'config_sononet_grid_att_fs4_avg_v12.json'
99 |
100 | #config_name = 'config_sononet_grid_attention_fs8_v3.json'
101 |
102 | json_opts = json_file_to_pyobj('/vol/bitbucket/js3611/projects/transfer_learning/ultrasound/configs_2/{}'.format(config_name))
103 | train_opts = json_opts.training
104 |
105 | dir_name = os.path.join('visualisation_debug', config_name)
106 | if not os.path.isdir(dir_name):
107 | os.makedirs(dir_name)
108 | os.makedirs(os.path.join(dir_name,'pos'))
109 | os.makedirs(os.path.join(dir_name,'neg'))
110 |
111 | # Setup the NN Model
112 | model = get_model(json_opts.model)
113 | if hasattr(model.net, 'classification_mode'):
114 | model.net.classification_mode = 'attention'
115 | if hasattr(model.net, 'deep_supervised'):
116 | model.net.deep_supervised = False
117 |
118 | # Setup Dataset and Augmentation
119 | dataset_class = get_dataset(train_opts.arch_type)
120 | dataset_path = get_dataset_path(train_opts.arch_type, json_opts.data_path)
121 | dataset_transform = get_dataset_transformation(train_opts.arch_type, opts=json_opts.augmentation)
122 |
123 | # Setup Data Loader
124 | dataset = dataset_class(dataset_path, split='train', transform=dataset_transform['valid'])
125 | data_loader = DataLoader(dataset=dataset, num_workers=1, batch_size=1, shuffle=True)
126 |
127 | # test
128 | for iteration, data in enumerate(data_loader, 1):
129 | model.set_input(data[0], data[1])
130 |
131 | cls = dataset.label_names[int(data[1])]
132 |
133 | model.validate()
134 | pred_class = model.pred[1]
135 | pred_cls = dataset.label_names[int(pred_class)]
136 |
137 | #########################################################
138 | # Display the input image and Down_sample the input image
139 | input_img = model.input[0,0].cpu().numpy()
140 | #input_img = numpy.expand_dims(imresize(input_img, (fmap_size[0], fmap_size[1]), interp='bilinear'), axis=2)
141 | input_img = numpy.expand_dims(input_img, axis=2)
142 |
143 | # plotNNFilter(input_img, figure_id=0, colormap="gray")
144 | plotNNFilterOverlay(input_img, numpy.zeros_like(input_img), figure_id=0, interp='bilinear',
145 | colormap=cm.jet, title='[GT:{}|P:{}]'.format(cls, pred_cls),alpha=0)
146 |
147 | chance = np.random.random() < 0.01 if cls == "BACKGROUND" else 1
148 | if cls != pred_cls:
149 | plt.savefig('{}/neg/{:03d}.png'.format(dir_name,iteration))
150 | elif cls == pred_cls and chance:
151 | plt.savefig('{}/pos/{:03d}.png'.format(dir_name,iteration))
152 | #########################################################
153 | # Compatibility Scores overlay with input
154 | attentions = []
155 | for i in [1,2]:
156 | fmap = model.get_feature_maps('compatibility_score%d'%i, upscale=False)
157 | if not fmap:
158 | continue
159 |
160 | # Output of the attention block
161 | fmap_0 = fmap[0].squeeze().permute(1,2,0).cpu().numpy()
162 | fmap_size = fmap_0.shape
163 |
164 | # Attention coefficient (b x c x w x h x s)
165 | attention = fmap[1].squeeze().cpu().numpy()
166 | attention = attention[:, :]
167 | #attention = numpy.expand_dims(resize(attention, (fmap_size[0], fmap_size[1]), mode='constant', preserve_range=True), axis=2)
168 | attention = numpy.expand_dims(resize(attention, (input_img.shape[0], input_img.shape[1]), mode='constant', preserve_range=True), axis=2)
169 |
170 | # this one is useless
171 | #plotNNFilter(fmap_0, figure_id=i+3, interp='bilinear', colormap=cm.jet, title='compat. feature %d' %i)
172 |
173 | plotNNFilterOverlay(input_img, attention, figure_id=i, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. {}'.format(cls,pred_cls,i), alpha=0.5)
174 | attentions.append(attention)
175 |
176 | #plotNNFilterOverlay(input_img, attentions[0], figure_id=4, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. (all)'.format(cls, pred_cls), alpha=0.5)
177 | plotNNFilterOverlay(input_img, numpy.mean(attentions,0), figure_id=4, interp='bilinear', colormap=cm.jet, title='[GT:{}|P:{}] compat. (all)'.format(cls, pred_cls), alpha=0.5)
178 |
179 | if cls != pred_cls:
180 | plt.savefig('{}/neg/{:03d}_hm.png'.format(dir_name,iteration))
181 | elif cls == pred_cls and chance:
182 | plt.savefig('{}/pos/{:03d}_hm.png'.format(dir_name,iteration))
183 | # Linear embedding g(x)
184 | # (b, c, h, w)
185 | #gx = fmap[2].squeeze().permute(1,2,0).cpu().numpy()
186 | #plotNNFilter(gx, figure_id=3, interp='nearest', colormap=cm.jet)
187 |
188 | plt.show()
189 | plt.pause(PAUSE)
190 |
191 | model.destructor()
192 | #if iteration == 1: break
193 |
--------------------------------------------------------------------------------
/utils/visualiser.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | import ntpath
5 | import time
6 | from utils import util, html
7 |
8 | # Use the following comment to launch a visdom server
9 | # python -m visdom.server
10 |
11 | class Visualiser():
12 | def __init__(self, opt, save_dir, filename='loss_log.txt'):
13 | self.display_id = opt.display_id
14 | self.use_html = not opt.no_html
15 | self.win_size = opt.display_winsize
16 | self.save_dir = save_dir
17 | self.name = os.path.basename(self.save_dir)
18 | self.saved = False
19 | self.display_single_pane_ncols = opt.display_single_pane_ncols
20 |
21 | # Error plots
22 | self.error_plots = dict()
23 | self.error_wins = dict()
24 |
25 | if self.display_id > 0:
26 | import visdom
27 | self.vis = visdom.Visdom(port=opt.display_port)
28 |
29 | if self.use_html:
30 | self.web_dir = os.path.join(self.save_dir, 'web')
31 | self.img_dir = os.path.join(self.web_dir, 'images')
32 | print('create web directory %s...' % self.web_dir)
33 | util.mkdirs([self.web_dir, self.img_dir])
34 | self.log_name = os.path.join(self.save_dir, filename)
35 | with open(self.log_name, "a") as log_file:
36 | now = time.strftime("%c")
37 | log_file.write('================ Training Loss (%s) ================\n' % now)
38 |
39 | def reset(self):
40 | self.saved = False
41 |
42 | # |visuals|: dictionary of images to display or save
43 | def display_current_results(self, visuals, epoch, save_result):
44 | if self.display_id > 0: # show images in the browser
45 | ncols = self.display_single_pane_ncols
46 | if ncols > 0:
47 | h, w = next(iter(visuals.values())).shape[:2]
48 | table_css = """""" % (w, h)
52 | title = self.name
53 | label_html = ''
54 | label_html_row = ''
55 | nrows = int(np.ceil(len(visuals.items()) / ncols))
56 | images = []
57 | idx = 0
58 | for label, image_numpy in visuals.items():
59 | label_html_row += '%s | ' % label
60 | images.append(image_numpy.transpose([2, 0, 1]))
61 | idx += 1
62 | if idx % ncols == 0:
63 | label_html += '%s
' % label_html_row
64 | label_html_row = ''
65 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
66 | while idx % ncols != 0:
67 | images.append(white_image)
68 | label_html_row += ' | '
69 | idx += 1
70 | if label_html_row != '':
71 | label_html += '%s
' % label_html_row
72 | # pane col = image row
73 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
74 | padding=2, opts=dict(title=title + ' images'))
75 | label_html = '' % label_html
76 | self.vis.text(table_css + label_html, win=self.display_id + 2,
77 | opts=dict(title=title + ' labels'))
78 | else:
79 | idx = 1
80 | for label, image_numpy in visuals.items():
81 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
82 | win=self.display_id + idx)
83 | idx += 1
84 |
85 | if self.use_html and (save_result or not self.saved): # save images to a html file
86 | self.saved = True
87 | for label, image_numpy in visuals.items():
88 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
89 | util.save_image(image_numpy, img_path)
90 | # update website
91 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
92 | for n in range(epoch, 0, -1):
93 | webpage.add_header('epoch [%d]' % n)
94 | ims = []
95 | txts = []
96 | links = []
97 |
98 | for label, image_numpy in visuals.items():
99 | img_path = 'epoch%.3d_%s.png' % (n, label)
100 | ims.append(img_path)
101 | txts.append(label)
102 | links.append(img_path)
103 | webpage.add_images(ims, txts, links, width=self.win_size)
104 | webpage.save()
105 |
106 | def plot_table_html(self, x, y, key, split_name, **kwargs):
107 | key_s = key+'_'+split_name
108 | if key_s not in self.error_plots:
109 | self.error_wins[key_s] = self.display_id * 3 + len(self.error_wins)
110 | else:
111 | self.vis.close(self.error_plots[key_s])
112 |
113 |
114 | table = pd.DataFrame(np.array(y['data']).transpose(),
115 | index=kwargs['labels'], columns=y['colnames'])
116 | table_html = table.round(2).to_html(col_space=200, bold_rows=True, border=12)
117 |
118 | self.error_plots[key_s] = self.vis.text(table_html,
119 | opts=dict(title=self.name+split_name,
120 | width=350, height=350,
121 | win=self.error_wins[key_s]))
122 |
123 |
124 | def plot_heatmap(self, x, y, key, split_name, **kwargs):
125 | key_s = key+'_'+split_name
126 | if key_s not in self.error_plots:
127 | self.error_wins[key_s] = self.display_id * 3 + len(self.error_wins)
128 | else:
129 | self.vis.close(self.error_plots[key_s])
130 | self.error_plots[key_s] = self.vis.heatmap(
131 | X=y,
132 | opts=dict(
133 | columnnames=kwargs['labels'],
134 | rownames=kwargs['labels'],
135 | title=self.name + ' confusion matrix',
136 | win=self.error_wins[key_s]))
137 |
138 | def plot_line(self, x, y, key, split_name):
139 | if key not in self.error_plots:
140 | self.error_wins[key] = self.display_id * 3 + len(self.error_wins)
141 | self.error_plots[key] = self.vis.line(
142 | X=np.array([x, x]),
143 | Y=np.array([y, y]),
144 | opts=dict(
145 | legend=[split_name],
146 | title=self.name + ' {} over time'.format(key),
147 | xlabel='Epochs',
148 | ylabel=key,
149 | win=self.error_wins[key]
150 | ))
151 | else:
152 | self.vis.updateTrace(X=np.array([x]), Y=np.array([y]), win=self.error_plots[key], name=split_name)
153 | # errors: dictionary of error labels and values
154 | def plot_current_errors(self, epoch, errors, split_name, counter_ratio=0.0, **kwargs):
155 | if self.display_id > 0:
156 | for key in errors.keys():
157 | x = epoch + counter_ratio
158 | y = errors[key]
159 | if isinstance(y, dict):
160 | if y['type'] == 'table':
161 | self.plot_table_html(x,y,key,split_name, **kwargs)
162 | elif np.isscalar(y):
163 | self.plot_line(x,y,key,split_name)
164 | elif y.ndim == 2:
165 | self.plot_heatmap(x,y,key,split_name, **kwargs)
166 |
167 |
168 | # errors: same format as |errors| of plotCurrentErrors
169 | def print_current_errors(self, epoch, errors, split_name):
170 | message = '(epoch: %d, split: %s) ' % (epoch, split_name)
171 | for k, v in errors.items():
172 | if np.isscalar(v):
173 | message += '%s: %.3f ' % (k, v)
174 |
175 | print(message)
176 | with open(self.log_name, "a") as log_file:
177 | log_file.write('%s\n' % message)
178 |
179 | # save image to the disk
180 | def save_images(self, webpage, visuals, image_path):
181 | image_dir = webpage.get_image_dir()
182 | short_path = ntpath.basename(image_path[0])
183 | name = os.path.splitext(short_path)[0]
184 |
185 | webpage.add_header(name)
186 | ims = []
187 | txts = []
188 | links = []
189 |
190 | for label, image_numpy in visuals.items():
191 | image_name = '%s_%s.png' % (name, label)
192 | save_path = os.path.join(image_dir, image_name)
193 | util.save_image(image_numpy, save_path)
194 |
195 | ims.append(image_name)
196 | txts.append(label)
197 | links.append(image_name)
198 | webpage.add_images(ims, txts, links, width=self.win_size)
199 |
--------------------------------------------------------------------------------
/train_classifaction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader, sampler
3 | from tqdm import tqdm
4 |
5 |
6 | from dataio.loader import get_dataset, get_dataset_path
7 | from dataio.transformation import get_dataset_transformation
8 | from utils.util import json_file_to_pyobj
9 | from utils.visualiser import Visualiser
10 | from utils.error_logger import ErrorLogger
11 | from models.networks_other import adjust_learning_rate
12 |
13 | from models import get_model
14 |
15 |
16 | class StratifiedSampler(object):
17 | """Stratified Sampling
18 | Provides equal representation of target classes in each batch
19 | """
20 | def __init__(self, class_vector, batch_size):
21 | """
22 | Arguments
23 | ---------
24 | class_vector : torch tensor
25 | a vector of class labels
26 | batch_size : integer
27 | batch_size
28 | """
29 | self.class_vector = class_vector
30 | self.batch_size = batch_size
31 | self.num_iter = len(class_vector) // 52
32 | self.n_class = 14
33 | self.sample_n = 2
34 | # create pool of each vectors
35 | indices = {}
36 | for i in range(self.n_class):
37 | indices[i] = np.where(self.class_vector == i)[0]
38 |
39 | self.indices = indices
40 | self.background_index = np.argmax([ len(indices[i]) for i in range(self.n_class)])
41 |
42 |
43 | def gen_sample_array(self):
44 | # sample 2 from each class
45 | sample_array = []
46 | for i in range(self.num_iter):
47 | arrs = []
48 | for i in range(self.n_class):
49 | n = self.sample_n
50 | if i == self.background_index:
51 | n = self.sample_n * (self.n_class-1)
52 | arr = np.random.choice(self.indices[i], n)
53 | arrs.append(arr)
54 |
55 | sample_array.append(np.hstack(arrs))
56 | return np.hstack(sample_array)
57 |
58 | def __iter__(self):
59 | return iter(self.gen_sample_array())
60 |
61 | def __len__(self):
62 | return len(self.class_vector)
63 |
64 |
65 | # Not using anymore
66 | def check_warm_start(epoch, model, train_opts):
67 | if hasattr(train_opts, "warm_start_epoch"):
68 | if epoch < train_opts.warm_start_epoch:
69 | print('... warm_start: lr={}'.format(train_opts.warm_start_lr))
70 | adjust_learning_rate(model.optimizers[0], train_opts.warm_start_lr)
71 | elif epoch == train_opts.warm_start_epoch:
72 | print('... warm_start ended: lr={}'.format(model.opts.lr_rate))
73 | adjust_learning_rate(model.optimizers[0], model.opts.lr_rate)
74 |
75 |
76 | def train(arguments):
77 |
78 | # Parse input arguments
79 | json_filename = arguments.config
80 | network_debug = arguments.debug
81 |
82 | # Load options
83 | json_opts = json_file_to_pyobj(json_filename)
84 | train_opts = json_opts.training
85 |
86 | # Architecture type
87 | arch_type = train_opts.arch_type
88 |
89 | # Setup Dataset and Augmentation
90 | ds_class = get_dataset(arch_type)
91 | ds_path = get_dataset_path(arch_type, json_opts.data_path)
92 | ds_transform = get_dataset_transformation(arch_type, opts=json_opts.augmentation)
93 |
94 | # Setup the NN Model
95 | model = get_model(json_opts.model)
96 | if network_debug:
97 | print('# of pars: ', model.get_number_parameters())
98 | print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(*model.get_fp_bp_time()))
99 | exit()
100 |
101 | # Setup Data Loader
102 | num_workers = train_opts.num_workers if hasattr(train_opts, 'num_workers') else 16
103 | train_dataset = ds_class(ds_path, split='train', transform=ds_transform['train'], preload_data=train_opts.preloadData)
104 | valid_dataset = ds_class(ds_path, split='val', transform=ds_transform['valid'], preload_data=train_opts.preloadData)
105 | test_dataset = ds_class(ds_path, split='test', transform=ds_transform['valid'], preload_data=train_opts.preloadData)
106 |
107 | # create sampler
108 | if train_opts.sampler == 'stratified':
109 | print('stratified sampler')
110 | train_sampler = StratifiedSampler(train_dataset.labels, train_opts.batchSize)
111 | batch_size = 52
112 | elif train_opts.sampler == 'weighted2':
113 | print('weighted sampler with background weight={}x'.format(train_opts.bgd_weight_multiplier))
114 | # modify and increase background weight
115 | weight = train_dataset.weight
116 | bgd_weight = np.min(weight)
117 | weight[abs(weight - bgd_weight) < 1e-8] = bgd_weight * train_opts.bgd_weight_multiplier
118 | train_sampler = sampler.WeightedRandomSampler(weight, len(train_dataset.weight))
119 | batch_size = train_opts.batchSize
120 | else:
121 | print('weighted sampler')
122 | train_sampler = sampler.WeightedRandomSampler(train_dataset.weight, len(train_dataset.weight))
123 | batch_size = train_opts.batchSize
124 |
125 | # loader
126 | train_loader = DataLoader(dataset=train_dataset, num_workers=num_workers,
127 | batch_size=batch_size, sampler=train_sampler)
128 | valid_loader = DataLoader(dataset=valid_dataset, num_workers=num_workers, batch_size=train_opts.batchSize, shuffle=True)
129 | test_loader = DataLoader(dataset=test_dataset, num_workers=num_workers, batch_size=train_opts.batchSize, shuffle=True)
130 |
131 | # Visualisation Parameters
132 | visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
133 | error_logger = ErrorLogger()
134 |
135 | # Training Function
136 | track_labels = np.arange(len(train_dataset.label_names))
137 | model.set_labels(track_labels)
138 | model.set_scheduler(train_opts)
139 |
140 | if hasattr(model, 'update_state'):
141 | model.update_state(0)
142 |
143 | for epoch in range(model.which_epoch, train_opts.n_epochs):
144 | print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))
145 |
146 | # # # --- Start ---
147 | # import matplotlib.pyplot as plt
148 | # plt.ion()
149 | # plt.figure()
150 | # target_arr = np.zeros(14)
151 | # # # --- End ---
152 |
153 | # Training Iterations
154 | for epoch_iter, (images, labels) in tqdm(enumerate(train_loader, 1), total=len(train_loader)):
155 | # Make a training update
156 | model.set_input(images, labels)
157 | model.optimize_parameters()
158 |
159 | if epoch == (train_opts.n_epochs-1):
160 | import time
161 | time.sleep(36000)
162 |
163 | if train_opts.max_it == epoch_iter:
164 | break
165 |
166 | # # # --- visualise distribution ---
167 | # for lab in labels.numpy():
168 | # target_arr[lab] += 1
169 | # plt.clf(); plt.bar(train_dataset.label_names, target_arr); plt.pause(0.01)
170 | # # # --- End ---
171 |
172 | # Visualise predictions
173 | if epoch_iter <= 100:
174 | visuals = model.get_current_visuals()
175 | visualizer.display_current_results(visuals, epoch=epoch, save_result=False)
176 |
177 | # Error visualisation
178 | errors = model.get_current_errors()
179 | error_logger.update(errors, split='train')
180 |
181 | # Validation and Testing Iterations
182 | pr_lbls = []
183 | gt_lbls = []
184 | for loader, split in zip([valid_loader, test_loader], ['validation', 'test']):
185 | model.reset_results()
186 |
187 | for epoch_iter, (images, labels) in tqdm(enumerate(loader, 1), total=len(loader)):
188 |
189 | # Make a forward pass with the model
190 | model.set_input(images, labels)
191 | model.validate()
192 |
193 | # Visualise predictions
194 | visuals = model.get_current_visuals()
195 | visualizer.display_current_results(visuals, epoch=epoch, save_result=False)
196 |
197 | if train_opts.max_it == epoch_iter:
198 | break
199 |
200 | # Error visualisation
201 | errors = model.get_accumulated_errors()
202 | stats = model.get_classification_stats()
203 | error_logger.update({**errors, **stats}, split=split)
204 |
205 | # HACK save validation error
206 | if split == 'validation':
207 | valid_err = errors['CE']
208 |
209 | # Update the plots
210 | for split in ['train', 'validation', 'test']:
211 | # exclude bckground
212 | #track_labels = np.delete(track_labels, 3)
213 | #show_labels = train_dataset.label_names[:3] + train_dataset.label_names[4:]
214 | show_labels = train_dataset.label_names
215 | visualizer.plot_current_errors(epoch, error_logger.get_errors(split), split_name=split, labels=show_labels)
216 | visualizer.print_current_errors(epoch, error_logger.get_errors(split), split_name=split)
217 | error_logger.reset()
218 |
219 | # Save the model parameters
220 | if epoch % train_opts.save_epoch_freq == 0:
221 | model.save(epoch)
222 |
223 | if hasattr(model, 'update_state'):
224 | model.update_state(epoch)
225 |
226 | # Update the model learning rate
227 | model.update_learning_rate(metric=valid_err, epoch=epoch)
228 |
229 |
230 | if __name__ == '__main__':
231 | import argparse
232 |
233 | parser = argparse.ArgumentParser(description='CNN Classification Training Function')
234 |
235 | parser.add_argument('-c', '--config', help='training config file', required=True)
236 | parser.add_argument('-d', '--debug', help='returns number of parameters and bp/fp runtime', action='store_true')
237 | args = parser.parse_args()
238 |
239 | train(args)
240 |
--------------------------------------------------------------------------------
/models/layers/nonlocal_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from models.networks_other import init_weights
5 |
6 |
7 | class _NonLocalBlockND(nn.Module):
8 | def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian',
9 | sub_sample_factor=4, bn_layer=True):
10 | super(_NonLocalBlockND, self).__init__()
11 |
12 | assert dimension in [1, 2, 3]
13 | assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']
14 |
15 | # print('Dimension: %d, mode: %s' % (dimension, mode))
16 |
17 | self.mode = mode
18 | self.dimension = dimension
19 | self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor]
20 |
21 | self.in_channels = in_channels
22 | self.inter_channels = inter_channels
23 |
24 | if self.inter_channels is None:
25 | self.inter_channels = in_channels // 2
26 | if self.inter_channels == 0:
27 | self.inter_channels = 1
28 |
29 | if dimension == 3:
30 | conv_nd = nn.Conv3d
31 | max_pool = nn.MaxPool3d
32 | bn = nn.BatchNorm3d
33 | elif dimension == 2:
34 | conv_nd = nn.Conv2d
35 | max_pool = nn.MaxPool2d
36 | bn = nn.BatchNorm2d
37 | else:
38 | conv_nd = nn.Conv1d
39 | max_pool = nn.MaxPool1d
40 | bn = nn.BatchNorm1d
41 |
42 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
43 | kernel_size=1, stride=1, padding=0)
44 |
45 | if bn_layer:
46 | self.W = nn.Sequential(
47 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
48 | kernel_size=1, stride=1, padding=0),
49 | bn(self.in_channels)
50 | )
51 | nn.init.constant(self.W[1].weight, 0)
52 | nn.init.constant(self.W[1].bias, 0)
53 | else:
54 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
55 | kernel_size=1, stride=1, padding=0)
56 | nn.init.constant(self.W.weight, 0)
57 | nn.init.constant(self.W.bias, 0)
58 |
59 | self.theta = None
60 | self.phi = None
61 |
62 | if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']:
63 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
64 | kernel_size=1, stride=1, padding=0)
65 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
66 | kernel_size=1, stride=1, padding=0)
67 |
68 | if mode in ['concatenation']:
69 | self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False)
70 | self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False)
71 | elif mode in ['concat_proper', 'concat_proper_down']:
72 | self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1,
73 | padding=0, bias=True)
74 |
75 | if mode == 'embedded_gaussian':
76 | self.operation_function = self._embedded_gaussian
77 | elif mode == 'dot_product':
78 | self.operation_function = self._dot_product
79 | elif mode == 'gaussian':
80 | self.operation_function = self._gaussian
81 | elif mode == 'concatenation':
82 | self.operation_function = self._concatenation
83 | elif mode == 'concat_proper':
84 | self.operation_function = self._concatenation_proper
85 | elif mode == 'concat_proper_down':
86 | self.operation_function = self._concatenation_proper_down
87 | else:
88 | raise NotImplementedError('Unknown operation function.')
89 |
90 | if any(ss > 1 for ss in self.sub_sample_factor):
91 | self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor))
92 | if self.phi is None:
93 | self.phi = max_pool(kernel_size=sub_sample_factor)
94 | else:
95 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor))
96 | if mode == 'concat_proper_down':
97 | self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor))
98 |
99 | # Initialise weights
100 | for m in self.children():
101 | init_weights(m, init_type='kaiming')
102 |
103 | def forward(self, x):
104 | '''
105 | :param x: (b, c, t, h, w)
106 | :return:
107 | '''
108 |
109 | output = self.operation_function(x)
110 | return output
111 |
112 | def _embedded_gaussian(self, x):
113 | batch_size = x.size(0)
114 |
115 | # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c)
116 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
117 | g_x = g_x.permute(0, 2, 1)
118 |
119 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
120 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
121 | # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)
122 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
123 | theta_x = theta_x.permute(0, 2, 1)
124 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
125 | f = torch.matmul(theta_x, phi_x)
126 | f_div_C = F.softmax(f, dim=-1)
127 |
128 | # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w)
129 | y = torch.matmul(f_div_C, g_x)
130 | y = y.permute(0, 2, 1).contiguous()
131 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
132 | W_y = self.W(y)
133 | z = W_y + x
134 |
135 | return z
136 |
137 | def _gaussian(self, x):
138 | batch_size = x.size(0)
139 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
140 | g_x = g_x.permute(0, 2, 1)
141 |
142 | theta_x = x.view(batch_size, self.in_channels, -1)
143 | theta_x = theta_x.permute(0, 2, 1)
144 |
145 | if self.sub_sample_factor > 1:
146 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
147 | else:
148 | phi_x = x.view(batch_size, self.in_channels, -1)
149 |
150 | f = torch.matmul(theta_x, phi_x)
151 | f_div_C = F.softmax(f, dim=-1)
152 |
153 | y = torch.matmul(f_div_C, g_x)
154 | y = y.permute(0, 2, 1).contiguous()
155 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
156 | W_y = self.W(y)
157 | z = W_y + x
158 |
159 | return z
160 |
161 | def _dot_product(self, x):
162 | batch_size = x.size(0)
163 |
164 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
165 | g_x = g_x.permute(0, 2, 1)
166 |
167 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
168 | theta_x = theta_x.permute(0, 2, 1)
169 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
170 | f = torch.matmul(theta_x, phi_x)
171 | N = f.size(-1)
172 | f_div_C = f / N
173 |
174 | y = torch.matmul(f_div_C, g_x)
175 | y = y.permute(0, 2, 1).contiguous()
176 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
177 | W_y = self.W(y)
178 | z = W_y + x
179 |
180 | return z
181 |
182 | def _concatenation(self, x):
183 | batch_size = x.size(0)
184 |
185 | # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
186 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
187 |
188 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
189 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c)
190 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1)
191 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1)
192 |
193 | # theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw)
194 | # phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw)
195 | # f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw)
196 | f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \
197 | self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1))
198 | f = F.relu(f, inplace=True)
199 |
200 | # Normalise the relations
201 | N = f.size(-1)
202 | f_div_c = f / N
203 |
204 | # g(x_j) * f(x_j, x_i)
205 | # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
206 | y = torch.matmul(g_x, f_div_c)
207 | y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:])
208 | W_y = self.W(y)
209 | z = W_y + x
210 |
211 | return z
212 |
213 | def _concatenation_proper(self, x):
214 | batch_size = x.size(0)
215 |
216 | # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
217 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
218 |
219 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
220 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)
221 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
222 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
223 |
224 | # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)
225 | # phi => (b, 0.5c, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw)
226 | # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)
227 | f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \
228 | phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))
229 | f = F.relu(f, inplace=True)
230 |
231 | # psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)
232 | f = torch.squeeze(self.psi(f), dim=1)
233 |
234 | # Normalise the relations
235 | f_div_c = F.softmax(f, dim=1)
236 |
237 | # g(x_j) * f(x_j, x_i)
238 | # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
239 | y = torch.matmul(g_x, f_div_c)
240 | y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:])
241 | W_y = self.W(y)
242 | z = W_y + x
243 |
244 | return z
245 |
246 | def _concatenation_proper_down(self, x):
247 | batch_size = x.size(0)
248 |
249 | # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2)
250 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
251 |
252 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
253 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2)
254 | theta_x = self.theta(x)
255 | downsampled_size = theta_x.size()
256 | theta_x = theta_x.view(batch_size, self.inter_channels, -1)
257 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
258 |
259 | # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw)
260 | # phi => (b, 0.5, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw)
261 | # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw)
262 | f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \
263 | phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2))
264 | f = F.relu(f, inplace=True)
265 |
266 | # psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw)
267 | f = torch.squeeze(self.psi(f), dim=1)
268 |
269 | # Normalise the relations
270 | f_div_c = F.softmax(f, dim=1)
271 |
272 | # g(x_j) * f(x_j, x_i)
273 | # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw)
274 | y = torch.matmul(g_x, f_div_c)
275 | y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:])
276 |
277 | # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3)
278 | y = F.upsample(y, size=x.size()[2:], mode='trilinear')
279 |
280 | # attention block output
281 | W_y = self.W(y)
282 | z = W_y + x
283 |
284 | return z
285 |
286 |
287 | class NONLocalBlock1D(_NonLocalBlockND):
288 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):
289 | super(NONLocalBlock1D, self).__init__(in_channels,
290 | inter_channels=inter_channels,
291 | dimension=1, mode=mode,
292 | sub_sample_factor=sub_sample_factor,
293 | bn_layer=bn_layer)
294 |
295 |
296 | class NONLocalBlock2D(_NonLocalBlockND):
297 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):
298 | super(NONLocalBlock2D, self).__init__(in_channels,
299 | inter_channels=inter_channels,
300 | dimension=2, mode=mode,
301 | sub_sample_factor=sub_sample_factor,
302 | bn_layer=bn_layer)
303 |
304 |
305 | class NONLocalBlock3D(_NonLocalBlockND):
306 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True):
307 | super(NONLocalBlock3D, self).__init__(in_channels,
308 | inter_channels=inter_channels,
309 | dimension=3, mode=mode,
310 | sub_sample_factor=sub_sample_factor,
311 | bn_layer=bn_layer)
312 |
313 |
314 | if __name__ == '__main__':
315 | from torch.autograd import Variable
316 |
317 | mode_list = ['concatenation']
318 | #mode_list = ['embedded_gaussian', 'gaussian', 'dot_product', ]
319 |
320 | for mode in mode_list:
321 | print(mode)
322 | img = Variable(torch.zeros(2, 4, 5))
323 | net = NONLocalBlock1D(4, mode=mode, sub_sample_factor=2)
324 | out = net(img)
325 | print(out.size())
326 |
327 | img = Variable(torch.zeros(2, 4, 5, 3))
328 | net = NONLocalBlock2D(4, mode=mode, sub_sample_factor=1, bn_layer=False)
329 | out = net(img)
330 | print(out.size())
331 |
332 | img = Variable(torch.zeros(2, 4, 5, 4, 5))
333 | net = NONLocalBlock3D(4, mode=mode)
334 | out = net(img)
335 | print(out.size())
336 |
--------------------------------------------------------------------------------
/dataio/transformation/myImageTransformations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import scipy.ndimage
4 | from scipy.ndimage.filters import gaussian_filter
5 | from scipy.ndimage.interpolation import map_coordinates
6 | import collections
7 | from PIL import Image
8 | import numbers
9 |
10 |
11 | def center_crop(x, center_crop_size):
12 | assert x.ndim == 3
13 | centerw, centerh = x.shape[1] // 2, x.shape[2] // 2
14 | halfw, halfh = center_crop_size[0] // 2, center_crop_size[1] // 2
15 | return x[:, centerw - halfw:centerw + halfw, centerh - halfh:centerh + halfh]
16 |
17 |
18 | def to_tensor(x):
19 | import torch
20 | x = x.transpose((2, 0, 1))
21 | print(x.shape)
22 | return torch.from_numpy(x).float()
23 |
24 |
25 | def random_num_generator(config, random_state=np.random):
26 | if config[0] == 'uniform':
27 | ret = random_state.uniform(config[1], config[2], 1)[0]
28 | elif config[0] == 'lognormal':
29 | ret = random_state.lognormal(config[1], config[2], 1)[0]
30 | else:
31 | print(config)
32 | raise Exception('unsupported format')
33 | return ret
34 |
35 |
36 | def poisson_downsampling(image, peak, random_state=np.random):
37 | if not isinstance(image, np.ndarray):
38 | imgArr = np.array(image, dtype='float32')
39 | else:
40 | imgArr = image.astype('float32')
41 | Q = imgArr.max(axis=(0, 1)) / peak
42 | if Q[0] == 0:
43 | return imgArr
44 | ima_lambda = imgArr / Q
45 | noisy_img = random_state.poisson(lam=ima_lambda)
46 | return noisy_img.astype('float32')
47 |
48 |
49 | def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random):
50 | """Elastic deformation of image as described in [Simard2003]_.
51 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
52 | Convolutional Neural Networks applied to Visual Document Analysis", in
53 | Proc. of the International Conference on Document Analysis and
54 | Recognition, 2003.
55 | """
56 | assert image.ndim == 3
57 | shape = image.shape[:2]
58 |
59 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
60 | sigma, mode="constant", cval=0) * alpha
61 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
62 | sigma, mode="constant", cval=0) * alpha
63 |
64 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
65 | indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
66 | result = np.empty_like(image)
67 | for i in range(image.shape[2]):
68 | result[:, :, i] = map_coordinates(
69 | image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
70 | return result
71 |
72 |
73 | class Merge(object):
74 | """Merge a group of images
75 | """
76 |
77 | def __init__(self, axis=-1):
78 | self.axis = axis
79 |
80 | def __call__(self, images):
81 | if isinstance(images, collections.Sequence) or isinstance(images, np.ndarray):
82 | assert all([isinstance(i, np.ndarray)
83 | for i in images]), 'only numpy array is supported'
84 | shapes = [list(i.shape) for i in images]
85 | for s in shapes:
86 | s[self.axis] = None
87 | assert all([s == shapes[0] for s in shapes]
88 | ), 'shapes must be the same except the merge axis'
89 | return np.concatenate(images, axis=self.axis)
90 | else:
91 | raise Exception("obj is not a sequence (list, tuple, etc)")
92 |
93 |
94 | class Split(object):
95 | """Split images into individual arraies
96 | """
97 |
98 | def __init__(self, *slices, **kwargs):
99 | assert isinstance(slices, collections.Sequence)
100 | slices_ = []
101 | for s in slices:
102 | if isinstance(s, collections.Sequence):
103 | slices_.append(slice(*s))
104 | else:
105 | slices_.append(s)
106 | assert all([isinstance(s, slice) for s in slices_]
107 | ), 'slices must be consist of slice instances'
108 | self.slices = slices_
109 | self.axis = kwargs.get('axis', -1)
110 |
111 | def __call__(self, image):
112 | if isinstance(image, np.ndarray):
113 | ret = []
114 | for s in self.slices:
115 | sl = [slice(None)] * image.ndim
116 | sl[self.axis] = s
117 | ret.append(image[sl])
118 | return ret
119 | else:
120 | raise Exception("obj is not an numpy array")
121 |
122 |
123 | class ElasticTransform(object):
124 | """Apply elastic transformation on a numpy.ndarray (H x W x C)
125 | """
126 |
127 | def __init__(self, alpha, sigma):
128 | self.alpha = alpha
129 | self.sigma = sigma
130 |
131 | def __call__(self, image):
132 | if isinstance(self.alpha, collections.Sequence):
133 | alpha = random_num_generator(self.alpha)
134 | else:
135 | alpha = self.alpha
136 | if isinstance(self.sigma, collections.Sequence):
137 | sigma = random_num_generator(self.sigma)
138 | else:
139 | sigma = self.sigma
140 | return elastic_transform(image, alpha=alpha, sigma=sigma)
141 |
142 |
143 | class PoissonSubsampling(object):
144 | """Poisson subsampling on a numpy.ndarray (H x W x C)
145 | """
146 |
147 | def __init__(self, peak, random_state=np.random):
148 | self.peak = peak
149 | self.random_state = random_state
150 |
151 | def __call__(self, image):
152 | if isinstance(self.peak, collections.Sequence):
153 | peak = random_num_generator(
154 | self.peak, random_state=self.random_state)
155 | else:
156 | peak = self.peak
157 | return poisson_downsampling(image, peak, random_state=self.random_state)
158 |
159 |
160 | class AddGaussianNoise(object):
161 | """Add gaussian noise to a numpy.ndarray (H x W x C)
162 | """
163 |
164 | def __init__(self, mean, sigma, random_state=np.random):
165 | self.sigma = sigma
166 | self.mean = mean
167 | self.random_state = random_state
168 |
169 | def __call__(self, image):
170 | if isinstance(self.sigma, collections.Sequence):
171 | sigma = random_num_generator(self.sigma, random_state=self.random_state)
172 | else:
173 | sigma = self.sigma
174 | if isinstance(self.mean, collections.Sequence):
175 | mean = random_num_generator(self.mean, random_state=self.random_state)
176 | else:
177 | mean = self.mean
178 | row, col, ch = image.shape
179 | gauss = self.random_state.normal(mean, sigma, (row, col, ch))
180 | gauss = gauss.reshape(row, col, ch)
181 | image += gauss
182 | return image
183 |
184 |
185 | class AddSpeckleNoise(object):
186 | """Add speckle noise to a numpy.ndarray (H x W x C)
187 | """
188 |
189 | def __init__(self, mean, sigma, random_state=np.random):
190 | self.sigma = sigma
191 | self.mean = mean
192 | self.random_state = random_state
193 |
194 | def __call__(self, image):
195 | if isinstance(self.sigma, collections.Sequence):
196 | sigma = random_num_generator(
197 | self.sigma, random_state=self.random_state)
198 | else:
199 | sigma = self.sigma
200 | if isinstance(self.mean, collections.Sequence):
201 | mean = random_num_generator(
202 | self.mean, random_state=self.random_state)
203 | else:
204 | mean = self.mean
205 | row, col, ch = image.shape
206 | gauss = self.random_state.normal(mean, sigma, (row, col, ch))
207 | gauss = gauss.reshape(row, col, ch)
208 | image += image * gauss
209 | return image
210 |
211 |
212 | class GaussianBlurring(object):
213 | """Apply gaussian blur to a numpy.ndarray (H x W x C)
214 | """
215 |
216 | def __init__(self, sigma, random_state=np.random):
217 | self.sigma = sigma
218 | self.random_state = random_state
219 |
220 | def __call__(self, image):
221 | if isinstance(self.sigma, collections.Sequence):
222 | sigma = random_num_generator(
223 | self.sigma, random_state=self.random_state)
224 | else:
225 | sigma = self.sigma
226 | image = gaussian_filter(image, sigma=(sigma, sigma, 0))
227 | return image
228 |
229 |
230 | class AddGaussianPoissonNoise(object):
231 | """Add poisson noise with gaussian blurred image to a numpy.ndarray (H x W x C)
232 | """
233 |
234 | def __init__(self, sigma, peak, random_state=np.random):
235 | self.sigma = sigma
236 | self.peak = peak
237 | self.random_state = random_state
238 |
239 | def __call__(self, image):
240 | if isinstance(self.sigma, collections.Sequence):
241 | sigma = random_num_generator(
242 | self.sigma, random_state=self.random_state)
243 | else:
244 | sigma = self.sigma
245 | if isinstance(self.peak, collections.Sequence):
246 | peak = random_num_generator(
247 | self.peak, random_state=self.random_state)
248 | else:
249 | peak = self.peak
250 | bg = gaussian_filter(image, sigma=(sigma, sigma, 0))
251 | bg = poisson_downsampling(
252 | bg, peak=peak, random_state=self.random_state)
253 | return image + bg
254 |
255 |
256 | class MaxScaleNumpy(object):
257 | """scale with max and min of each channel of the numpy array i.e.
258 | channel = (channel - mean) / std
259 | """
260 |
261 | def __init__(self, range_min=0.0, range_max=1.0):
262 | self.scale = (range_min, range_max)
263 |
264 | def __call__(self, image):
265 | mn = image.min(axis=(0, 1))
266 | mx = image.max(axis=(0, 1))
267 | return self.scale[0] + (image - mn) * (self.scale[1] - self.scale[0]) / (mx - mn)
268 |
269 |
270 | class MedianScaleNumpy(object):
271 | """Scale with median and mean of each channel of the numpy array i.e.
272 | channel = (channel - mean) / std
273 | """
274 |
275 | def __init__(self, range_min=0.0, range_max=1.0):
276 | self.scale = (range_min, range_max)
277 |
278 | def __call__(self, image):
279 | mn = image.min(axis=(0, 1))
280 | md = np.median(image, axis=(0, 1))
281 | return self.scale[0] + (image - mn) * (self.scale[1] - self.scale[0]) / (md - mn)
282 |
283 |
284 | class NormalizeNumpy(object):
285 | """Normalize each channel of the numpy array i.e.
286 | channel = (channel - mean) / std
287 | """
288 |
289 | def __call__(self, image):
290 | image -= image.mean(axis=(0, 1))
291 | s = image.std(axis=(0, 1))
292 | s[s == 0] = 1.0
293 | image /= s
294 | return image
295 |
296 |
297 | class MutualExclude(object):
298 | """Remove elements from one channel
299 | """
300 |
301 | def __init__(self, exclude_channel, from_channel):
302 | self.from_channel = from_channel
303 | self.exclude_channel = exclude_channel
304 |
305 | def __call__(self, image):
306 | mask = image[:, :, self.exclude_channel] > 0
307 | image[:, :, self.from_channel][mask] = 0
308 | return image
309 |
310 |
311 | class RandomCropNumpy(object):
312 | """Crops the given numpy array at a random location to have a region of
313 | the given size. size can be a tuple (target_height, target_width)
314 | or an integer, in which case the target will be of a square shape (size, size)
315 | """
316 |
317 | def __init__(self, size, random_state=np.random):
318 | if isinstance(size, numbers.Number):
319 | self.size = (int(size), int(size))
320 | else:
321 | self.size = size
322 | self.random_state = random_state
323 |
324 | def __call__(self, img):
325 | w, h = img.shape[:2]
326 | th, tw = self.size
327 | if w == tw and h == th:
328 | return img
329 |
330 | x1 = self.random_state.randint(0, w - tw)
331 | y1 = self.random_state.randint(0, h - th)
332 | return img[x1:x1 + tw, y1: y1 + th, :]
333 |
334 |
335 | class CenterCropNumpy(object):
336 | """Crops the given numpy array at the center to have a region of
337 | the given size. size can be a tuple (target_height, target_width)
338 | or an integer, in which case the target will be of a square shape (size, size)
339 | """
340 |
341 | def __init__(self, size):
342 | if isinstance(size, numbers.Number):
343 | self.size = (int(size), int(size))
344 | else:
345 | self.size = size
346 |
347 | def __call__(self, img):
348 | w, h = img.shape[:2]
349 | th, tw = self.size
350 | x1 = int(round((w - tw) / 2.))
351 | y1 = int(round((h - th) / 2.))
352 | return img[x1:x1 + tw, y1: y1 + th, :]
353 |
354 |
355 | class RandomRotate(object):
356 | """Rotate a PIL.Image or numpy.ndarray (H x W x C) randomly
357 | """
358 |
359 | def __init__(self, angle_range=(0.0, 360.0), axes=(0, 1), mode='reflect', random_state=np.random):
360 | assert isinstance(angle_range, tuple)
361 | self.angle_range = angle_range
362 | self.random_state = random_state
363 | self.axes = axes
364 | self.mode = mode
365 |
366 | def __call__(self, image):
367 | angle = self.random_state.uniform(
368 | self.angle_range[0], self.angle_range[1])
369 | if isinstance(image, np.ndarray):
370 | mi, ma = image.min(), image.max()
371 | image = scipy.ndimage.interpolation.rotate(
372 | image, angle, reshape=False, axes=self.axes, mode=self.mode)
373 | return np.clip(image, mi, ma)
374 | elif isinstance(image, Image.Image):
375 | return image.rotate(angle)
376 | else:
377 | raise Exception('unsupported type')
378 |
379 |
380 | class BilinearResize(object):
381 | """Resize a PIL.Image or numpy.ndarray (H x W x C)
382 | """
383 |
384 | def __init__(self, zoom):
385 | self.zoom = [zoom, zoom, 1]
386 |
387 | def __call__(self, image):
388 | if isinstance(image, np.ndarray):
389 | return scipy.ndimage.interpolation.zoom(image, self.zoom)
390 | elif isinstance(image, Image.Image):
391 | return image.resize(self.size, Image.BILINEAR)
392 | else:
393 | raise Exception('unsupported type')
394 |
395 |
396 | class EnhancedCompose(object):
397 | """Composes several transforms together.
398 | Args:
399 | transforms (List[Transform]): list of transforms to compose.
400 | Example:
401 | >>> transforms.Compose([
402 | >>> transforms.CenterCrop(10),
403 | >>> transforms.ToTensor(),
404 | >>> ])
405 | """
406 |
407 | def __init__(self, transforms):
408 | self.transforms = transforms
409 |
410 | def __call__(self, img):
411 | for t in self.transforms:
412 | if isinstance(t, collections.Sequence):
413 | assert isinstance(img, collections.Sequence) and len(img) == len(
414 | t), "size of image group and transform group does not fit"
415 | tmp_ = []
416 | for i, im_ in enumerate(img):
417 | if callable(t[i]):
418 | tmp_.append(t[i](im_))
419 | else:
420 | tmp_.append(im_)
421 | img = tmp_
422 | elif callable(t):
423 | img = t(img)
424 | elif t is None:
425 | continue
426 | else:
427 | raise Exception('unexpected type')
428 | return img
429 |
430 |
431 | if __name__ == '__main__':
432 | from torchvision.transforms import Lambda
433 |
434 | input_channel = 3
435 | target_channel = 3
436 |
437 | # define a transform pipeline
438 | transform = EnhancedCompose([
439 | Merge(),
440 | RandomCropNumpy(size=(512, 512)),
441 | RandomRotate(),
442 | Split([0, input_channel], [input_channel, input_channel + target_channel]),
443 | [CenterCropNumpy(size=(256, 256)), CenterCropNumpy(size=(256, 256))],
444 | [NormalizeNumpy(), MaxScaleNumpy(0, 1.0)],
445 | # for non-pytorch usage, remove to_tensor conversion
446 | [Lambda(to_tensor), Lambda(to_tensor)]
447 | ])
448 | # read input dataio for test
449 | image_in = np.array(Image.open('input.jpg'))
450 | image_target = np.array(Image.open('target.jpg'))
451 |
452 | # apply the transform
453 | x, y = transform([image_in, image_target])
--------------------------------------------------------------------------------