├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── data.py ├── losses.py ├── metrics.py ├── models ├── __init__.py ├── build.py └── unet.py ├── nvidia-smi.py ├── solver ├── __init__.py ├── build.py └── scheduler.py ├── train_val.py └── utils ├── __init__.py ├── logger.py ├── metric ├── __init__.py ├── binary.py ├── histogram.py ├── image.py ├── inception_score.py ├── metrics_torch.py ├── sim_torch.py ├── similarity.py └── utils.py ├── metric_logger.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # log file 107 | /logs 108 | # pycharm 109 | /.idea 110 | # vscode 111 | /.vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Eric Ching 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # brats_segmentation-pytorch 2 | 3d unet + vae, reproduce brats2018 winner solution 3 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 14:13 4 | # @Author : Eric Ching 5 | from yacs.config import CfgNode as CN 6 | import platform 7 | 8 | _C = CN() 9 | _C.DATASET = CN() 10 | 11 | if "Win" in platform.system(): 12 | _C.DATASET.DATA_ROOT = 'G:/data_repos/Brats2018' 13 | else: 14 | _C.DATASET.DATA_ROOT = "/home/share/data_repos/Brats2018" 15 | 16 | _C.DATASET.NUM_FOLDS = 4 17 | _C.DATASET.SELECT_FOLD = 0 18 | _C.DATASET.USE_MODES = ("t1", "t2", "flair", "t1ce") 19 | _C.DATASET.INPUT_SHAPE = (160, 192, 128) 20 | 21 | _C.DATALOADER = CN() 22 | _C.DATALOADER.BATCH_SIZE = 1 23 | _C.DATALOADER.NUM_WORKERS = 6 24 | 25 | _C.MODEL = CN() 26 | _C.MODEL.NAME = 'unet-vae' 27 | _C.MODEL.INIT_CHANNELS = 16 28 | _C.MODEL.DROPOUT = 0.2 29 | _C.MODEL.LOSS_WEIGHT = 0.1 30 | 31 | _C.SOLVER = CN() 32 | _C.SOLVER.LEARNING_RATE = 1e-3 33 | _C.SOLVER.WEIGHT_DECAY = 1e-5 34 | _C.SOLVER.POWER = 0.9 35 | _C.SOLVER.NUM_EPOCHS = 300 36 | 37 | _C.MISC = CN() 38 | _C.LOG_DIR = './logs' 39 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 14:18 4 | # @Author : Eric Ching 5 | import glob 6 | import os 7 | import numpy as np 8 | import nibabel as nib 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | import random 12 | 13 | class Brats2018(Dataset): 14 | 15 | def __init__(self, patients_dir, crop_size, modes, train=True): 16 | self.patients_dir = patients_dir 17 | self.modes = modes 18 | self.train = train 19 | self.crop_size = crop_size 20 | 21 | def __len__(self): 22 | return len(self.patients_dir) 23 | 24 | def __getitem__(self, index): 25 | patient_dir = self.patients_dir[index] 26 | volumes = [] 27 | modes = list(self.modes) + ['seg'] 28 | for mode in modes: 29 | patient_id = os.path.split(patient_dir)[-1] 30 | volume_path = os.path.join(patient_dir, patient_id + "_" + mode + '.nii') 31 | volume = nib.load(volume_path).get_data() 32 | if not mode == "seg": 33 | volume = self.normlize(volume) # [0, 1.0] 34 | volumes.append(volume) # [h, w, d] 35 | seg_volume = volumes[-1] 36 | volumes = volumes[:-1] 37 | volume, seg_volume = self.aug_sample(volumes, seg_volume) 38 | wt_volume = seg_volume > 0 # 坏死和无增强的肿瘤区域:1、增强区域(活跃部分):4、周边水肿区域:2 39 | tc_volume = np.logical_or(seg_volume == 4, seg_volume == 1) 40 | et_volume = (seg_volume == 4) 41 | seg_volume = [wt_volume, tc_volume, et_volume] 42 | seg_volume = np.concatenate(seg_volume, axis=0).astype("float32") 43 | 44 | return (torch.tensor(volume.copy(), dtype=torch.float), 45 | torch.tensor(seg_volume.copy(), dtype=torch.float)) 46 | 47 | 48 | def aug_sample(self, volumes, mask): 49 | """ 50 | Args: 51 | volumes: list of array, [h, w, d] 52 | mask: array [h, w, d], segmentation volume 53 | Ret: x, y: [channel, h, w, d] 54 | 55 | """ 56 | x = np.stack(volumes, axis=0) # [N, H, W, D] 57 | y = np.expand_dims(mask, axis=0) # [channel, h, w, d] 58 | 59 | if self.train: 60 | # crop volume 61 | x, y = self.random_crop(x, y) 62 | if random.random() < 0.5: 63 | x = np.flip(x, axis=1) 64 | y = np.flip(y, axis=1) 65 | if random.random() < 0.5: 66 | x = np.flip(x, axis=2) 67 | y = np.flip(y, axis=2) 68 | if random.random() < 0.5: 69 | x = np.flip(x, axis=3) 70 | y = np.flip(y, axis=3) 71 | else: 72 | x, y = self.center_crop(x, y) 73 | 74 | return x, y 75 | 76 | def random_crop(self, x, y): 77 | """ 78 | Args: 79 | x: 4d array, [channel, h, w, d] 80 | """ 81 | crop_size = self.crop_size 82 | height, width, depth = x.shape[-3:] 83 | sx = random.randint(0, height - crop_size[0] - 1) 84 | sy = random.randint(0, width - crop_size[1] - 1) 85 | sz = random.randint(0, depth - crop_size[2] - 1) 86 | crop_volume = x[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]] 87 | crop_seg = y[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]] 88 | 89 | return crop_volume, crop_seg 90 | 91 | def center_crop(self, x, y): 92 | crop_size = self.crop_size 93 | height, width, depth = x.shape[-3:] 94 | sx = (height - crop_size[0] - 1) // 2 95 | sy = (width - crop_size[1] - 1) // 2 96 | sz = (depth - crop_size[2] - 1) // 2 97 | crop_volume = x[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]] 98 | crop_seg = y[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]] 99 | 100 | return crop_volume, crop_seg 101 | 102 | def normlize(self, x): 103 | return (x - x.min()) / (x.max() - x.min()) 104 | 105 | def split_dataset(data_root, nfold=4, seed=42, select=0): 106 | patients_dir = glob.glob(os.path.join(data_root, "*GG", "Brats18*")) 107 | n_patients = len(patients_dir) 108 | print(f"total patients: {n_patients}") 109 | pid_idx = np.arange(n_patients) 110 | np.random.seed(seed) 111 | np.random.shuffle(pid_idx) 112 | n_fold_list = np.split(pid_idx, nfold) 113 | print(f"split {len(n_fold_list)} folds and every fold have {len(n_fold_list[0])} patients") 114 | val_patients_list = [] 115 | train_patients_list = [] 116 | for i, fold in enumerate(n_fold_list): 117 | if i == select: 118 | for idx in fold: 119 | val_patients_list.append(patients_dir[idx]) 120 | else: 121 | for idx in fold: 122 | train_patients_list.append(patients_dir[idx]) 123 | print(f"train patients: {len(train_patients_list)}, test patients: {len(val_patients_list)}") 124 | 125 | return train_patients_list, val_patients_list 126 | 127 | def make_data_loaders(cfg): 128 | train_list, val_list = split_dataset(cfg.DATASET.DATA_ROOT, cfg.DATASET.NUM_FOLDS, cfg.DATASET.SELECT_FOLD) 129 | train_ds = Brats2018(train_list, crop_size=cfg.DATASET.INPUT_SHAPE, modes=cfg.DATASET.USE_MODES, train=True) 130 | val_ds = Brats2018(val_list, crop_size=cfg.DATASET.INPUT_SHAPE, modes=cfg.DATASET.USE_MODES, train=False) 131 | loaders = {} 132 | loaders['train'] = DataLoader(train_ds, batch_size=cfg.DATALOADER.BATCH_SIZE, 133 | num_workers=cfg.DATALOADER.NUM_WORKERS, 134 | pin_memory=True, 135 | shuffle=True) 136 | loaders['eval'] = DataLoader(val_ds, batch_size=cfg.DATALOADER.BATCH_SIZE, 137 | num_workers=cfg.DATALOADER.NUM_WORKERS, 138 | pin_memory=True, 139 | shuffle=False) 140 | return loaders 141 | 142 | 143 | if __name__ == "__main__": 144 | from config import _C as cfg 145 | train_list, val_list = split_dataset(cfg.DATASET.DATA_ROOT, cfg.DATASET.NUM_FOLDS, cfg.DATASET.SELECT_FOLD) 146 | train_ds = Brats2018(train_list, crop_size=cfg.DATASET.INPUT_SHAPE, modes=cfg.DATASET.USE_MODES, train=True) 147 | val_ds = Brats2018(val_list, crop_size=cfg.DATASET.INPUT_SHAPE, modes=cfg.DATASET.USE_MODES, train=False) 148 | 149 | for i in range(len(train_ds)): 150 | x, y = train_ds[i] 151 | volume = (x.numpy()[0] * 255).astype('uint8') 152 | seg = (np.sum(y.numpy(), axis=0)).astype('uint8') 153 | volume = nib.Nifti1Image(volume, np.eye(4)) 154 | seg = nib.Nifti1Image(seg, np.eye(4)) 155 | nib.save(volume, 'test'+str(i)+'_volume.nii') 156 | nib.save(seg, 'test' + str(i) + '_seg.nii') -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 16:31 4 | # @Author : Eric Ching 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | 9 | def dice_loss(input, target): 10 | """soft dice loss""" 11 | eps = 1e-7 12 | iflat = input.view(-1) 13 | tflat = target.view(-1) 14 | intersection = (iflat * tflat).sum() 15 | 16 | return 1 - 2. * intersection / ((iflat ** 2).sum() + (tflat ** 2).sum() + eps) 17 | 18 | 19 | def vae_loss(recon_x, x, mu, logvar): 20 | loss_dict = {} 21 | loss_dict['KLD'] = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 22 | loss_dict['recon_loss'] = F.mse_loss(recon_x, x, reduction='mean') 23 | 24 | return loss_dict 25 | 26 | def unet_vae_loss(cfg, batch_pred, batch_x, batch_y, vout, mu, logvar): 27 | loss_dict = {} 28 | loss_dict['wt_loss'] = dice_loss(batch_pred[:, 0], batch_y[:, 0]) # whole tumor 29 | loss_dict['tc_loss'] = dice_loss(batch_pred[:, 1], batch_y[:, 1]) # tumore core 30 | loss_dict['et_loss'] = dice_loss(batch_pred[:, 2], batch_y[:, 2]) # enhance tumor 31 | loss_dict.update(vae_loss(vout, batch_x, mu, logvar)) 32 | weight = cfg.MODEL.LOSS_WEIGHT 33 | loss_dict['loss'] = loss_dict['wt_loss'] + loss_dict['tc_loss'] + loss_dict['et_loss'] + \ 34 | weight * loss_dict['recon_loss'] + weight * loss_dict['KLD'] 35 | 36 | return loss_dict 37 | 38 | def get_losses(cfg): 39 | losses = {} 40 | losses['vae'] = vae_loss 41 | losses['dice'] = dice_loss 42 | losses['dice_vae'] = unet_vae_loss 43 | 44 | return losses 45 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 16:33 4 | # @Author : Eric Ching 5 | import torch.nn as nn 6 | import numpy as np 7 | from utils.metric.binary import hd 8 | 9 | def dice_coef(input, target, threshold=0.5): 10 | smooth = 1. 11 | iflat = (input.view(-1) > threshold).float() 12 | tflat = target.view(-1) 13 | intersection = (iflat * tflat).sum() 14 | 15 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) 16 | 17 | def dice_coef_np(input, target, eps=1e-7): 18 | input = np.ravel(input) 19 | target = np.ravel(target) 20 | intersection = (input * target).sum() 21 | 22 | return (2. * intersection) / (input.sum() + target.sum() + eps) 23 | 24 | def hausdorff(batch_pred, batch_y, threshold=0.5): 25 | """batch size must equal 1""" 26 | batch_pred = batch_pred.cpu().squeeze().numpy() > threshold 27 | batch_y = batch_y.cpu().squeeze().numpy() 28 | metric_dict = {} 29 | try: 30 | metric_dict['wt_hd'] = hd(batch_pred[0], batch_y[0]) 31 | except: 32 | metric_dict['wt_hd'] = 1.0 33 | print("wt have zero object") 34 | try: 35 | metric_dict['tc_hd'] = hd(batch_pred[1], batch_y[1]) 36 | except: 37 | metric_dict['tc_hd'] = 1.0 38 | print("tc have zero object") 39 | try: 40 | metric_dict['et_hd'] = hd(batch_pred[2], batch_y[2]) 41 | except: 42 | metric_dict['et_hd'] = 1.0 43 | print("et have zero object") 44 | 45 | return metric_dict 46 | 47 | def get_metrics(cfg): 48 | metrics = {} 49 | metrics["mse"] = nn.MSELoss().cuda() 50 | metrics["hd"] = hausdorff 51 | 52 | return metrics -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 15:22 4 | # @Author : Eric Ching 5 | from .build import build_model 6 | -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 15:56 4 | # @Author : Eric Ching 5 | from .unet import UNet3D, UnetVAE3D 6 | 7 | 8 | def build_model(cfg): 9 | if cfg.MODEL.NAME == 'unet-vae': 10 | model = UnetVAE3D(cfg.DATASET.INPUT_SHAPE, 11 | in_channels=len(cfg.DATASET.USE_MODES), 12 | out_channels=3, 13 | init_channels=cfg.MODEL.INIT_CHANNELS, 14 | p=cfg.MODEL.DROPOUT) 15 | else: 16 | model = UNet3D(cfg.DATASET.INPUT_SHAPE, 17 | in_channels=len(cfg.DATASET.USE_MODES), 18 | out_channels=3, 19 | init_channels=cfg.MODEL.INIT_CHANNELS, 20 | p=cfg.MODEL.DROPOUT) 21 | 22 | return model 23 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 15:22 4 | # @Author : Eric Ching 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | def __init__(self, in_channels, out_channels, n_groups=8): 13 | super(BasicBlock, self).__init__() 14 | self.gn1 = nn.GroupNorm(n_groups, in_channels) 15 | self.relu1 = nn.ReLU(inplace=True) 16 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 17 | self.gn2 = nn.GroupNorm(n_groups, in_channels) 18 | self.relu2 = nn.ReLU(inplace=True) 19 | self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 20 | 21 | def forward(self, x): 22 | residul = x 23 | x = self.relu1(self.gn1(x)) 24 | x = self.conv1(x) 25 | 26 | x = self.relu2(self.gn2(x)) 27 | x = self.conv2(x) 28 | x = x + residul 29 | 30 | return x 31 | 32 | 33 | class VAEBranch(nn.Module): 34 | 35 | def __init__(self, input_shape, init_channels, out_channels, squeeze_channels=None): 36 | super(VAEBranch, self).__init__() 37 | self.input_shape = input_shape 38 | 39 | if squeeze_channels: 40 | self.squeeze_channels = squeeze_channels 41 | else: 42 | self.squeeze_channels = init_channels * 4 43 | 44 | self.hidden_conv = nn.Sequential(nn.GroupNorm(8, init_channels * 8), 45 | nn.ReLU(inplace=True), 46 | nn.Conv3d(init_channels * 8, self.squeeze_channels, (3, 3, 3), 47 | padding=(1, 1, 1)), 48 | nn.AdaptiveAvgPool3d(1)) 49 | 50 | self.mu_fc = nn.Linear(self.squeeze_channels // 2, self.squeeze_channels // 2) 51 | self.logvar_fc = nn.Linear(self.squeeze_channels // 2, self.squeeze_channels // 2) 52 | 53 | recon_shape = np.prod(self.input_shape) // (16 ** 3) 54 | 55 | self.reconstraction = nn.Sequential(nn.Linear(self.squeeze_channels // 2, init_channels * 8 * recon_shape), 56 | nn.ReLU(inplace=True)) 57 | 58 | self.vconv4 = nn.Sequential(nn.Conv3d(init_channels * 8, init_channels * 8, (1, 1, 1)), 59 | nn.Upsample(scale_factor=2)) 60 | 61 | self.vconv3 = nn.Sequential(nn.Conv3d(init_channels * 8, init_channels * 4, (3, 3, 3), padding=(1, 1, 1)), 62 | nn.Upsample(scale_factor=2), 63 | BasicBlock(init_channels * 4, init_channels * 4)) 64 | 65 | self.vconv2 = nn.Sequential(nn.Conv3d(init_channels * 4, init_channels * 2, (3, 3, 3), padding=(1, 1, 1)), 66 | nn.Upsample(scale_factor=2), 67 | BasicBlock(init_channels * 2, init_channels * 2)) 68 | 69 | self.vconv1 = nn.Sequential(nn.Conv3d(init_channels * 2, init_channels, (3, 3, 3), padding=(1, 1, 1)), 70 | nn.Upsample(scale_factor=2), 71 | BasicBlock(init_channels, init_channels)) 72 | 73 | self.vconv0 = nn.Conv3d(init_channels, out_channels, (1, 1, 1)) 74 | 75 | def reparameterize(self, mu, logvar): 76 | std = torch.exp(0.5 * logvar) 77 | eps = torch.randn_like(std) 78 | 79 | return eps.mul(std).add_(mu) 80 | 81 | def forward(self, x): 82 | x = self.hidden_conv(x) 83 | batch_size = x.size()[0] 84 | x = x.view((batch_size, -1)) 85 | mu = x[:, :self.squeeze_channels // 2] 86 | mu = self.mu_fc(mu) 87 | logvar = x[:, self.squeeze_channels // 2:] 88 | logvar = self.logvar_fc(logvar) 89 | z = self.reparameterize(mu, logvar) 90 | re_x = self.reconstraction(z) 91 | recon_shape = [batch_size, 92 | self.squeeze_channels // 2, 93 | self.input_shape[0] // 16, 94 | self.input_shape[1] // 16, 95 | self.input_shape[2] // 16] 96 | re_x = re_x.view(recon_shape) 97 | x = self.vconv4(re_x) 98 | x = self.vconv3(x) 99 | x = self.vconv2(x) 100 | x = self.vconv1(x) 101 | vout = self.vconv0(x) 102 | 103 | return vout, mu, logvar 104 | 105 | 106 | class UNet3D(nn.Module): 107 | """3d unet 108 | Ref: 109 | 3D MRI brain tumor segmentation using autoencoder regularization. Andriy Myronenko 110 | Args: 111 | input_shape: tuple, (height, width, depth) 112 | """ 113 | 114 | def __init__(self, input_shape, in_channels=4, out_channels=3, init_channels=32, p=0.2): 115 | super(UNet3D, self).__init__() 116 | self.input_shape = input_shape 117 | self.in_channels = in_channels 118 | self.out_channels = out_channels 119 | self.init_channels = init_channels 120 | self.make_encoder() 121 | self.make_decoder() 122 | self.dropout = nn.Dropout(p=p) 123 | 124 | def make_encoder(self): 125 | init_channels = self.init_channels 126 | self.conv1a = nn.Conv3d(self.in_channels, init_channels, (3, 3, 3), padding=(1, 1, 1)) 127 | self.conv1b = BasicBlock(init_channels, init_channels) # 32 128 | 129 | self.ds1 = nn.Conv3d(init_channels, init_channels * 2, (3, 3, 3), stride=(2, 2, 2), 130 | padding=(1, 1, 1)) # down sampling and add channels 131 | 132 | self.conv2a = BasicBlock(init_channels * 2, init_channels * 2) 133 | self.conv2b = BasicBlock(init_channels * 2, init_channels * 2) 134 | 135 | self.ds2 = nn.Conv3d(init_channels * 2, init_channels * 4, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) 136 | 137 | self.conv3a = BasicBlock(init_channels * 4, init_channels * 4) 138 | self.conv3b = BasicBlock(init_channels * 4, init_channels * 4) 139 | 140 | self.ds3 = nn.Conv3d(init_channels * 4, init_channels * 8, (3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) 141 | 142 | self.conv4a = BasicBlock(init_channels * 8, init_channels * 8) 143 | self.conv4b = BasicBlock(init_channels * 8, init_channels * 8) 144 | self.conv4c = BasicBlock(init_channels * 8, init_channels * 8) 145 | self.conv4d = BasicBlock(init_channels * 8, init_channels * 8) 146 | 147 | def make_decoder(self): 148 | init_channels = self.init_channels 149 | self.up4conva = nn.Conv3d(init_channels * 8, init_channels * 4, (1, 1, 1)) 150 | self.up4 = nn.Upsample(scale_factor=2) # mode='bilinear' 151 | self.up4convb = BasicBlock(init_channels * 4, init_channels * 4) 152 | 153 | self.up3conva = nn.Conv3d(init_channels * 4, init_channels * 2, (1, 1, 1)) 154 | self.up3 = nn.Upsample(scale_factor=2) 155 | self.up3convb = BasicBlock(init_channels * 2, init_channels * 2) 156 | 157 | self.up2conva = nn.Conv3d(init_channels * 2, init_channels, (1, 1, 1)) 158 | self.up2 = nn.Upsample(scale_factor=2) 159 | self.up2convb = BasicBlock(init_channels, init_channels) 160 | 161 | self.up1conv = nn.Conv3d(init_channels, self.out_channels, (1, 1, 1)) 162 | 163 | def forward(self, x): 164 | c1 = self.conv1a(x) 165 | c1 = self.conv1b(c1) 166 | c1d = self.ds1(c1) 167 | 168 | c2 = self.conv2a(c1d) 169 | c2 = self.conv2b(c2) 170 | c2d = self.ds2(c2) 171 | 172 | c3 = self.conv3a(c2d) 173 | c3 = self.conv3b(c3) 174 | c3d = self.ds3(c3) 175 | 176 | c4 = self.conv4a(c3d) 177 | c4 = self.conv4b(c4) 178 | c4 = self.conv4c(c4) 179 | c4d = self.conv4d(c4) 180 | 181 | c4d = self.dropout(c4d) 182 | 183 | u4 = self.up4conva(c4d) 184 | u4 = self.up4(u4) 185 | u4 = u4 + c3 186 | u4 = self.up4convb(u4) 187 | 188 | u3 = self.up3conva(u4) 189 | u3 = self.up3(u3) 190 | u3 = u3 + c2 191 | u3 = self.up3convb(u3) 192 | 193 | u2 = self.up2conva(u3) 194 | u2 = self.up2(u2) 195 | u2 = u2 + c1 196 | u2 = self.up2convb(u2) 197 | 198 | uout = self.up1conv(u2) 199 | uout = F.sigmoid(uout) 200 | 201 | return uout, c4d 202 | 203 | 204 | class UnetVAE3D(nn.Module): 205 | 206 | def __init__(self, input_shape, in_channels=4, out_channels=3, init_channels=32, p=0.2): 207 | super(UnetVAE3D, self).__init__() 208 | self.unet = UNet3D(input_shape, in_channels, out_channels, init_channels, p) 209 | self.vae_branch = VAEBranch(input_shape, init_channels, out_channels=in_channels) 210 | 211 | def forward(self, x): 212 | uout, c4d = self.unet(x) 213 | vout, mu, logvar = self.vae_branch(c4d) 214 | 215 | return uout, vout, mu, logvar 216 | -------------------------------------------------------------------------------- /nvidia-smi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/4/2 19:30 4 | # @Author : Eric Ching 5 | import subprocess 6 | import platform 7 | 8 | if 'Win' in platform.system(): 9 | nvidia_smi_path = 'C:/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' 10 | else: 11 | nvidia_smi_path = 'nvidia-smi' 12 | subprocess.call(nvidia_smi_path, shell=True) -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 16:22 4 | # @Author : Eric Ching 5 | from .build import make_optimizer -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 16:24 4 | # @Author : Eric Ching 5 | import torch 6 | from .scheduler import PolyLR 7 | 8 | def make_optimizer(cfg, model): 9 | lr = cfg.SOLVER.LEARNING_RATE 10 | print('initial learning rate is ', lr) 11 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 12 | scheduler = PolyLR(optimizer, max_epoch=cfg.SOLVER.NUM_EPOCHS, power=cfg.SOLVER.POWER) 13 | 14 | return optimizer, scheduler 15 | -------------------------------------------------------------------------------- /solver/scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/3/28 13:57 4 | # @Author : Eric Ching 5 | from torch.optim import lr_scheduler 6 | 7 | class PolyLR(lr_scheduler._LRScheduler): 8 | """Set the learning rate of each parameter group to the initial lr decayed 9 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr. 10 | 11 | Args: 12 | optimizer (Optimizer): Wrapped optimizer. 13 | gamma (float): Multiplicative factor of learning rate decay. 14 | last_epoch (int): The index of last epoch. Default: -1. 15 | """ 16 | 17 | def __init__(self, optimizer, max_epoch, power=0.9, last_epoch=-1): 18 | self.max_epoch = max_epoch 19 | self.power = power 20 | super(PolyLR, self).__init__(optimizer, last_epoch) 21 | 22 | def get_lr(self): 23 | return [base_lr * (1 - self.last_epoch / self.max_epoch) ** self.power 24 | for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 16:14 4 | # @Author : Eric Ching 5 | from config import _C as cfg 6 | from data import make_data_loaders 7 | from models import build_model 8 | from utils import init_env, mkdir 9 | from solver import make_optimizer 10 | import os 11 | import torch 12 | from utils.logger import setup_logger 13 | from utils.metric_logger import MetricLogger 14 | import logging 15 | import time 16 | from losses import get_losses 17 | from metrics import get_metrics 18 | import shutil 19 | import nibabel as nib 20 | import numpy as np 21 | 22 | def save_sample(batch_pred, batch_x, batch_y, epoch, batch_id): 23 | def get_mask(seg_volume): 24 | seg_volume = seg_volume.cpu().numpy() 25 | seg_volume = np.squeeze(seg_volume) 26 | wt_pred = seg_volume[0] 27 | tc_pred = seg_volume[1] 28 | et_pred = seg_volume[2] 29 | mask = np.zeros_like(wt_pred) 30 | mask[wt_pred > 0.5] = 2 31 | mask[tc_pred > 0.5] = 1 32 | mask[et_pred > 0.5] = 4 33 | mask = mask.astype("uint8") 34 | mask_nii = nib.Nifti1Image(mask, np.eye(4)) 35 | return mask_nii 36 | 37 | volume = batch_x[:, 0].cpu().numpy() 38 | volume = volume.astype("uint8") 39 | volume_nii = nib.Nifti1Image(volume, np.eye(4)) 40 | log_dir = os.path.join(cfg.LOG_DIR, cfg.TASK_NAME, 'epoch'+str(epoch)) 41 | mkdir(log_dir) 42 | nib.save(volume_nii, os.path.join(log_dir, 'batch'+str(batch_id)+'_volume.nii')) 43 | pred_nii = get_mask(batch_pred) 44 | gt_nii = get_mask(batch_y) 45 | nib.save(pred_nii, os.path.join(log_dir, 'batch' + str(batch_id) + '_pred.nii')) 46 | nib.save(gt_nii, os.path.join(log_dir, 'batch' + str(batch_id) + '_gt.nii')) 47 | 48 | def train_val(model, loaders, optimizer, scheduler, losses, metrics=None): 49 | n_epochs = cfg.SOLVER.NUM_EPOCHS 50 | end = time.time() 51 | best_dice = 0.0 52 | for epoch in range(n_epochs): 53 | scheduler.step() 54 | for phase in ['train', 'eval']: 55 | meters = MetricLogger(delimiter=" ") 56 | loader = loaders[phase] 57 | getattr(model, phase)() 58 | logger = logging.getLogger(phase) 59 | total = len(loader) 60 | for batch_id, (batch_x, batch_y) in enumerate(loader): 61 | batch_x, batch_y = batch_x.cuda(async=True), batch_y.cuda(async=True) 62 | with torch.set_grad_enabled(phase == 'train'): 63 | output, vout, mu, logvar = model(batch_x) 64 | loss_dict = losses['dice_vae'](cfg, output, batch_x, batch_y, vout, mu, logvar) 65 | meters.update(**loss_dict) 66 | if phase == 'train': 67 | optimizer.zero_grad() 68 | loss_dict['loss'].backward() 69 | optimizer.step() 70 | else: 71 | if metrics and (epoch + 1) % 20 == 0: 72 | with torch.no_grad(): 73 | hausdorff = metrics['hd'] 74 | metric_dict = hausdorff(output, batch_y) 75 | meters.update(**metric_dict) 76 | save_sample(output, batch_x, batch_y, epoch, batch_id) 77 | logger.info(meters.delimiter.join([f"Epoch: {epoch}, Batch:{batch_id}/{total}", 78 | f"{str(meters)}", 79 | f"Time: {time.time() - end: .3f}" 80 | ])) 81 | end = time.time() 82 | 83 | if phase == 'eval': 84 | dice = 1 - (meters.wt_loss.global_avg + meters.tc_loss.global_avg + meters.et_loss.global_avg) / 3 85 | state = {} 86 | state['model'] = model.state_dict() 87 | state['optimizer'] = optimizer.state_dict() 88 | file_name = os.path.join(cfg.LOG_DIR, cfg.TASK_NAME, 'epoch' + str(epoch) + '.pt') 89 | torch.save(state, file_name) 90 | if dice > best_dice: 91 | best_dice = dice 92 | shutil.copyfile(file_name, os.path.join(cfg.LOG_DIR, cfg.TASK_NAME, 'best_model.pth')) 93 | 94 | return model 95 | 96 | def main(): 97 | init_env('1') 98 | loaders = make_data_loaders(cfg) 99 | model = build_model(cfg) 100 | model = model.cuda() 101 | task_name = 'base_unet' 102 | log_dir = os.path.join(cfg.LOG_DIR, task_name) 103 | cfg.TASK_NAME = task_name 104 | mkdir(log_dir) 105 | logger = setup_logger('train', log_dir, filename='train.log') 106 | logger.info(cfg) 107 | logger = setup_logger('eval', log_dir, filename='eval.log') 108 | optimizer, scheduler = make_optimizer(cfg, model) 109 | metrics = get_metrics(cfg) 110 | losses = get_losses(cfg) 111 | train_val(model, loaders, optimizer, scheduler, losses, metrics) 112 | 113 | if __name__ == "__main__": 114 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 16:17 4 | # @Author : Eric Ching 5 | from .misc import mkdir, init_env -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 16:36 4 | # @Author : Eric Ching 5 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 6 | import logging 7 | import os 8 | import sys 9 | 10 | def setup_logger(name, save_dir, filename="log.txt"): 11 | logger = logging.getLogger(name) 12 | logger.setLevel(logging.DEBUG) 13 | ch = logging.StreamHandler(stream=sys.stdout) 14 | ch.setLevel(logging.DEBUG) 15 | formatter = logging.Formatter("%(asctime)s %(message)s", "%Y-%m-%d %H:%M") 16 | ch.setFormatter(formatter) 17 | logger.addHandler(ch) 18 | 19 | if save_dir: 20 | fh = logging.FileHandler(os.path.join(save_dir, filename)) 21 | fh.setLevel(logging.DEBUG) 22 | fh.setFormatter(formatter) 23 | logger.addHandler(fh) 24 | 25 | return logger -------------------------------------------------------------------------------- /utils/metric/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ===================================== 3 | Metric measures :mod:`metric` 4 | ===================================== 5 | .. currentmodule:: metric 6 | 7 | This package provides a number of metric measures that e.g. can be used for testing 8 | and/or evaluation purposes on two binary masks (i.e. measuring their similarity) or 9 | distance between histograms. 10 | 11 | Binary metrics (:mod:metric.binary`) 12 | =========================================== 13 | Metrics to compare binary objects and classification results. 14 | 15 | Compare two binary objects 16 | ************************** 17 | 18 | .. module:: metric.binary 19 | 20 | .. autosummary:: 21 | :toctree: generated/ 22 | 23 | dc 24 | jc 25 | hd 26 | asd 27 | assd 28 | precision 29 | recall 30 | sensitivity 31 | specificity 32 | true_positive_rate 33 | true_negative_rate 34 | positive_predictive_value 35 | ravd 36 | 37 | Compare two sets of binary objects 38 | ********************************** 39 | 40 | .. autosummary:: 41 | :toctree: generated/ 42 | 43 | obj_tpr 44 | obj_fpr 45 | obj_asd 46 | obj_assd 47 | 48 | Compare to sequences of binary objects 49 | ************************************** 50 | 51 | .. autosummary:: 52 | :toctree: generated/ 53 | 54 | volume_correlation 55 | volume_change_correlation 56 | 57 | Image metrics (:mod:`metric.image`) 58 | ========================================= 59 | Some more image metrics (e.g. `~medpy.filter.image.sls` and `~medpy.filter.image.ssd`) 60 | can be found in :mod:`medpy.filter.image`. 61 | 62 | .. module:: medpy.metric.image 63 | .. autosummary:: 64 | :toctree: generated/ 65 | 66 | mutual_information 67 | 68 | Histogram metrics (:mod:`medpy.metric.histogram`) 69 | ================================================= 70 | 71 | .. module:: medpy.metric.histogram 72 | .. autosummary:: 73 | :toctree: generated/ 74 | 75 | chebyshev 76 | chebyshev_neg 77 | chi_square 78 | correlate 79 | correlate_1 80 | cosine 81 | cosine_1 82 | cosine_2 83 | cosine_alt 84 | euclidean 85 | fidelity_based 86 | histogram_intersection 87 | histogram_intersection_1 88 | jensen_shannon 89 | kullback_leibler 90 | manhattan 91 | minowski 92 | noelle_1 93 | noelle_2 94 | noelle_3 95 | noelle_4 96 | noelle_5 97 | quadratic_forms 98 | relative_bin_deviation 99 | relative_deviation 100 | """ -------------------------------------------------------------------------------- /utils/metric/binary.py: -------------------------------------------------------------------------------- 1 | #encoding: utf-8 2 | """二值图像相关的测量,可以用于单目标图像分割、多目标(object wise)图像分割评价 3 | 指标: 4 | dice coefficient(dc) 5 | jaccard coefficient(jc) 6 | precision 7 | recall 8 | sensitivity 9 | Average surface distance metric(asd) 10 | 95th percentile of the Hausdorff Distance(hd95) 11 | Average symmetric surface distance(assd) 12 | Relative absolute volume difference.(ravd) 13 | """ 14 | import numpy 15 | from scipy.ndimage import _ni_support 16 | from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\ 17 | generate_binary_structure 18 | from scipy.ndimage.measurements import label, find_objects 19 | from scipy.stats import pearsonr 20 | 21 | 22 | def dc(result, reference): 23 | r""" 24 | Dice coefficient 25 | 26 | Computes the Dice coefficient (also known as Sorensen index) between the binary 27 | objects in two images. 28 | 29 | The metric is defined as 30 | 31 | .. math:: 32 | 33 | DC=\frac{2|A\cap B|}{|A|+|B|} 34 | 35 | , where :math:`A` is the first and :math:`B` the second set of samples (here: binary objects). 36 | 37 | Parameters 38 | ---------- 39 | result : array_like 40 | Input data containing objects. Can be any type but will be converted 41 | into binary: background where 0, object everywhere else. 42 | reference : array_like 43 | Input data containing objects. Can be any type but will be converted 44 | into binary: background where 0, object everywhere else. 45 | 46 | Returns 47 | ------- 48 | dc : float 49 | The Dice coefficient between the object(s) in ```result``` and the 50 | object(s) in ```reference```. It ranges from 0 (no overlap) to 1 (perfect overlap). 51 | 52 | Notes 53 | ----- 54 | This is a real metric. The binary images can therefore be supplied in any order. 55 | """ 56 | result = numpy.atleast_1d(result.astype(numpy.bool)) 57 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 58 | 59 | intersection = numpy.count_nonzero(result & reference) 60 | 61 | size_i1 = numpy.count_nonzero(result) 62 | size_i2 = numpy.count_nonzero(reference) 63 | 64 | try: 65 | dc = 2. * intersection / float(size_i1 + size_i2) 66 | except ZeroDivisionError: 67 | dc = 0.0 68 | 69 | return dc 70 | 71 | def jc(result, reference): 72 | """ 73 | Jaccard coefficient 74 | 75 | Computes the Jaccard coefficient between the binary objects in two images. 76 | 77 | Parameters 78 | ---------- 79 | result: array_like 80 | Input data containing objects. Can be any type but will be converted 81 | into binary: background where 0, object everywhere else. 82 | reference: array_like 83 | Input data containing objects. Can be any type but will be converted 84 | into binary: background where 0, object everywhere else. 85 | 86 | Returns 87 | ------- 88 | jc: float 89 | The Jaccard coefficient between the object(s) in `result` and the 90 | object(s) in `reference`. It ranges from 0 (no overlap) to 1 (perfect overlap). 91 | 92 | Notes 93 | ----- 94 | This is a real metric. The binary images can therefore be supplied in any order. 95 | """ 96 | result = numpy.atleast_1d(result.astype(numpy.bool)) 97 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 98 | 99 | intersection = numpy.count_nonzero(result & reference) 100 | union = numpy.count_nonzero(result | reference) 101 | 102 | jc = float(intersection) / float(union) 103 | 104 | return jc 105 | 106 | def precision(result, reference): 107 | """ 108 | Precison. 109 | 110 | Parameters 111 | ---------- 112 | result : array_like 113 | Input data containing objects. Can be any type but will be converted 114 | into binary: background where 0, object everywhere else. 115 | reference : array_like 116 | Input data containing objects. Can be any type but will be converted 117 | into binary: background where 0, object everywhere else. 118 | 119 | Returns 120 | ------- 121 | precision : float 122 | The precision between two binary datasets, here mostly binary objects in images, 123 | which is defined as the fraction of retrieved instances that are relevant. The 124 | precision is not symmetric. 125 | 126 | See also 127 | -------- 128 | :func:`recall` 129 | 130 | Notes 131 | ----- 132 | Not symmetric. The inverse of the precision is :func:`recall`. 133 | High precision means that an algorithm returned substantially more relevant results than irrelevant. 134 | 135 | References 136 | ---------- 137 | .. [1] http://en.wikipedia.org/wiki/Precision_and_recall 138 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion 139 | """ 140 | result = numpy.atleast_1d(result.astype(numpy.bool)) 141 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 142 | 143 | tp = numpy.count_nonzero(result & reference) 144 | fp = numpy.count_nonzero(result & ~reference) 145 | 146 | try: 147 | precision = tp / float(tp + fp) 148 | except ZeroDivisionError: 149 | precision = 0.0 150 | 151 | return precision 152 | 153 | def recall(result, reference): 154 | """ 155 | Recall. 156 | 157 | Parameters 158 | ---------- 159 | result : array_like 160 | Input data containing objects. Can be any type but will be converted 161 | into binary: background where 0, object everywhere else. 162 | reference : array_like 163 | Input data containing objects. Can be any type but will be converted 164 | into binary: background where 0, object everywhere else. 165 | 166 | Returns 167 | ------- 168 | recall : float 169 | The recall between two binary datasets, here mostly binary objects in images, 170 | which is defined as the fraction of relevant instances that are retrieved. The 171 | recall is not symmetric. 172 | 173 | See also 174 | -------- 175 | :func:`precision` 176 | 177 | Notes 178 | ----- 179 | Not symmetric. The inverse of the recall is :func:`precision`. 180 | High recall means that an algorithm returned most of the relevant results. 181 | 182 | References 183 | ---------- 184 | .. [1] http://en.wikipedia.org/wiki/Precision_and_recall 185 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion 186 | """ 187 | result = numpy.atleast_1d(result.astype(numpy.bool)) 188 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 189 | 190 | tp = numpy.count_nonzero(result & reference) 191 | fn = numpy.count_nonzero(~result & reference) 192 | 193 | try: 194 | recall = tp / float(tp + fn) 195 | except ZeroDivisionError: 196 | recall = 0.0 197 | 198 | return recall 199 | 200 | def sensitivity(result, reference): 201 | """ 202 | Sensitivity. 203 | Same as :func:`recall`, see there for a detailed description. 204 | 205 | See also 206 | -------- 207 | :func:`specificity` 208 | """ 209 | return recall(result, reference) 210 | 211 | def specificity(result, reference): 212 | """ 213 | Specificity. 214 | 215 | Parameters 216 | ---------- 217 | result : array_like 218 | Input data containing objects. Can be any type but will be converted 219 | into binary: background where 0, object everywhere else. 220 | reference : array_like 221 | Input data containing objects. Can be any type but will be converted 222 | into binary: background where 0, object everywhere else. 223 | 224 | Returns 225 | ------- 226 | specificity : float 227 | The specificity between two binary datasets, here mostly binary objects in images, 228 | which denotes the fraction of correctly returned negatives. The 229 | specificity is not symmetric. 230 | 231 | See also 232 | -------- 233 | :func:`sensitivity` 234 | 235 | Notes 236 | ----- 237 | Not symmetric. The completment of the specificity is :func:`sensitivity`. 238 | High recall means that an algorithm returned most of the irrelevant results. 239 | 240 | References 241 | ---------- 242 | .. [1] https://en.wikipedia.org/wiki/Sensitivity_and_specificity 243 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion 244 | """ 245 | result = numpy.atleast_1d(result.astype(numpy.bool)) 246 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 247 | 248 | tn = numpy.count_nonzero(~result & ~reference) 249 | fp = numpy.count_nonzero(result & ~reference) 250 | 251 | try: 252 | specificity = tn / float(tn + fp) 253 | except ZeroDivisionError: 254 | specificity = 0.0 255 | 256 | return specificity 257 | 258 | def true_negative_rate(result, reference): 259 | """ 260 | True negative rate. 261 | Same as :func:`specificity`, see there for a detailed description. 262 | 263 | See also 264 | -------- 265 | :func:`true_positive_rate` 266 | :func:`positive_predictive_value` 267 | """ 268 | return specificity(result, reference) 269 | 270 | def true_positive_rate(result, reference): 271 | """ 272 | True positive rate. 273 | Same as :func:`recall` and :func:`sensitivity`, see there for a detailed description. 274 | 275 | See also 276 | -------- 277 | :func:`positive_predictive_value` 278 | :func:`true_negative_rate` 279 | """ 280 | return recall(result, reference) 281 | 282 | def positive_predictive_value(result, reference): 283 | """ 284 | Positive predictive value. 285 | Same as :func:`precision`, see there for a detailed description. 286 | 287 | See also 288 | -------- 289 | :func:`true_positive_rate` 290 | :func:`true_negative_rate` 291 | """ 292 | return precision(result, reference) 293 | 294 | def hd(result, reference, voxelspacing=None, connectivity=1): 295 | """ 296 | Hausdorff Distance. 297 | 298 | Computes the (symmetric) Hausdorff Distance (HD) between the binary objects in two 299 | images. It is defined as the maximum surface distance between the objects. 300 | 301 | Parameters 302 | ---------- 303 | result : array_like 304 | Input data containing objects. Can be any type but will be converted 305 | into binary: background where 0, object everywhere else. 306 | reference : array_like 307 | Input data containing objects. Can be any type but will be converted 308 | into binary: background where 0, object everywhere else. 309 | voxelspacing : float or sequence of floats, optional 310 | The voxelspacing in a distance unit i.e. spacing of elements 311 | along each dimension. If a sequence, must be of length equal to 312 | the input rank; if a single number, this is used for all axes. If 313 | not specified, a grid spacing of unity is implied. 314 | connectivity : int 315 | The neighbourhood/connectivity considered when determining the surface 316 | of the binary objects. This value is passed to 317 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 318 | Note that the connectivity influences the result in the case of the Hausdorff distance. 319 | 320 | Returns 321 | ------- 322 | hd : float 323 | The symmetric Hausdorff Distance between the object(s) in ```result``` and the 324 | object(s) in ```reference```. The distance unit is the same as for the spacing of 325 | elements along each dimension, which is usually given in mm. 326 | 327 | See also 328 | -------- 329 | :func:`assd` 330 | :func:`asd` 331 | 332 | Notes 333 | ----- 334 | This is a real metric. The binary images can therefore be supplied in any order. 335 | """ 336 | hd1 = __surface_distances(result, reference, voxelspacing, connectivity).max() 337 | hd2 = __surface_distances(reference, result, voxelspacing, connectivity).max() 338 | hd = max(hd1, hd2) 339 | return hd 340 | 341 | 342 | def hd95(result, reference, voxelspacing=None, connectivity=1): 343 | """ 344 | 95th percentile of the Hausdorff Distance. 345 | 346 | Computes the 95th percentile of the (symmetric) Hausdorff Distance (HD) between the binary objects in two 347 | images. Compared to the Hausdorff Distance, this metric is slightly more stable to small outliers and is 348 | commonly used in Biomedical Segmentation challenges. 349 | 350 | Parameters 351 | ---------- 352 | result : array_like 353 | Input data containing objects. Can be any type but will be converted 354 | into binary: background where 0, object everywhere else. 355 | reference : array_like 356 | Input data containing objects. Can be any type but will be converted 357 | into binary: background where 0, object everywhere else. 358 | voxelspacing : float or sequence of floats, optional 359 | The voxelspacing in a distance unit i.e. spacing of elements 360 | along each dimension. If a sequence, must be of length equal to 361 | the input rank; if a single number, this is used for all axes. If 362 | not specified, a grid spacing of unity is implied. 363 | connectivity : int 364 | The neighbourhood/connectivity considered when determining the surface 365 | of the binary objects. This value is passed to 366 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 367 | Note that the connectivity influences the result in the case of the Hausdorff distance. 368 | 369 | Returns 370 | ------- 371 | hd : float 372 | The symmetric Hausdorff Distance between the object(s) in ```result``` and the 373 | object(s) in ```reference```. The distance unit is the same as for the spacing of 374 | elements along each dimension, which is usually given in mm. 375 | 376 | See also 377 | -------- 378 | :func:`hd` 379 | 380 | Notes 381 | ----- 382 | This is a real metric. The binary images can therefore be supplied in any order. 383 | """ 384 | hd1 = __surface_distances(result, reference, voxelspacing, connectivity) 385 | hd2 = __surface_distances(reference, result, voxelspacing, connectivity) 386 | hd95 = numpy.percentile(numpy.hstack((hd1, hd2)), 95) 387 | return hd95 388 | 389 | 390 | def assd(result, reference, voxelspacing=None, connectivity=1): 391 | """ 392 | Average symmetric surface distance. 393 | 394 | Computes the average symmetric surface distance (ASD) between the binary objects in 395 | two images. 396 | 397 | Parameters 398 | ---------- 399 | result : array_like 400 | Input data containing objects. Can be any type but will be converted 401 | into binary: background where 0, object everywhere else. 402 | reference : array_like 403 | Input data containing objects. Can be any type but will be converted 404 | into binary: background where 0, object everywhere else. 405 | voxelspacing : float or sequence of floats, optional 406 | The voxelspacing in a distance unit i.e. spacing of elements 407 | along each dimension. If a sequence, must be of length equal to 408 | the input rank; if a single number, this is used for all axes. If 409 | not specified, a grid spacing of unity is implied. 410 | connectivity : int 411 | The neighbourhood/connectivity considered when determining the surface 412 | of the binary objects. This value is passed to 413 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 414 | The decision on the connectivity is important, as it can influence the results 415 | strongly. If in doubt, leave it as it is. 416 | 417 | Returns 418 | ------- 419 | assd : float 420 | The average symmetric surface distance between the object(s) in ``result`` and the 421 | object(s) in ``reference``. The distance unit is the same as for the spacing of 422 | elements along each dimension, which is usually given in mm. 423 | 424 | See also 425 | -------- 426 | :func:`asd` 427 | :func:`hd` 428 | 429 | Notes 430 | ----- 431 | This is a real metric, obtained by calling and averaging 432 | 433 | >>> asd(result, reference) 434 | 435 | and 436 | 437 | >>> asd(reference, result) 438 | 439 | The binary images can therefore be supplied in any order. 440 | """ 441 | assd = numpy.mean( (asd(result, reference, voxelspacing, connectivity), asd(reference, result, voxelspacing, connectivity)) ) 442 | return assd 443 | 444 | def asd(result, reference, voxelspacing=None, connectivity=1): 445 | """ 446 | Average surface distance metric. 447 | 448 | Computes the average surface distance (ASD) between the binary objects in two images. 449 | 450 | Parameters 451 | ---------- 452 | result : array_like 453 | Input data containing objects. Can be any type but will be converted 454 | into binary: background where 0, object everywhere else. 455 | reference : array_like 456 | Input data containing objects. Can be any type but will be converted 457 | into binary: background where 0, object everywhere else. 458 | voxelspacing : float or sequence of floats, optional 459 | The voxelspacing in a distance unit i.e. spacing of elements 460 | along each dimension. If a sequence, must be of length equal to 461 | the input rank; if a single number, this is used for all axes. If 462 | not specified, a grid spacing of unity is implied. 463 | connectivity : int 464 | The neighbourhood/connectivity considered when determining the surface 465 | of the binary objects. This value is passed to 466 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 467 | The decision on the connectivity is important, as it can influence the results 468 | strongly. If in doubt, leave it as it is. 469 | 470 | Returns 471 | ------- 472 | asd : float 473 | The average surface distance between the object(s) in ``result`` and the 474 | object(s) in ``reference``. The distance unit is the same as for the spacing 475 | of elements along each dimension, which is usually given in mm. 476 | 477 | See also 478 | -------- 479 | :func:`assd` 480 | :func:`hd` 481 | 482 | 483 | Notes 484 | ----- 485 | This is not a real metric, as it is directed. See `assd` for a real metric of this. 486 | 487 | The method is implemented making use of distance images and simple binary morphology 488 | to achieve high computational speed. 489 | 490 | Examples 491 | -------- 492 | The `connectivity` determines what pixels/voxels are considered the surface of a 493 | binary object. Take the following binary image showing a cross 494 | 495 | >>> from scipy.ndimage.morphology import generate_binary_structure 496 | >>> cross = generate_binary_structure(2, 1) 497 | array([[0, 1, 0], 498 | [1, 1, 1], 499 | [0, 1, 0]]) 500 | 501 | With `connectivity` set to `1` a 4-neighbourhood is considered when determining the 502 | object surface, resulting in the surface 503 | 504 | .. code-block:: python 505 | 506 | array([[0, 1, 0], 507 | [1, 0, 1], 508 | [0, 1, 0]]) 509 | 510 | Changing `connectivity` to `2`, a 8-neighbourhood is considered and we get: 511 | 512 | .. code-block:: python 513 | 514 | array([[0, 1, 0], 515 | [1, 1, 1], 516 | [0, 1, 0]]) 517 | 518 | , as a diagonal connection does no longer qualifies as valid object surface. 519 | 520 | This influences the results `asd` returns. Imagine we want to compute the surface 521 | distance of our cross to a cube-like object: 522 | 523 | >>> cube = generate_binary_structure(2, 1) 524 | array([[1, 1, 1], 525 | [1, 1, 1], 526 | [1, 1, 1]]) 527 | 528 | , which surface is, independent of the `connectivity` value set, always 529 | 530 | .. code-block:: python 531 | 532 | array([[1, 1, 1], 533 | [1, 0, 1], 534 | [1, 1, 1]]) 535 | 536 | Using a `connectivity` of `1` we get 537 | 538 | >>> asd(cross, cube, connectivity=1) 539 | 0.0 540 | 541 | while a value of `2` returns us 542 | 543 | >>> asd(cross, cube, connectivity=2) 544 | 0.20000000000000001 545 | 546 | due to the center of the cross being considered surface as well. 547 | 548 | """ 549 | sds = __surface_distances(result, reference, voxelspacing, connectivity) 550 | asd = sds.mean() 551 | return asd 552 | 553 | def ravd(result, reference): 554 | """ 555 | Relative absolute volume difference. 556 | 557 | Compute the relative absolute volume difference between the (joined) binary objects 558 | in the two images. 559 | 560 | Parameters 561 | ---------- 562 | result : array_like 563 | Input data containing objects. Can be any type but will be converted 564 | into binary: background where 0, object everywhere else. 565 | reference : array_like 566 | Input data containing objects. Can be any type but will be converted 567 | into binary: background where 0, object everywhere else. 568 | 569 | Returns 570 | ------- 571 | ravd : float 572 | The relative absolute volume difference between the object(s) in ``result`` 573 | and the object(s) in ``reference``. This is a percentage value in the range 574 | :math:`[-1.0, +inf]` for which a :math:`0` denotes an ideal score. 575 | 576 | Raises 577 | ------ 578 | RuntimeError 579 | If the reference object is empty. 580 | 581 | See also 582 | -------- 583 | :func:`dc` 584 | :func:`precision` 585 | :func:`recall` 586 | 587 | Notes 588 | ----- 589 | This is not a real metric, as it is directed. Negative values denote a smaller 590 | and positive values a larger volume than the reference. 591 | This implementation does not check, whether the two supplied arrays are of the same 592 | size. 593 | 594 | Examples 595 | -------- 596 | Considering the following inputs 597 | 598 | >>> import numpy 599 | >>> arr1 = numpy.asarray([[0,1,0],[1,1,1],[0,1,0]]) 600 | >>> arr1 601 | array([[0, 1, 0], 602 | [1, 1, 1], 603 | [0, 1, 0]]) 604 | >>> arr2 = numpy.asarray([[0,1,0],[1,0,1],[0,1,0]]) 605 | >>> arr2 606 | array([[0, 1, 0], 607 | [1, 0, 1], 608 | [0, 1, 0]]) 609 | 610 | comparing `arr1` to `arr2` we get 611 | 612 | >>> ravd(arr1, arr2) 613 | -0.2 614 | 615 | and reversing the inputs the directivness of the metric becomes evident 616 | 617 | >>> ravd(arr2, arr1) 618 | 0.25 619 | 620 | It is important to keep in mind that a perfect score of `0` does not mean that the 621 | binary objects fit exactely, as only the volumes are compared: 622 | 623 | >>> arr1 = numpy.asarray([1,0,0]) 624 | >>> arr2 = numpy.asarray([0,0,1]) 625 | >>> ravd(arr1, arr2) 626 | 0.0 627 | 628 | """ 629 | result = numpy.atleast_1d(result.astype(numpy.bool)) 630 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 631 | 632 | vol1 = numpy.count_nonzero(result) 633 | vol2 = numpy.count_nonzero(reference) 634 | 635 | if 0 == vol2: 636 | raise RuntimeError('The second supplied array does not contain any binary object.') 637 | 638 | return (vol1 - vol2) / float(vol2) 639 | 640 | 641 | def volume_correlation(results, references): 642 | r""" 643 | Volume correlation. 644 | 645 | Computes the linear correlation in binary object volume between the 646 | contents of the successive binary images supplied. Measured through 647 | the Pearson product-moment correlation coefficient. 648 | 649 | Parameters 650 | ---------- 651 | results : sequence of array_like 652 | Ordered list of input data containing objects. Each array_like will be 653 | converted into binary: background where 0, object everywhere else. 654 | references : sequence of array_like 655 | Ordered list of input data containing objects. Each array_like will be 656 | converted into binary: background where 0, object everywhere else. 657 | The order must be the same as for ``results``. 658 | 659 | Returns 660 | ------- 661 | r : float 662 | The correlation coefficient between -1 and 1. 663 | p : float 664 | The two-side p value. 665 | 666 | """ 667 | results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool)) 668 | references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool)) 669 | 670 | results_volumes = [numpy.count_nonzero(r) for r in results] 671 | references_volumes = [numpy.count_nonzero(r) for r in references] 672 | 673 | return pearsonr(results_volumes, references_volumes) # returns (Pearson' 674 | 675 | def volume_change_correlation(results, references): 676 | r"""计算二值图像的体素改变的相关系数 677 | Volume change correlation. 678 | 679 | Computes the linear correlation of change in binary object volume between 680 | the contents of the successive binary images supplied. Measured through 681 | the Pearson product-moment correlation coefficient. 682 | 683 | Parameters 684 | ---------- 685 | results : sequence of array_like 686 | Ordered list of input data containing objects. Each array_like will be 687 | converted into binary: background where 0, object everywhere else. 688 | references : sequence of array_like 689 | Ordered list of input data containing objects. Each array_like will be 690 | converted into binary: background where 0, object everywhere else. 691 | The order must be the same as for ``results``. 692 | 693 | Returns 694 | ------- 695 | r : float 696 | The correlation coefficient between -1 and 1. 697 | p : float 698 | The two-side p value. 699 | 700 | """ 701 | # 转化为二维 702 | results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool)) 703 | references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool)) 704 | 705 | results_volumes = numpy.asarray([numpy.count_nonzero(r) for r in results]) 706 | references_volumes = numpy.asarray([numpy.count_nonzero(r) for r in references]) 707 | 708 | results_volumes_changes = results_volumes[1:] - results_volumes[:-1] 709 | references_volumes_changes = references_volumes[1:] - references_volumes[:-1] 710 | 711 | return pearsonr(results_volumes_changes, references_volumes_changes) # returns (Pearson's correlation coefficient, 2-tailed p-value) 712 | 713 | def obj_assd(result, reference, voxelspacing=None, connectivity=1): 714 | """ 715 | Average symmetric surface distance. 716 | 717 | Computes the average symmetric surface distance (ASSD) between the binary objects in 718 | two images. 719 | 720 | Parameters 721 | ---------- 722 | result : array_like 723 | Input data containing objects. Can be any type but will be converted 724 | into binary: background where 0, object everywhere else. 725 | reference : array_like 726 | Input data containing objects. Can be any type but will be converted 727 | into binary: background where 0, object everywhere else. 728 | voxelspacing : float or sequence of floats, optional 729 | The voxelspacing in a distance unit i.e. spacing of elements 730 | along each dimension. If a sequence, must be of length equal to 731 | the input rank; if a single number, this is used for all axes. If 732 | not specified, a grid spacing of unity is implied. 733 | connectivity : int 734 | The neighbourhood/connectivity considered when determining what accounts 735 | for a distinct binary object as well as when determining the surface 736 | of the binary objects. This value is passed to 737 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 738 | The decision on the connectivity is important, as it can influence the results 739 | strongly. If in doubt, leave it as it is. 740 | 741 | Returns 742 | ------- 743 | assd : float 744 | The average symmetric surface distance between all mutually existing distinct 745 | binary object(s) in ``result`` and ``reference``. The distance unit is the same as for 746 | the spacing of elements along each dimension, which is usually given in mm. 747 | 748 | See also 749 | -------- 750 | :func:`obj_asd` 751 | 752 | Notes 753 | ----- 754 | This is a real metric, obtained by calling and averaging 755 | 756 | >>> obj_asd(result, reference) 757 | 758 | and 759 | 760 | >>> obj_asd(reference, result) 761 | 762 | The binary images can therefore be supplied in any order. 763 | """ 764 | assd = numpy.mean( (obj_asd(result, reference, voxelspacing, connectivity), obj_asd(reference, result, voxelspacing, connectivity)) ) 765 | 766 | return assd 767 | 768 | 769 | def obj_asd(result, reference, voxelspacing=None, connectivity=1): 770 | """ 771 | Average surface distance between objects. 772 | 773 | First correspondences between distinct binary objects in reference and result are 774 | established. Then the average surface distance is only computed between corresponding 775 | objects. Correspondence is defined as unique and at least one voxel overlap. 776 | 777 | Parameters 778 | ---------- 779 | result : array_like 780 | Input data containing objects. Can be any type but will be converted 781 | into binary: background where 0, object everywhere else. 782 | reference : array_like 783 | Input data containing objects. Can be any type but will be converted 784 | into binary: background where 0, object everywhere else. 785 | voxelspacing : float or sequence of floats, optional 786 | The voxelspacing in a distance unit i.e. spacing of elements 787 | along each dimension. If a sequence, must be of length equal to 788 | the input rank; if a single number, this is used for all axes. If 789 | not specified, a grid spacing of unity is implied. 790 | connectivity : int 791 | The neighbourhood/connectivity considered when determining what accounts 792 | for a distinct binary object as well as when determining the surface 793 | of the binary objects. This value is passed to 794 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 795 | The decision on the connectivity is important, as it can influence the results 796 | strongly. If in doubt, leave it as it is. 797 | 798 | Returns 799 | ------- 800 | asd : float 801 | The average surface distance between all mutually existing distinct binary 802 | object(s) in ``result`` and ``reference``. The distance unit is the same as for the 803 | spacing of elements along each dimension, which is usually given in mm. 804 | 805 | See also 806 | -------- 807 | :func:`obj_assd` 808 | :func:`obj_tpr` 809 | :func:`obj_fpr` 810 | 811 | Notes 812 | ----- 813 | This is not a real metric, as it is directed. See `obj_assd` for a real metric of this. 814 | 815 | For the understanding of this metric, both the notions of connectedness and surface 816 | distance are essential. Please see :func:`obj_tpr` and :func:`obj_fpr` for more 817 | information on the first and :func:`asd` on the second. 818 | 819 | Examples 820 | -------- 821 | >>> arr1 = numpy.asarray([[1,1,1],[1,1,1],[1,1,1]]) 822 | >>> arr2 = numpy.asarray([[0,1,0],[0,1,0],[0,1,0]]) 823 | >>> arr1 824 | array([[1, 1, 1], 825 | [1, 1, 1], 826 | [1, 1, 1]]) 827 | >>> arr2 828 | array([[0, 1, 0], 829 | [0, 1, 0], 830 | [0, 1, 0]]) 831 | >>> obj_asd(arr1, arr2) 832 | 1.5 833 | >>> obj_asd(arr2, arr1) 834 | 0.333333333333 835 | 836 | With the `voxelspacing` parameter, the distances between the voxels can be set for 837 | each dimension separately: 838 | 839 | >>> obj_asd(arr1, arr2, voxelspacing=(1,2)) 840 | 1.5 841 | >>> obj_asd(arr2, arr1, voxelspacing=(1,2)) 842 | 0.333333333333 843 | 844 | More examples depicting the notion of object connectedness: 845 | 846 | >>> arr1 = numpy.asarray([[1,0,1],[1,0,0],[0,0,0]]) 847 | >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]]) 848 | >>> arr1 849 | array([[1, 0, 1], 850 | [1, 0, 0], 851 | [0, 0, 0]]) 852 | >>> arr2 853 | array([[1, 0, 1], 854 | [1, 0, 0], 855 | [0, 0, 1]]) 856 | >>> obj_asd(arr1, arr2) 857 | 0.0 858 | >>> obj_asd(arr2, arr1) 859 | 0.0 860 | 861 | >>> arr1 = numpy.asarray([[1,0,1],[1,0,1],[0,0,1]]) 862 | >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]]) 863 | >>> arr1 864 | array([[1, 0, 1], 865 | [1, 0, 1], 866 | [0, 0, 1]]) 867 | >>> arr2 868 | array([[1, 0, 1], 869 | [1, 0, 0], 870 | [0, 0, 1]]) 871 | >>> obj_asd(arr1, arr2) 872 | 0.6 873 | >>> obj_asd(arr2, arr1) 874 | 0.0 875 | 876 | Influence of `connectivity` parameter can be seen in the following example, where 877 | with the (default) connectivity of `1` the first array is considered to contain two 878 | objects, while with an increase connectivity of `2`, just one large object is 879 | detected. 880 | 881 | >>> arr1 = numpy.asarray([[1,0,0],[0,1,1],[0,1,1]]) 882 | >>> arr2 = numpy.asarray([[1,0,0],[0,0,0],[0,0,0]]) 883 | >>> arr1 884 | array([[1, 0, 0], 885 | [0, 1, 1], 886 | [0, 1, 1]]) 887 | >>> arr2 888 | array([[1, 0, 0], 889 | [0, 0, 0], 890 | [0, 0, 0]]) 891 | >>> obj_asd(arr1, arr2) 892 | 0.0 893 | >>> obj_asd(arr1, arr2, connectivity=2) 894 | 1.742955328 895 | 896 | Note that the connectivity also influence the notion of what is considered an object 897 | surface voxels. 898 | """ 899 | sds = list() 900 | labelmap1, labelmap2, _a, _b, mapping = __distinct_binary_object_correspondences(result, reference, connectivity) 901 | slicers1 = find_objects(labelmap1) 902 | slicers2 = find_objects(labelmap2) 903 | for lid2, lid1 in list(mapping.items()): 904 | window = __combine_windows(slicers1[lid1 - 1], slicers2[lid2 - 1]) 905 | object1 = labelmap1[window] == lid1 906 | object2 = labelmap2[window] == lid2 907 | sds.extend(__surface_distances(object1, object2, voxelspacing, connectivity)) 908 | asd = numpy.mean(sds) 909 | 910 | return asd 911 | 912 | def obj_fpr(result, reference, connectivity=1): 913 | """ 914 | The false positive rate of distinct binary object detection. 915 | 916 | The false positive rates gives a percentage measure of how many distinct binary 917 | objects in the second array do not exists in the first array. A partial overlap 918 | (of minimum one voxel) is here considered sufficient. 919 | 920 | In cases where two distinct binary object in the second array overlap with a single 921 | distinct object in the first array, only one is considered to have been detected 922 | successfully and the other is added to the count of false positives. 923 | 924 | Parameters 925 | ---------- 926 | result : array_like 927 | Input data containing objects. Can be any type but will be converted 928 | into binary: background where 0, object everywhere else. 929 | reference : array_like 930 | Input data containing objects. Can be any type but will be converted 931 | into binary: background where 0, object everywhere else. 932 | connectivity : int 933 | The neighbourhood/connectivity considered when determining what accounts 934 | for a distinct binary object. This value is passed to 935 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 936 | The decision on the connectivity is important, as it can influence the results 937 | strongly. If in doubt, leave it as it is. 938 | 939 | Returns 940 | ------- 941 | tpr : float 942 | A percentage measure of how many distinct binary objects in ``results`` have no 943 | corresponding binary object in ``reference``. It has the range :math:`[0, 1]`, where a :math:`0` 944 | denotes an ideal score. 945 | 946 | Raises 947 | ------ 948 | RuntimeError 949 | If the second array is empty. 950 | 951 | See also 952 | -------- 953 | :func:`obj_tpr` 954 | 955 | Notes 956 | ----- 957 | This is not a real metric, as it is directed. Whatever array is considered as 958 | reference should be passed second. A perfect score of :math:`0` tells that there are no 959 | distinct binary objects in the second array that do not exists also in the reference 960 | array, but does not reveal anything about objects in the reference array also 961 | existing in the second array (use :func:`obj_tpr` for this). 962 | 963 | Examples 964 | -------- 965 | >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]]) 966 | >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]]) 967 | >>> arr2 968 | array([[1, 0, 0], 969 | [1, 0, 1], 970 | [0, 0, 1]]) 971 | >>> arr1 972 | array([[0, 0, 1], 973 | [1, 0, 1], 974 | [0, 0, 1]]) 975 | >>> obj_fpr(arr1, arr2) 976 | 0.0 977 | >>> obj_fpr(arr2, arr1) 978 | 0.0 979 | 980 | Example of directedness: 981 | 982 | >>> arr2 = numpy.asarray([1,0,1,0,1]) 983 | >>> arr1 = numpy.asarray([1,0,1,0,0]) 984 | >>> obj_fpr(arr1, arr2) 985 | 0.0 986 | >>> obj_fpr(arr2, arr1) 987 | 0.3333333333333333 988 | 989 | Examples of multiple overlap treatment: 990 | 991 | >>> arr2 = numpy.asarray([1,0,1,0,1,1,1]) 992 | >>> arr1 = numpy.asarray([1,1,1,0,1,0,1]) 993 | >>> obj_fpr(arr1, arr2) 994 | 0.3333333333333333 995 | >>> obj_fpr(arr2, arr1) 996 | 0.3333333333333333 997 | 998 | >>> arr2 = numpy.asarray([1,0,1,1,1,0,1]) 999 | >>> arr1 = numpy.asarray([1,1,1,0,1,1,1]) 1000 | >>> obj_fpr(arr1, arr2) 1001 | 0.0 1002 | >>> obj_fpr(arr2, arr1) 1003 | 0.3333333333333333 1004 | 1005 | >>> arr2 = numpy.asarray([[1,0,1,0,0], 1006 | [1,0,0,0,0], 1007 | [1,0,1,1,1], 1008 | [0,0,0,0,0], 1009 | [1,0,1,0,0]]) 1010 | >>> arr1 = numpy.asarray([[1,1,1,0,0], 1011 | [0,0,0,0,0], 1012 | [1,1,1,0,1], 1013 | [0,0,0,0,0], 1014 | [1,1,1,0,0]]) 1015 | >>> obj_fpr(arr1, arr2) 1016 | 0.0 1017 | >>> obj_fpr(arr2, arr1) 1018 | 0.2 1019 | """ 1020 | _, _, _, n_obj_reference, mapping = __distinct_binary_object_correspondences(reference, result, connectivity) 1021 | return (n_obj_reference - len(mapping)) / float(n_obj_reference) 1022 | 1023 | def obj_tpr(result, reference, connectivity=1): 1024 | """ 1025 | The true positive rate of distinct binary object detection. 1026 | 1027 | The true positive rates gives a percentage measure of how many distinct binary 1028 | objects in the first array also exists in the second array. A partial overlap 1029 | (of minimum one voxel) is here considered sufficient. 1030 | 1031 | In cases where two distinct binary object in the first array overlaps with a single 1032 | distinct object in the second array, only one is considered to have been detected 1033 | successfully. 1034 | 1035 | Parameters 1036 | ---------- 1037 | result : array_like 1038 | Input data containing objects. Can be any type but will be converted 1039 | into binary: background where 0, object everywhere else. 1040 | reference : array_like 1041 | Input data containing objects. Can be any type but will be converted 1042 | into binary: background where 0, object everywhere else. 1043 | connectivity : int 1044 | The neighbourhood/connectivity considered when determining what accounts 1045 | for a distinct binary object. This value is passed to 1046 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 1047 | The decision on the connectivity is important, as it can influence the results 1048 | strongly. If in doubt, leave it as it is. 1049 | 1050 | Returns 1051 | ------- 1052 | tpr : float 1053 | A percentage measure of how many distinct binary objects in ``result`` also exists 1054 | in ``reference``. It has the range :math:`[0, 1]`, where a :math:`1` denotes an ideal score. 1055 | 1056 | Raises 1057 | ------ 1058 | RuntimeError 1059 | If the reference object is empty. 1060 | 1061 | See also 1062 | -------- 1063 | :func:`obj_fpr` 1064 | 1065 | Notes 1066 | ----- 1067 | This is not a real metric, as it is directed. Whatever array is considered as 1068 | reference should be passed second. A perfect score of :math:`1` tells that all distinct 1069 | binary objects in the reference array also exist in the result array, but does not 1070 | reveal anything about additional binary objects in the result array 1071 | (use :func:`obj_fpr` for this). 1072 | 1073 | Examples 1074 | -------- 1075 | >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]]) 1076 | >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]]) 1077 | >>> arr2 1078 | array([[1, 0, 0], 1079 | [1, 0, 1], 1080 | [0, 0, 1]]) 1081 | >>> arr1 1082 | array([[0, 0, 1], 1083 | [1, 0, 1], 1084 | [0, 0, 1]]) 1085 | >>> obj_tpr(arr1, arr2) 1086 | 1.0 1087 | >>> obj_tpr(arr2, arr1) 1088 | 1.0 1089 | 1090 | Example of directedness: 1091 | 1092 | >>> arr2 = numpy.asarray([1,0,1,0,1]) 1093 | >>> arr1 = numpy.asarray([1,0,1,0,0]) 1094 | >>> obj_tpr(arr1, arr2) 1095 | 0.6666666666666666 1096 | >>> obj_tpr(arr2, arr1) 1097 | 1.0 1098 | 1099 | Examples of multiple overlap treatment: 1100 | 1101 | >>> arr2 = numpy.asarray([1,0,1,0,1,1,1]) 1102 | >>> arr1 = numpy.asarray([1,1,1,0,1,0,1]) 1103 | >>> obj_tpr(arr1, arr2) 1104 | 0.6666666666666666 1105 | >>> obj_tpr(arr2, arr1) 1106 | 0.6666666666666666 1107 | 1108 | >>> arr2 = numpy.asarray([1,0,1,1,1,0,1]) 1109 | >>> arr1 = numpy.asarray([1,1,1,0,1,1,1]) 1110 | >>> obj_tpr(arr1, arr2) 1111 | 0.6666666666666666 1112 | >>> obj_tpr(arr2, arr1) 1113 | 1.0 1114 | 1115 | >>> arr2 = numpy.asarray([[1,0,1,0,0], 1116 | [1,0,0,0,0], 1117 | [1,0,1,1,1], 1118 | [0,0,0,0,0], 1119 | [1,0,1,0,0]]) 1120 | >>> arr1 = numpy.asarray([[1,1,1,0,0], 1121 | [0,0,0,0,0], 1122 | [1,1,1,0,1], 1123 | [0,0,0,0,0], 1124 | [1,1,1,0,0]]) 1125 | >>> obj_tpr(arr1, arr2) 1126 | 0.8 1127 | >>> obj_tpr(arr2, arr1) 1128 | 1.0 1129 | """ 1130 | _, _, n_obj_result, _, mapping = __distinct_binary_object_correspondences(reference, result, connectivity) 1131 | return len(mapping) / float(n_obj_result) 1132 | 1133 | 1134 | def __distinct_binary_object_correspondences(reference, result, connectivity=1): 1135 | """ 1136 | Determines all distinct (where connectivity is defined by the connectivity parameter 1137 | passed to scipy's `generate_binary_structure`) binary objects in both of the input 1138 | parameters and returns a 1to1 mapping from the labelled objects in reference to the 1139 | corresponding (whereas a one-voxel overlap suffices for correspondence) objects in 1140 | result. 1141 | 1142 | All stems from the problem, that the relationship is non-surjective many-to-many. 1143 | 1144 | @return (labelmap1, labelmap2, n_lables1, n_labels2, labelmapping2to1) 1145 | """ 1146 | result = numpy.atleast_1d(result.astype(numpy.bool)) 1147 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 1148 | 1149 | # binary structure 1150 | footprint = generate_binary_structure(result.ndim, connectivity) 1151 | 1152 | # label distinct binary objects 1153 | labelmap1, n_obj_result = label(result, footprint) 1154 | labelmap2, n_obj_reference = label(reference, footprint) 1155 | 1156 | # find all overlaps from labelmap2 to labelmap1; collect one-to-one relationships and store all one-two-many for later processing 1157 | slicers = find_objects(labelmap2) # get windows of labelled objects 1158 | mapping = dict() # mappings from labels in labelmap2 to corresponding object labels in labelmap1 1159 | used_labels = set() # set to collect all already used labels from labelmap2 1160 | one_to_many = list() # list to collect all one-to-many mappings 1161 | for l1id, slicer in enumerate(slicers): # iterate over object in labelmap2 and their windows 1162 | l1id += 1 # labelled objects have ids sarting from 1 1163 | bobj = (l1id) == labelmap2[slicer] # find binary object corresponding to the label1 id in the segmentation 1164 | l2ids = numpy.unique(labelmap1[slicer][bobj]) # extract all unique object identifiers at the corresponding positions in the reference (i.e. the mapping) 1165 | l2ids = l2ids[0 != l2ids] # remove background identifiers (=0) 1166 | if 1 == len(l2ids): # one-to-one mapping: if target label not already used, add to final list of object-to-object mappings and mark target label as used 1167 | l2id = l2ids[0] 1168 | if not l2id in used_labels: 1169 | mapping[l1id] = l2id 1170 | used_labels.add(l2id) 1171 | elif 1 < len(l2ids): # one-to-many mapping: store relationship for later processing 1172 | one_to_many.append((l1id, set(l2ids))) 1173 | 1174 | # process one-to-many mappings, always choosing the one with the least labelmap2 correspondences first 1175 | while True: 1176 | one_to_many = [(l1id, l2ids - used_labels) for l1id, l2ids in one_to_many] # remove already used ids from all sets 1177 | one_to_many = [x for x in one_to_many if x[1]] # remove empty sets 1178 | one_to_many = sorted(one_to_many, key=lambda x: len(x[1])) # sort by set length 1179 | if 0 == len(one_to_many): 1180 | break 1181 | l2id = one_to_many[0][1].pop() # select an arbitrary target label id from the shortest set 1182 | mapping[one_to_many[0][0]] = l2id # add to one-to-one mappings 1183 | used_labels.add(l2id) # mark target label as used 1184 | one_to_many = one_to_many[1:] # delete the processed set from all sets 1185 | 1186 | return labelmap1, labelmap2, n_obj_result, n_obj_reference, mapping 1187 | 1188 | 1189 | def __surface_distances(result, reference, voxelspacing=None, connectivity=1): 1190 | """ 1191 | The distances between the surface voxel of binary objects in result and their 1192 | nearest partner surface voxel of a binary object in reference. 1193 | """ 1194 | result = numpy.atleast_1d(result.astype(numpy.bool)) 1195 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 1196 | if voxelspacing is not None: 1197 | voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim) 1198 | voxelspacing = numpy.asarray(voxelspacing, dtype=numpy.float64) 1199 | if not voxelspacing.flags.contiguous: 1200 | voxelspacing = voxelspacing.copy() 1201 | 1202 | # binary structure 1203 | footprint = generate_binary_structure(result.ndim, connectivity) 1204 | 1205 | # test for emptiness 1206 | if 0 == numpy.count_nonzero(result): 1207 | raise RuntimeError('The first supplied array does not contain any binary object.') 1208 | if 0 == numpy.count_nonzero(reference): 1209 | raise RuntimeError('The second supplied array does not contain any binary object.') 1210 | 1211 | # extract only 1-pixel border line of objects 1212 | result_border = result ^ binary_erosion(result, structure=footprint, iterations=1) 1213 | reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1) 1214 | 1215 | # compute average surface distance 1216 | # Note: scipys distance transform is calculated only inside the borders of the 1217 | # foreground objects, therefore the input has to be reversed 1218 | dt = distance_transform_edt(~reference_border, sampling=voxelspacing) 1219 | sds = dt[result_border] 1220 | 1221 | return sds 1222 | 1223 | 1224 | def __combine_windows(w1, w2): 1225 | """ 1226 | Joins two windows (defined by tuple of slices) such that their maximum 1227 | combined extend is covered by the new returned window. 1228 | """ 1229 | res = [] 1230 | for s1, s2 in zip(w1, w2): 1231 | res.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop))) 1232 | 1233 | return tuple(res) 1234 | -------------------------------------------------------------------------------- /utils/metric/histogram.py: -------------------------------------------------------------------------------- 1 | # build-in modules 2 | import math 3 | 4 | # third-party modules 5 | import scipy 6 | 7 | # own modules 8 | 9 | # ////////////////////////////// # 10 | # Bin-by-bin comparison measures # 11 | # ////////////////////////////// # 12 | 13 | def minowski(h1, h2, p = 2): # 46..45..14,11..43..44 / 45 us for p=int(-inf..-24..-1,1..24..inf) / float @array, +20 us @list \w 100 bins 14 | r""" 15 | Minowski distance. 16 | 17 | With :math:`p=2` equal to the Euclidean distance, with :math:`p=1` equal to the Manhattan distance, 18 | and the Chebyshev distance implementation represents the case of :math:`p=\pm inf`. 19 | 20 | The Minowksi distance between two histograms :math:`H` and :math:`H'` of size :math:`m` is 21 | defined as: 22 | 23 | .. math:: 24 | 25 | d_p(H, H') = \left(\sum_{m=1}^M|H_m - H'_m|^p 26 | \right)^{\frac{1}{p}} 27 | 28 | *Attributes:* 29 | 30 | - a real metric 31 | 32 | *Attributes for normalized histograms:* 33 | 34 | - :math:`d(H, H')\in[0, \sqrt[p]{2}]` 35 | - :math:`d(H, H) = 0` 36 | - :math:`d(H, H') = d(H', H)` 37 | 38 | *Attributes for not-normalized histograms:* 39 | 40 | - :math:`d(H, H')\in[0, \infty)` 41 | - :math:`d(H, H) = 0` 42 | - :math:`d(H, H') = d(H', H)` 43 | 44 | *Attributes for not-equal histograms:* 45 | 46 | - not applicable 47 | 48 | Parameters 49 | ---------- 50 | h1 : sequence 51 | The first histogram. 52 | h2 : sequence 53 | The second histogram. 54 | p : float 55 | The :math:`p` value in the Minowksi distance formula. 56 | 57 | Returns 58 | ------- 59 | minowski : float 60 | Minowski distance. 61 | 62 | Raises 63 | ------ 64 | ValueError 65 | If ``p`` is zero. 66 | """ 67 | h1, h2 = __prepare_histogram(h1, h2) 68 | if 0 == p: raise ValueError('p can not be zero') 69 | elif int == type(p): 70 | if p > 0 and p < 25: return __minowski_low_positive_integer_p(h1, h2, p) 71 | elif p < 0 and p > -25: return __minowski_low_negative_integer_p(h1, h2, p) 72 | return math.pow(scipy.sum(scipy.power(scipy.absolute(h1 - h2), p)), 1./p) 73 | 74 | def __minowski_low_positive_integer_p(h1, h2, p = 2): # 11..43 us for p = 1..24 \w 100 bins 75 | """ 76 | A faster implementation of the Minowski distance for positive integer < 25. 77 | @note do not use this function directly, but the general @link minowski() method. 78 | @note the passed histograms must be scipy arrays. 79 | """ 80 | mult = scipy.absolute(h1 - h2) 81 | dif = mult 82 | for _ in range(p - 1): dif = scipy.multiply(dif, mult) 83 | return math.pow(scipy.sum(dif), 1./p) 84 | 85 | def __minowski_low_negative_integer_p(h1, h2, p = 2): # 14..46 us for p = -1..-24 \w 100 bins 86 | """ 87 | A faster implementation of the Minowski distance for negative integer > -25. 88 | @note do not use this function directly, but the general @link minowski() method. 89 | @note the passed histograms must be scipy arrays. 90 | """ 91 | mult = scipy.absolute(h1 - h2) 92 | dif = mult 93 | for _ in range(-p + 1): dif = scipy.multiply(dif, mult) 94 | return math.pow(scipy.sum(1./dif), 1./p) 95 | 96 | def manhattan(h1, h2): # # 7 us @array, 31 us @list \w 100 bins 97 | r""" 98 | Equal to Minowski distance with :math:`p=1`. 99 | 100 | See also 101 | -------- 102 | minowski 103 | """ 104 | h1, h2 = __prepare_histogram(h1, h2) 105 | return scipy.sum(scipy.absolute(h1 - h2)) 106 | 107 | def euclidean(h1, h2): # 9 us @array, 33 us @list \w 100 bins 108 | r""" 109 | Equal to Minowski distance with :math:`p=2`. 110 | 111 | See also 112 | -------- 113 | minowski 114 | """ 115 | h1, h2 = __prepare_histogram(h1, h2) 116 | return math.sqrt(scipy.sum(scipy.square(scipy.absolute(h1 - h2)))) 117 | 118 | def chebyshev(h1, h2): # 12 us @array, 36 us @list \w 100 bins 119 | r""" 120 | Chebyshev distance. 121 | 122 | Also Tchebychev distance, Maximum or :math:`L_{\infty}` metric; equal to Minowski 123 | distance with :math:`p=+\infty`. For the case of :math:`p=-\infty`, use `chebyshev_neg`. 124 | 125 | The Chebyshev distance between two histograms :math:`H` and :math:`H'` of size :math:`m` is 126 | defined as: 127 | 128 | .. math:: 129 | 130 | d_{\infty}(H, H') = \max_{m=1}^M|H_m-H'_m| 131 | 132 | *Attributes:* 133 | 134 | - semimetric (triangle equation satisfied?) 135 | 136 | *Attributes for normalized histograms:* 137 | 138 | - :math:`d(H, H')\in[0, 1]` 139 | - :math:`d(H, H) = 0` 140 | - :math:`d(H, H') = d(H', H)` 141 | 142 | *Attributes for not-normalized histograms:* 143 | 144 | - :math:`d(H, H')\in[0, \infty)` 145 | - :math:`d(H, H) = 0` 146 | - :math:`d(H, H') = d(H', H)` 147 | 148 | *Attributes for not-equal histograms:* 149 | 150 | - not applicable 151 | 152 | Parameters 153 | ---------- 154 | h1 : sequence 155 | The first histogram. 156 | h2 : sequence 157 | The second histogram. 158 | 159 | Returns 160 | ------- 161 | chebyshev : float 162 | Chebyshev distance. 163 | 164 | See also 165 | -------- 166 | minowski, chebyshev_neg 167 | """ 168 | h1, h2 = __prepare_histogram(h1, h2) 169 | return max(scipy.absolute(h1 - h2)) 170 | 171 | def chebyshev_neg(h1, h2): # 12 us @array, 36 us @list \w 100 bins 172 | r""" 173 | Chebyshev negative distance. 174 | 175 | Also Tchebychev distance, Minimum or :math:`L_{-\infty}` metric; equal to Minowski 176 | distance with :math:`p=-\infty`. For the case of :math:`p=+\infty`, use `chebyshev`. 177 | 178 | The Chebyshev distance between two histograms :math:`H` and :math:`H'` of size :math:`m` is 179 | defined as: 180 | 181 | .. math:: 182 | 183 | d_{-\infty}(H, H') = \min_{m=1}^M|H_m-H'_m| 184 | 185 | *Attributes:* 186 | 187 | - semimetric (triangle equation satisfied?) 188 | 189 | *Attributes for normalized histograms:* 190 | 191 | - :math:`d(H, H')\in[0, 1]` 192 | - :math:`d(H, H) = 0` 193 | - :math:`d(H, H') = d(H', H)` 194 | 195 | *Attributes for not-normalized histograms:* 196 | 197 | - :math:`d(H, H')\in[0, \infty)` 198 | - :math:`d(H, H) = 0` 199 | - :math:`d(H, H') = d(H', H)` 200 | 201 | *Attributes for not-equal histograms:* 202 | 203 | - not applicable 204 | 205 | Parameters 206 | ---------- 207 | h1 : sequence 208 | The first histogram. 209 | h2 : sequence 210 | The second histogram. 211 | 212 | Returns 213 | ------- 214 | chebyshev_neg : float 215 | Chebyshev negative distance. 216 | 217 | See also 218 | -------- 219 | minowski, chebyshev 220 | """ 221 | h1, h2 = __prepare_histogram(h1, h2) 222 | return min(scipy.absolute(h1 - h2)) 223 | 224 | def histogram_intersection(h1, h2): # 6 us @array, 30 us @list \w 100 bins 225 | r""" 226 | Calculate the common part of two histograms. 227 | 228 | The histogram intersection between two histograms :math:`H` and :math:`H'` of size :math:`m` is 229 | defined as: 230 | 231 | .. math:: 232 | 233 | d_{\cap}(H, H') = \sum_{m=1}^M\min(H_m, H'_m) 234 | 235 | *Attributes:* 236 | 237 | - a real metric 238 | 239 | *Attributes for normalized histograms:* 240 | 241 | - :math:`d(H, H')\in[0, 1]` 242 | - :math:`d(H, H) = 1` 243 | - :math:`d(H, H') = d(H', H)` 244 | 245 | *Attributes for not-normalized histograms:* 246 | 247 | - not applicable 248 | 249 | *Attributes for not-equal histograms:* 250 | 251 | - not applicable 252 | 253 | Parameters 254 | ---------- 255 | h1 : sequence 256 | The first histogram, normalized. 257 | h2 : sequence 258 | The second histogram, normalized, same bins as ``h1``. 259 | 260 | Returns 261 | ------- 262 | histogram_intersection : float 263 | Intersection between the two histograms. 264 | """ 265 | h1, h2 = __prepare_histogram(h1, h2) 266 | return scipy.sum(scipy.minimum(h1, h2)) 267 | 268 | def histogram_intersection_1(h1, h2): # 7 us @array, 31 us @list \w 100 bins 269 | r""" 270 | Turns the histogram intersection similarity into a distance measure for normalized, 271 | positive histograms. 272 | 273 | .. math:: 274 | 275 | d_{\bar{\cos}}(H, H') = 1 - d_{\cap}(H, H') 276 | 277 | See `histogram_intersection` for the definition of :math:`d_{\cap}(H, H')`. 278 | 279 | *Attributes:* 280 | 281 | - semimetric 282 | 283 | *Attributes for normalized histograms:* 284 | 285 | - :math:`d(H, H')\in[0, 1]` 286 | - :math:`d(H, H) = 0` 287 | - :math:`d(H, H') = d(H', H)` 288 | 289 | *Attributes for not-normalized histograms:* 290 | 291 | - not applicable 292 | 293 | *Attributes for not-equal histograms:* 294 | 295 | - not applicable 296 | 297 | Parameters 298 | ---------- 299 | h1 : sequence 300 | The first histogram, normalized. 301 | h2 : sequence 302 | The second histogram, normalized, same bins as ``h1``. 303 | 304 | Returns 305 | ------- 306 | histogram_intersection : float 307 | Intersection between the two histograms. 308 | """ 309 | return 1. - histogram_intersection(h1, h2) 310 | 311 | def relative_deviation(h1, h2): # 18 us @array, 42 us @list \w 100 bins 312 | r""" 313 | Calculate the deviation between two histograms. 314 | 315 | The relative deviation between two histograms :math:`H` and :math:`H'` of size :math:`m` is 316 | defined as: 317 | 318 | .. math:: 319 | 320 | d_{rd}(H, H') = 321 | \frac{ 322 | \sqrt{\sum_{m=1}^M(H_m - H'_m)^2} 323 | }{ 324 | \frac{1}{2} 325 | \left( 326 | \sqrt{\sum_{m=1}^M H_m^2} + 327 | \sqrt{\sum_{m=1}^M {H'}_m^2} 328 | \right) 329 | } 330 | 331 | *Attributes:* 332 | 333 | - semimetric (triangle equation satisfied?) 334 | 335 | *Attributes for normalized histograms:* 336 | 337 | - :math:`d(H, H')\in[0, \sqrt{2}]` 338 | - :math:`d(H, H) = 0` 339 | - :math:`d(H, H') = d(H', H)` 340 | 341 | *Attributes for not-normalized histograms:* 342 | 343 | - :math:`d(H, H')\in[0, 2]` 344 | - :math:`d(H, H) = 0` 345 | - :math:`d(H, H') = d(H', H)` 346 | 347 | *Attributes for not-equal histograms:* 348 | 349 | - not applicable 350 | 351 | Parameters 352 | ---------- 353 | h1 : sequence 354 | The first histogram. 355 | h2 : sequence 356 | The second histogram, same bins as ``h1``. 357 | 358 | Returns 359 | ------- 360 | relative_deviation : float 361 | Relative deviation between the two histograms. 362 | """ 363 | h1, h2 = __prepare_histogram(h1, h2) 364 | numerator = math.sqrt(scipy.sum(scipy.square(h1 - h2))) 365 | denominator = (math.sqrt(scipy.sum(scipy.square(h1))) + math.sqrt(scipy.sum(scipy.square(h2)))) / 2. 366 | return numerator / denominator 367 | 368 | def relative_bin_deviation(h1, h2): # 79 us @array, 104 us @list \w 100 bins 369 | r""" 370 | Calculate the bin-wise deviation between two histograms. 371 | 372 | The relative bin deviation between two histograms :math:`H` and :math:`H'` of size 373 | :math:`m` is defined as: 374 | 375 | .. math:: 376 | 377 | d_{rbd}(H, H') = \sum_{m=1}^M 378 | \frac{ 379 | \sqrt{(H_m - H'_m)^2} 380 | }{ 381 | \frac{1}{2} 382 | \left( 383 | \sqrt{H_m^2} + 384 | \sqrt{{H'}_m^2} 385 | \right) 386 | } 387 | 388 | *Attributes:* 389 | 390 | - a real metric 391 | 392 | *Attributes for normalized histograms:* 393 | 394 | - :math:`d(H, H')\in[0, \infty)` 395 | - :math:`d(H, H) = 0` 396 | - :math:`d(H, H') = d(H', H)` 397 | 398 | *Attributes for not-normalized histograms:* 399 | 400 | - :math:`d(H, H')\in[0, \infty)` 401 | - :math:`d(H, H) = 0` 402 | - :math:`d(H, H') = d(H', H)` 403 | 404 | *Attributes for not-equal histograms:* 405 | 406 | - not applicable 407 | 408 | Parameters 409 | ---------- 410 | h1 : sequence 411 | The first histogram. 412 | h2 : sequence 413 | The second histogram, same bins as ``h1``. 414 | 415 | Returns 416 | ------- 417 | relative_bin_deviation : float 418 | Relative bin deviation between the two histograms. 419 | """ 420 | h1, h2 = __prepare_histogram(h1, h2) 421 | numerator = scipy.sqrt(scipy.square(h1 - h2)) 422 | denominator = (scipy.sqrt(scipy.square(h1)) + scipy.sqrt(scipy.square(h2))) / 2. 423 | old_err_state = scipy.seterr(invalid='ignore') # divide through zero only occurs when the bin is zero in both histograms, in which case the division is 0/0 and leads to (and should lead to) 0 424 | result = numerator / denominator 425 | scipy.seterr(**old_err_state) 426 | result[scipy.isnan(result)] = 0 # faster than scipy.nan_to_num, which checks for +inf and -inf also 427 | return scipy.sum(result) 428 | 429 | def chi_square(h1, h2): # 23 us @array, 49 us @list \w 100 430 | r""" 431 | Chi-square distance. 432 | 433 | Measure how unlikely it is that one distribution (histogram) was drawn from the 434 | other. The Chi-square distance between two histograms :math:`H` and :math:`H'` of size 435 | :math:`m` is defined as: 436 | 437 | .. math:: 438 | 439 | d_{\chi^2}(H, H') = \sum_{m=1}^M 440 | \frac{ 441 | (H_m - H'_m)^2 442 | }{ 443 | H_m + H'_m 444 | } 445 | 446 | *Attributes:* 447 | 448 | - semimetric 449 | 450 | *Attributes for normalized histograms:* 451 | 452 | - :math:`d(H, H')\in[0, 2]` 453 | - :math:`d(H, H) = 0` 454 | - :math:`d(H, H') = d(H', H)` 455 | 456 | *Attributes for not-normalized histograms:* 457 | 458 | - :math:`d(H, H')\in[0, \infty)` 459 | - :math:`d(H, H) = 0` 460 | - :math:`d(H, H') = d(H', H)` 461 | 462 | *Attributes for not-equal histograms:* 463 | 464 | - not applicable 465 | 466 | Parameters 467 | ---------- 468 | h1 : sequence 469 | The first histogram. 470 | h2 : sequence 471 | The second histogram. 472 | 473 | Returns 474 | ------- 475 | chi_square : float 476 | Chi-square distance. 477 | """ 478 | h1, h2 = __prepare_histogram(h1, h2) 479 | old_err_state = scipy.seterr(invalid='ignore') # divide through zero only occurs when the bin is zero in both histograms, in which case the division is 0/0 and leads to (and should lead to) 0 480 | result = scipy.square(h1 - h2) / (h1 + h2) 481 | scipy.seterr(**old_err_state) 482 | result[scipy.isnan(result)] = 0 # faster than scipy.nan_to_num, which checks for +inf and -inf also 483 | return scipy.sum(result) 484 | 485 | 486 | def kullback_leibler(h1, h2): # 83 us @array, 109 us @list \w 100 bins 487 | r""" 488 | Kullback-Leibler divergence. 489 | 490 | Compute how inefficient it would to be code one histogram into another. 491 | Actually computes :math:`\frac{d_{KL}(h1, h2) + d_{KL}(h2, h1)}{2}` to achieve symmetry. 492 | 493 | The Kullback-Leibler divergence between two histograms :math:`H` and :math:`H'` of size 494 | :math:`m` is defined as: 495 | 496 | .. math:: 497 | 498 | d_{KL}(H, H') = \sum_{m=1}^M H_m\log\frac{H_m}{H'_m} 499 | 500 | *Attributes:* 501 | 502 | - quasimetric (but made symetric) 503 | 504 | *Attributes for normalized histograms:* 505 | 506 | - :math:`d(H, H')\in[0, \infty)` 507 | - :math:`d(H, H) = 0` 508 | - :math:`d(H, H') = d(H', H)` 509 | 510 | *Attributes for not-normalized histograms:* 511 | 512 | - not applicable 513 | 514 | *Attributes for not-equal histograms:* 515 | 516 | - not applicable 517 | 518 | Parameters 519 | ---------- 520 | h1 : sequence 521 | The first histogram, where h1[i] > 0 for any i such that h2[i] > 0, normalized. 522 | h2 : sequence 523 | The second histogram, where h2[i] > 0 for any i such that h1[i] > 0, normalized, same bins as ``h1``. 524 | 525 | Returns 526 | ------- 527 | kullback_leibler : float 528 | Kullback-Leibler divergence. 529 | 530 | """ 531 | old_err_state = scipy.seterr(divide='raise') 532 | try: 533 | h1, h2 = __prepare_histogram(h1, h2) 534 | result = (__kullback_leibler(h1, h2) + __kullback_leibler(h2, h1)) / 2. 535 | scipy.seterr(**old_err_state) 536 | return result 537 | except FloatingPointError: 538 | scipy.seterr(**old_err_state) 539 | raise ValueError('h1 can only contain zero values where h2 also contains zero values and vice-versa') 540 | 541 | def __kullback_leibler(h1, h2): # 36.3 us 542 | """ 543 | The actual KL implementation. @see kullback_leibler() for details. 544 | Expects the histograms to be of type scipy.ndarray. 545 | """ 546 | result = h1.astype(scipy.float_) 547 | mask = h1 != 0 548 | result[mask] = scipy.multiply(h1[mask], scipy.log(h1[mask] / h2[mask])) 549 | return scipy.sum(result) 550 | 551 | def jensen_shannon(h1, h2): # 85 us @array, 110 us @list \w 100 bins 552 | r""" 553 | Jensen-Shannon divergence. 554 | 555 | A symmetric and numerically more stable empirical extension of the Kullback-Leibler 556 | divergence. 557 | 558 | The Jensen Shannon divergence between two histograms :math:`H` and :math:`H'` of size 559 | :math:`m` is defined as: 560 | 561 | .. math:: 562 | 563 | d_{JSD}(H, H') = 564 | \frac{1}{2} d_{KL}(H, H^*) + 565 | \frac{1}{2} d_{KL}(H', H^*) 566 | 567 | with :math:`H^*=\frac{1}{2}(H + H')`. 568 | 569 | *Attributes:* 570 | 571 | - semimetric 572 | 573 | *Attributes for normalized histograms:* 574 | 575 | - :math:`d(H, H')\in[0, 1]` 576 | - :math:`d(H, H) = 0` 577 | - :math:`d(H, H') = d(H', H)` 578 | 579 | *Attributes for not-normalized histograms:* 580 | 581 | - :math:`d(H, H')\in[0, \infty)` 582 | - :math:`d(H, H) = 0` 583 | - :math:`d(H, H') = d(H', H)` 584 | 585 | *Attributes for not-equal histograms:* 586 | 587 | - not applicable 588 | 589 | Parameters 590 | ---------- 591 | h1 : sequence 592 | The first histogram. 593 | h2 : sequence 594 | The second histogram, same bins as ``h1``. 595 | 596 | Returns 597 | ------- 598 | jensen_shannon : float 599 | Jensen-Shannon divergence. 600 | 601 | """ 602 | h1, h2 = __prepare_histogram(h1, h2) 603 | s = (h1 + h2) / 2. 604 | return __kullback_leibler(h1, s) / 2. + __kullback_leibler(h2, s) / 2. 605 | 606 | def fidelity_based(h1, h2): # 25 us @array, 51 us @list \w 100 bins 607 | r""" 608 | Fidelity based distance. 609 | 610 | Also Bhattacharyya distance; see also the extensions `noelle_1` to `noelle_5`. 611 | 612 | The metric between two histograms :math:`H` and :math:`H'` of size :math:`m` is defined as: 613 | 614 | .. math:: 615 | 616 | d_{F}(H, H') = \sum_{m=1}^M\sqrt{H_m * H'_m} 617 | 618 | 619 | *Attributes:* 620 | 621 | - not a metric, a similarity 622 | 623 | *Attributes for normalized histograms:* 624 | 625 | - :math:`d(H, H')\in[0, 1]` 626 | - :math:`d(H, H) = 1` 627 | - :math:`d(H, H') = d(H', H)` 628 | 629 | *Attributes for not-normalized histograms:* 630 | 631 | - not applicable 632 | 633 | *Attributes for not-equal histograms:* 634 | 635 | - not applicable 636 | 637 | Parameters 638 | ---------- 639 | h1 : sequence 640 | The first histogram, normalized. 641 | h2 : sequence 642 | The second histogram, normalized, same bins as ``h1``. 643 | 644 | Returns 645 | ------- 646 | fidelity_based : float 647 | Fidelity based distance. 648 | 649 | Notes 650 | ----- 651 | The fidelity between two histograms :math:`H` and :math:`H'` is the same as the 652 | cosine between their square roots :math:`\sqrt{H}` and :math:`\sqrt{H'}`. 653 | """ 654 | h1, h2 = __prepare_histogram(h1, h2) 655 | result = scipy.sum(scipy.sqrt(h1 * h2)) 656 | result = 0 if 0 > result else result # for rounding errors 657 | result = 1 if 1 < result else result # for rounding errors 658 | return result 659 | 660 | def noelle_1(h1, h2): # 26 us @array, 52 us @list \w 100 bins 661 | r""" 662 | Extension of `fidelity_based` proposed by [1]_. 663 | 664 | .. math:: 665 | 666 | d_{\bar{F}}(H, H') = 1 - d_{F}(H, H') 667 | 668 | See `fidelity_based` for the definition of :math:`d_{F}(H, H')`. 669 | 670 | *Attributes:* 671 | 672 | - semimetric 673 | 674 | *Attributes for normalized histograms:* 675 | 676 | - :math:`d(H, H')\in[0, 1]` 677 | - :math:`d(H, H) = 0` 678 | - :math:`d(H, H') = d(H', H)` 679 | 680 | *Attributes for not-normalized histograms:* 681 | 682 | - not applicable 683 | 684 | *Attributes for not-equal histograms:* 685 | 686 | - not applicable 687 | 688 | Parameters 689 | ---------- 690 | h1 : sequence 691 | The first histogram, normalized. 692 | h2 : sequence 693 | The second histogram, normalized, same bins as ``h1``. 694 | 695 | Returns 696 | ------- 697 | fidelity_based : float 698 | Fidelity based distance. 699 | 700 | References 701 | ---------- 702 | .. [1] M. Noelle "Distribution Distance Measures Applied to 3-D Object Recognition", 2003 703 | """ 704 | return 1. - fidelity_based(h1, h2) 705 | 706 | def noelle_2(h1, h2): # 26 us @array, 52 us @list \w 100 bins 707 | r""" 708 | Extension of `fidelity_based` proposed by [1]_. 709 | 710 | .. math:: 711 | 712 | d_{\sqrt{1-F}}(H, H') = \sqrt{1 - d_{F}(H, H')} 713 | 714 | See `fidelity_based` for the definition of :math:`d_{F}(H, H')`. 715 | 716 | *Attributes:* 717 | 718 | - metric 719 | 720 | *Attributes for normalized histograms:* 721 | 722 | - :math:`d(H, H')\in[0, 1]` 723 | - :math:`d(H, H) = 0` 724 | - :math:`d(H, H') = d(H', H)` 725 | 726 | *Attributes for not-normalized histograms:* 727 | 728 | - not applicable 729 | 730 | *Attributes for not-equal histograms:* 731 | 732 | - not applicable 733 | 734 | Parameters 735 | ---------- 736 | h1 : sequence 737 | The first histogram, normalized. 738 | h2 : sequence 739 | The second histogram, normalized, same bins as ``h1``. 740 | 741 | Returns 742 | ------- 743 | fidelity_based : float 744 | Fidelity based distance. 745 | 746 | References 747 | ---------- 748 | .. [1] M. Noelle "Distribution Distance Measures Applied to 3-D Object Recognition", 2003 749 | """ 750 | return math.sqrt(1. - fidelity_based(h1, h2)) 751 | 752 | def noelle_3(h1, h2): # 26 us @array, 52 us @list \w 100 bins 753 | r""" 754 | Extension of `fidelity_based` proposed by [1]_. 755 | 756 | .. math:: 757 | 758 | d_{\log(2-F)}(H, H') = \log(2 - d_{F}(H, H')) 759 | 760 | See `fidelity_based` for the definition of :math:`d_{F}(H, H')`. 761 | 762 | *Attributes:* 763 | 764 | - semimetric 765 | 766 | *Attributes for normalized histograms:* 767 | 768 | - :math:`d(H, H')\in[0, log(2)]` 769 | - :math:`d(H, H) = 0` 770 | - :math:`d(H, H') = d(H', H)` 771 | 772 | *Attributes for not-normalized histograms:* 773 | 774 | - not applicable 775 | 776 | *Attributes for not-equal histograms:* 777 | 778 | - not applicable 779 | 780 | Parameters 781 | ---------- 782 | h1 : sequence 783 | The first histogram, normalized. 784 | h2 : sequence 785 | The second histogram, normalized, same bins as ``h1``. 786 | 787 | Returns 788 | ------- 789 | fidelity_based : float 790 | Fidelity based distance. 791 | 792 | References 793 | ---------- 794 | .. [1] M. Noelle "Distribution Distance Measures Applied to 3-D Object Recognition", 2003 795 | """ 796 | return math.log(2 - fidelity_based(h1, h2)) 797 | 798 | def noelle_4(h1, h2): # 26 us @array, 52 us @list \w 100 bins 799 | r""" 800 | Extension of `fidelity_based` proposed by [1]_. 801 | 802 | .. math:: 803 | 804 | d_{\arccos F}(H, H') = \frac{2}{\pi} \arccos d_{F}(H, H') 805 | 806 | See `fidelity_based` for the definition of :math:`d_{F}(H, H')`. 807 | 808 | *Attributes:* 809 | 810 | - metric 811 | 812 | *Attributes for normalized histograms:* 813 | 814 | - :math:`d(H, H')\in[0, 1]` 815 | - :math:`d(H, H) = 0` 816 | - :math:`d(H, H') = d(H', H)` 817 | 818 | *Attributes for not-normalized histograms:* 819 | 820 | - not applicable 821 | 822 | *Attributes for not-equal histograms:* 823 | 824 | - not applicable 825 | 826 | Parameters 827 | ---------- 828 | h1 : sequence 829 | The first histogram, normalized. 830 | h2 : sequence 831 | The second histogram, normalized, same bins as ``h1``. 832 | 833 | Returns 834 | ------- 835 | fidelity_based : float 836 | Fidelity based distance. 837 | 838 | References 839 | ---------- 840 | .. [1] M. Noelle "Distribution Distance Measures Applied to 3-D Object Recognition", 2003 841 | """ 842 | return 2. / math.pi * math.acos(fidelity_based(h1, h2)) 843 | 844 | def noelle_5(h1, h2): # 26 us @array, 52 us @list \w 100 bins 845 | r""" 846 | Extension of `fidelity_based` proposed by [1]_. 847 | 848 | .. math:: 849 | 850 | d_{\sin F}(H, H') = \sqrt{1 -d_{F}^2(H, H')} 851 | 852 | See `fidelity_based` for the definition of :math:`d_{F}(H, H')`. 853 | 854 | *Attributes:* 855 | 856 | - metric 857 | 858 | *Attributes for normalized histograms:* 859 | 860 | - :math:`d(H, H')\in[0, 1]` 861 | - :math:`d(H, H) = 0` 862 | - :math:`d(H, H') = d(H', H)` 863 | 864 | *Attributes for not-normalized histograms:* 865 | 866 | - not applicable 867 | 868 | *Attributes for not-equal histograms:* 869 | 870 | - not applicable 871 | 872 | Parameters 873 | ---------- 874 | h1 : sequence 875 | The first histogram, normalized. 876 | h2 : sequence 877 | The second histogram, normalized, same bins as ``h1``. 878 | 879 | Returns 880 | ------- 881 | fidelity_based : float 882 | Fidelity based distance. 883 | 884 | References 885 | ---------- 886 | .. [1] M. Noelle "Distribution Distance Measures Applied to 3-D Object Recognition", 2003 887 | """ 888 | return math.sqrt(1 - math.pow(fidelity_based(h1, h2), 2)) 889 | 890 | 891 | def cosine_alt(h1, h2): # 17 us @array, 42 us @list \w 100 bins 892 | r""" 893 | Alternative implementation of the `cosine` distance measure. 894 | 895 | Notes 896 | ----- 897 | Under development. 898 | """ 899 | h1, h2 = __prepare_histogram(h1, h2) 900 | return -1 * float(scipy.sum(h1 * h2)) / (scipy.sum(scipy.power(h1, 2)) * scipy.sum(scipy.power(h2, 2))) 901 | 902 | def cosine(h1, h2): # 17 us @array, 42 us @list \w 100 bins 903 | r""" 904 | Cosine simmilarity. 905 | 906 | Compute the angle between the two histograms in vector space irrespective of their 907 | length. The cosine similarity between two histograms :math:`H` and :math:`H'` of size 908 | :math:`m` is defined as: 909 | 910 | .. math:: 911 | 912 | d_{\cos}(H, H') = \cos\alpha = \frac{H * H'}{\|H\| \|H'\|} = \frac{\sum_{m=1}^M H_m*H'_m}{\sqrt{\sum_{m=1}^M H_m^2} * \sqrt{\sum_{m=1}^M {H'}_m^2}} 913 | 914 | 915 | *Attributes:* 916 | 917 | - not a metric, a similarity 918 | 919 | *Attributes for normalized histograms:* 920 | 921 | - :math:`d(H, H')\in[0, 1]` 922 | - :math:`d(H, H) = 1` 923 | - :math:`d(H, H') = d(H', H)` 924 | 925 | *Attributes for not-normalized histograms:* 926 | 927 | - :math:`d(H, H')\in[-1, 1]` 928 | - :math:`d(H, H) = 1` 929 | - :math:`d(H, H') = d(H', H)` 930 | 931 | *Attributes for not-equal histograms:* 932 | 933 | - not applicable 934 | 935 | Parameters 936 | ---------- 937 | h1 : sequence 938 | The first histogram. 939 | h2 : sequence 940 | The second histogram, same bins as ``h1``. 941 | 942 | Returns 943 | ------- 944 | cosine : float 945 | Cosine simmilarity. 946 | 947 | Notes 948 | ----- 949 | The resulting similarity ranges from -1 meaning exactly opposite, to 1 meaning 950 | exactly the same, with 0 usually indicating independence, and in-between values 951 | indicating intermediate similarity or dissimilarity. 952 | """ 953 | h1, h2 = __prepare_histogram(h1, h2) 954 | return scipy.sum(h1 * h2) / math.sqrt(scipy.sum(scipy.square(h1)) * scipy.sum(scipy.square(h2))) 955 | 956 | def cosine_1(h1, h2): # 18 us @array, 43 us @list \w 100 bins 957 | r""" 958 | Cosine simmilarity. 959 | 960 | Turns the cosine similarity into a distance measure for normalized, positive 961 | histograms. 962 | 963 | .. math:: 964 | 965 | d_{\bar{\cos}}(H, H') = 1 - d_{\cos}(H, H') 966 | 967 | See `cosine` for the definition of :math:`d_{\cos}(H, H')`. 968 | 969 | *Attributes:* 970 | 971 | - metric 972 | 973 | *Attributes for normalized histograms:* 974 | 975 | - :math:`d(H, H')\in[0, 1]` 976 | - :math:`d(H, H) = 0` 977 | - :math:`d(H, H') = d(H', H)` 978 | 979 | *Attributes for not-normalized histograms:* 980 | 981 | - not applicable 982 | 983 | *Attributes for not-equal histograms:* 984 | 985 | - not applicable 986 | 987 | Parameters 988 | ---------- 989 | h1 : sequence 990 | The first histogram, normalized. 991 | h2 : sequence 992 | The second histogram, normalized, same bins as ``h1``. 993 | 994 | Returns 995 | ------- 996 | cosine : float 997 | Cosine distance. 998 | """ 999 | return 1. - cosine(h1, h2) 1000 | 1001 | def cosine_2(h1, h2): # 19 us @array, 44 us @list \w 100 bins 1002 | r""" 1003 | Cosine simmilarity. 1004 | 1005 | Turns the cosine similarity into a distance measure for normalized, positive 1006 | histograms. 1007 | 1008 | .. math:: 1009 | 1010 | d_{\bar{\cos}}(H, H') = 1 - \frac{2*\arccos d_{\cos}(H, H')}{pi} 1011 | 1012 | See `cosine` for the definition of :math:`d_{\cos}(H, H')`. 1013 | 1014 | *Attributes:* 1015 | 1016 | - metric 1017 | 1018 | *Attributes for normalized histograms:* 1019 | 1020 | - :math:`d(H, H')\in[0, 1]` 1021 | - :math:`d(H, H) = 0` 1022 | - :math:`d(H, H') = d(H', H)` 1023 | 1024 | *Attributes for not-normalized histograms:* 1025 | 1026 | - not applicable 1027 | 1028 | *Attributes for not-equal histograms:* 1029 | 1030 | - not applicable 1031 | 1032 | Parameters 1033 | ---------- 1034 | h1 : sequence 1035 | The first histogram, normalized. 1036 | h2 : sequence 1037 | The second histogram, normalized, same bins as ``h1``. 1038 | 1039 | Returns 1040 | ------- 1041 | cosine : float 1042 | Cosine distance. 1043 | """ 1044 | return 1. - (2 * cosine(h1, h2)) / math.pi 1045 | 1046 | def correlate(h1, h2): # 31 us @array, 55 us @list \w 100 bins 1047 | r""" 1048 | Correlation between two histograms. 1049 | 1050 | The histogram correlation between two histograms :math:`H` and :math:`H'` of size :math:`m` 1051 | is defined as: 1052 | 1053 | .. math:: 1054 | 1055 | d_{corr}(H, H') = 1056 | \frac{ 1057 | \sum_{m=1}^M (H_m-\bar{H}) \cdot (H'_m-\bar{H'}) 1058 | }{ 1059 | \sqrt{\sum_{m=1}^M (H_m-\bar{H})^2 \cdot \sum_{m=1}^M (H'_m-\bar{H'})^2} 1060 | } 1061 | 1062 | with :math:`\bar{H}` and :math:`\bar{H'}` being the mean values of :math:`H` resp. :math:`H'` 1063 | 1064 | *Attributes:* 1065 | 1066 | - not a metric, a similarity 1067 | 1068 | *Attributes for normalized histograms:* 1069 | 1070 | - :math:`d(H, H')\in[-1, 1]` 1071 | - :math:`d(H, H) = 1` 1072 | - :math:`d(H, H') = d(H', H)` 1073 | 1074 | *Attributes for not-normalized histograms:* 1075 | 1076 | - :math:`d(H, H')\in[-1, 1]` 1077 | - :math:`d(H, H) = 1` 1078 | - :math:`d(H, H') = d(H', H)` 1079 | 1080 | *Attributes for not-equal histograms:* 1081 | 1082 | - not applicable 1083 | 1084 | Parameters 1085 | ---------- 1086 | h1 : sequence 1087 | The first histogram. 1088 | h2 : sequence 1089 | The second histogram, same bins as ``h1``. 1090 | 1091 | Returns 1092 | ------- 1093 | correlate : float 1094 | Correlation between the histograms. 1095 | 1096 | Notes 1097 | ----- 1098 | Returns 0 if one of h1 or h2 contain only zeros. 1099 | 1100 | """ 1101 | h1, h2 = __prepare_histogram(h1, h2) 1102 | h1m = h1 - scipy.sum(h1) / float(h1.size) 1103 | h2m = h2 - scipy.sum(h2) / float(h2.size) 1104 | a = scipy.sum(scipy.multiply(h1m, h2m)) 1105 | b = math.sqrt(scipy.sum(scipy.square(h1m)) * scipy.sum(scipy.square(h2m))) 1106 | return 0 if 0 == b else a / b 1107 | 1108 | def correlate_1(h1, h2): # 32 us @array, 56 us @list \w 100 bins 1109 | r""" 1110 | Correlation distance. 1111 | 1112 | Turns the histogram correlation into a distance measure for normalized, positive 1113 | histograms. 1114 | 1115 | .. math:: 1116 | 1117 | d_{\bar{corr}}(H, H') = 1-\frac{d_{corr}(H, H')}{2}. 1118 | 1119 | See `correlate` for the definition of :math:`d_{corr}(H, H')`. 1120 | 1121 | *Attributes:* 1122 | 1123 | - semimetric 1124 | 1125 | *Attributes for normalized histograms:* 1126 | 1127 | - :math:`d(H, H')\in[0, 1]` 1128 | - :math:`d(H, H) = 0` 1129 | - :math:`d(H, H') = d(H', H)` 1130 | 1131 | *Attributes for not-normalized histograms:* 1132 | 1133 | - :math:`d(H, H')\in[0, 1]` 1134 | - :math:`d(H, H) = 0` 1135 | - :math:`d(H, H') = d(H', H)` 1136 | 1137 | *Attributes for not-equal histograms:* 1138 | 1139 | - not applicable 1140 | 1141 | Parameters 1142 | ---------- 1143 | h1 : sequence 1144 | The first histogram. 1145 | h2 : sequence 1146 | The second histogram, same bins as ``h1``. 1147 | 1148 | Returns 1149 | ------- 1150 | correlate : float 1151 | Correlation distnace between the histograms. 1152 | 1153 | Notes 1154 | ----- 1155 | Returns 0.5 if one of h1 or h2 contains only zeros. 1156 | """ 1157 | return (1. - correlate(h1, h2))/2. 1158 | 1159 | 1160 | # ///////////////////////////// # 1161 | # Cross-bin comparison measures # 1162 | # ///////////////////////////// # 1163 | 1164 | def quadratic_forms(h1, h2): 1165 | r""" 1166 | Quadrativ forms metric. 1167 | 1168 | Notes 1169 | ----- 1170 | UNDER DEVELOPMENT 1171 | 1172 | This distance measure shows very strange behaviour. The expression 1173 | transpose(h1-h2) * A * (h1-h2) yields egative values that can not be processed by the 1174 | square root. Some examples:: 1175 | 1176 | h1 h2 transpose(h1-h2) * A * (h1-h2) 1177 | [1, 0] to [0.0, 1.0] : -2.0 1178 | [1, 0] to [0.5, 0.5] : 0.0 1179 | [1, 0] to [0.6666666666666667, 0.3333333333333333] : 0.111111111111 1180 | [1, 0] to [0.75, 0.25] : 0.0833333333333 1181 | [1, 0] to [0.8, 0.2] : 0.06 1182 | [1, 0] to [0.8333333333333334, 0.16666666666666666] : 0.0444444444444 1183 | [1, 0] to [0.8571428571428572, 0.14285714285714285] : 0.0340136054422 1184 | [1, 0] to [0.875, 0.125] : 0.0267857142857 1185 | [1, 0] to [0.8888888888888888, 0.1111111111111111] : 0.0216049382716 1186 | [1, 0] to [0.9, 0.1] : 0.0177777777778 1187 | [1, 0] to [1, 0]: 0.0 1188 | 1189 | It is clearly undesireable to recieve negative values and even worse to get a value 1190 | of zero for other cases than the same histograms. 1191 | """ 1192 | h1, h2 = __prepare_histogram(h1, h2) 1193 | A = __quadratic_forms_matrix_euclidean(h1, h2) 1194 | return math.sqrt((h1-h2).dot(A.dot(h1-h2))) # transpose(h1-h2) * A * (h1-h2) 1195 | 1196 | def __quadratic_forms_matrix_euclidean(h1, h2): 1197 | r""" 1198 | Compute the bin-similarity matrix for the quadratic form distance measure. 1199 | The matric :math:`A` for two histograms :math:`H` and :math:`H'` of size :math:`m` and 1200 | :math:`n` respectively is defined as 1201 | 1202 | .. math:: 1203 | 1204 | A_{m,n} = 1 - \frac{d_2(H_m, {H'}_n)}{d_{max}} 1205 | 1206 | with 1207 | 1208 | .. math:: 1209 | 1210 | d_{max} = \max_{m,n}d_2(H_m, {H'}_n) 1211 | 1212 | See also 1213 | -------- 1214 | quadratic_forms 1215 | """ 1216 | A = scipy.repeat(h2[:,scipy.newaxis], h1.size, 1) # repeat second array to form a matrix 1217 | A = scipy.absolute(A - h1) # euclidean distances 1218 | return 1 - (A / float(A.max())) 1219 | 1220 | 1221 | # //////////////// # 1222 | # Helper functions # 1223 | # //////////////// # 1224 | 1225 | def __prepare_histogram(h1, h2): 1226 | """Convert the histograms to scipy.ndarrays if required.""" 1227 | h1 = h1 if scipy.ndarray == type(h1) else scipy.asarray(h1) 1228 | h2 = h2 if scipy.ndarray == type(h2) else scipy.asarray(h2) 1229 | if h1.shape != h2.shape or h1.size != h2.size: 1230 | raise ValueError('h1 and h2 must be of same shape and size') 1231 | return h1, h2 1232 | -------------------------------------------------------------------------------- /utils/metric/image.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy 3 | 4 | # own modules 5 | from .utils import ArgumentError 6 | 7 | 8 | def mutual_information(i1, i2, bins=256): 9 | r"""互信息 10 | Computes the mutual information (MI) (a measure of entropy) between two images. 11 | 12 | MI is not real metric, but a symmetric and nonnegative similarity measures that 13 | takes high values for similar images. Negative values are also possible. 14 | 15 | Intuitively, mutual information measures the information that ``i1`` and ``i2`` share: it 16 | measures how much knowing one of these variables reduces uncertainty about the other. 17 | 18 | The Entropy is defined as: 19 | 20 | .. math:: 21 | 22 | H(X) = - \sum_i p(g_i) * ln(p(g_i) 23 | 24 | with :math:`p(g_i)` being the intensity probability of the images grey value :math:`g_i`. 25 | 26 | Assuming two images :math:`R` and :math:`T`, the mutual information is then computed by comparing the 27 | images entropy values (i.e. a measure how well-structured the common histogram is). 28 | The distance metric is then calculated as follows: 29 | 30 | .. math:: 31 | 32 | MI(R,T) = H(R) + H(T) - H(R,T) = H(R) - H(R|T) = H(T) - H(T|R) 33 | 34 | A maximization of the mutual information is equal to a minimization of the joint 35 | entropy. 36 | 37 | Parameters 38 | ---------- 39 | i1 : array_like 40 | The first image. 41 | i2 : array_like 42 | The second image. 43 | bins : integer 44 | The number of histogram bins (squared for the joined histogram). 45 | 46 | Returns 47 | ------- 48 | mutual_information : float 49 | The mutual information distance value between the supplied images. 50 | 51 | Raises 52 | ------ 53 | ArgumentError 54 | If the supplied arrays are of different shape. 55 | """ 56 | # pre-process function arguments 57 | i1 = numpy.asarray(i1) 58 | i2 = numpy.asarray(i2) 59 | 60 | # validate function arguments 61 | if not i1.shape == i2.shape: 62 | raise ArgumentError('the two supplied array-like sequences i1 and i2 must be of the same shape') 63 | 64 | # compute i1 and i2 histogram range 65 | i1_range = __range(i1, bins) 66 | i2_range = __range(i2, bins) 67 | 68 | # compute joined and separated normed histograms 69 | i1i2_hist, _, _ = numpy.histogram2d(i1.flatten(), i2.flatten(), bins=bins, range=[i1_range, 70 | i2_range]) # Note: histogram2d does not flatten array on its own 71 | i1_hist, _ = numpy.histogram(i1, bins=bins, range=i1_range) 72 | i2_hist, _ = numpy.histogram(i2, bins=bins, range=i2_range) 73 | 74 | # compute joined and separated entropy 75 | i1i2_entropy = __entropy(i1i2_hist) 76 | i1_entropy = __entropy(i1_hist) 77 | i2_entropy = __entropy(i2_hist) 78 | 79 | # compute and return the mutual information distance 80 | return i1_entropy + i2_entropy - i1i2_entropy 81 | 82 | 83 | def __range(a, bins): 84 | '''Compute the histogram range of the values in the array a according to 85 | scipy.stats.histogram.''' 86 | a = numpy.asarray(a) 87 | a_max = a.max() 88 | a_min = a.min() 89 | s = 0.5 * (a_max - a_min) / float(bins - 1) 90 | return (a_min - s, a_max + s) 91 | 92 | 93 | def __entropy(data): 94 | '''Compute entropy of the flattened data set (e.g. a density distribution).''' 95 | # normalize and convert to float 96 | data = data / float(numpy.sum(data)) 97 | # for each grey-value g with a probability p(g) = 0, the entropy is defined as 0, therefore we remove these values and also flatten the histogram 98 | data = data[numpy.nonzero(data)] 99 | # compute entropy 100 | return -1. * numpy.sum(data * numpy.log2(data)) 101 | -------------------------------------------------------------------------------- /utils/metric/inception_score.py: -------------------------------------------------------------------------------- 1 | #encoding: utf-8 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from torch.nn import functional as F 6 | import torch.utils.data 7 | from torchvision.models.inception import inception_v3 8 | import numpy as np 9 | from scipy.stats import entropy 10 | 11 | 12 | def inception_score(img_ds, cuda=True, batch_size=32, resize=False, splits=1): 13 | """Computes the inception score of the generated images imgs 14 | 15 | Ref: 16 | https://github.com/sbarratt/inception-score-pytorch/blob/master/inception_score.py 17 | Args: 18 | img_ds: Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] 19 | cuda: whether or not to run on GPU 20 | batch_size: batch size for feeding into Inception v3 21 | splits: number of splits 22 | """ 23 | N = len(img_ds) 24 | assert batch_size > 0 25 | assert N > batch_size 26 | # Set up dtype 27 | if cuda: 28 | dtype = torch.cuda.FloatTensor 29 | else: 30 | if torch.cuda.is_available(): 31 | print("WARNING: You have a CUDA device, so you should probably set cuda=True") 32 | dtype = torch.FloatTensor 33 | # Set up dataloader 34 | dataloader = torch.utils.data.DataLoader(img_ds, batch_size=batch_size) 35 | # Load inception model 36 | inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) 37 | inception_model.eval() 38 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) 39 | 40 | def get_pred(x): 41 | if resize: 42 | x = up(x) 43 | x = inception_model(x) 44 | return F.softmax(x).data.cpu().numpy() 45 | # Get predictions 46 | preds = np.zeros((N, 1000)) 47 | for i, batch in enumerate(dataloader, 0): 48 | batch = batch.type(dtype) 49 | batchv = Variable(batch) 50 | batch_size_i = batch.size()[0] 51 | 52 | preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv) 53 | 54 | # Now compute the mean kl-div 55 | split_scores = [] 56 | 57 | for k in range(splits): 58 | part = preds[k * (N // splits): (k + 1) * (N // splits), :] 59 | py = np.mean(part, axis=0) 60 | scores = [] 61 | for i in range(part.shape[0]): 62 | pyx = part[i, :] 63 | scores.append(entropy(pyx, py)) 64 | split_scores.append(np.exp(np.mean(scores))) 65 | 66 | return np.mean(split_scores), np.std(split_scores) 67 | 68 | 69 | if __name__ == '__main__': 70 | """在ImageNet数据集上使用""" 71 | class IgnoreLabelDataset(torch.utils.data.Dataset): 72 | def __init__(self, orig): 73 | self.orig = orig 74 | def __getitem__(self, index): 75 | return self.orig[index][0] 76 | 77 | def __len__(self): 78 | return len(self.orig) 79 | import torchvision.datasets as dset 80 | import torchvision.transforms as transforms 81 | cifar = dset.CIFAR10(root='data/', download=True, 82 | transform=transforms.Compose([ 83 | transforms.Scale(32), 84 | transforms.ToTensor(), 85 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 86 | ])) 87 | IgnoreLabelDataset(cifar) 88 | print("Calculating Inception Score...") 89 | print(inception_score(IgnoreLabelDataset(cifar), cuda=True, batch_size=32, resize=True, splits=10)) 90 | -------------------------------------------------------------------------------- /utils/metric/metrics_torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/3/28 15:38 4 | # @Author : Eric Ching 5 | import numpy as np 6 | import torch 7 | from skimage import measure 8 | SUPPORTED_METRICS = ['dice', 'iou', 'ap'] 9 | 10 | 11 | class MeanIoU: 12 | """ 13 | Computes IoU for each class separately and then averages over all classes. 14 | """ 15 | def __init__(self, ignore_index=None): 16 | """ 17 | :param ignore_index: id of the label to be ignored from IoU computation 18 | """ 19 | self.ignore_index = ignore_index 20 | 21 | def __call__(self, input, target): 22 | """ 23 | Args: 24 | input: 5D probability maps torch float tensor (NxCxDxHxW) 25 | target: 5D ground truth torch tensor. (NxCxDxHxW) 26 | Return: 27 | intersection over union averaged over all channels 28 | """ 29 | n_classes = input.size()[1] 30 | # batch dim must be 1 31 | input = input[0] 32 | target = target[0] 33 | assert input.size() == target.size() 34 | 35 | binary_prediction = self._binarize_predictions(input) 36 | 37 | if self.ignore_index is not None: 38 | # zero out ignore_index 39 | mask = target == self.ignore_index 40 | binary_prediction[mask] = 0 41 | target[mask] = 0 42 | 43 | # convert to uint8 just in case 44 | binary_prediction = binary_prediction.byte() 45 | target = target.byte() 46 | per_channel_iou = [] 47 | for c in range(n_classes): 48 | per_channel_iou.append(self._jaccard_index(binary_prediction[c], target[c])) 49 | 50 | return torch.mean(torch.tensor(per_channel_iou)) 51 | 52 | def _binarize_predictions(self, input): 53 | """ 54 | Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the 55 | same size as the input tensor. 56 | """ 57 | _, max_index = torch.max(input, dim=0, keepdim=True) 58 | return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1) 59 | 60 | def _jaccard_index(self, prediction, target): 61 | """ 62 | Computes IoU for a given target and prediction tensors 63 | """ 64 | return torch.sum(prediction & target).float() / torch.sum(prediction | target).float() 65 | 66 | 67 | class AveragePrecision: 68 | """ 69 | Computes Average Precision given boundary prediction and ground truth instance segmentation. 70 | """ 71 | 72 | def __init__(self, threshold=0.4, iou_range=(0.5, 1.0), ignore_index=-1, min_instance_size=None, 73 | use_last_target=False): 74 | """ 75 | :param threshold: probability value at which the input is going to be thresholded 76 | :param iou_range: compute ROC curve for the the range of IoU values: range(min,max,0.05) 77 | :param ignore_index: label to be ignored during computation 78 | :param min_instance_size: minimum size of the predicted instances to be considered 79 | :param use_last_target: if True use the last target channel to compute AP 80 | """ 81 | self.threshold = threshold 82 | # always have well defined ignore_index 83 | if ignore_index is None: 84 | ignore_index = -1 85 | self.iou_range = iou_range 86 | self.ignore_index = ignore_index 87 | self.min_instance_size = min_instance_size 88 | self.use_last_target = use_last_target 89 | 90 | def __call__(self, input, target): 91 | """ 92 | :param input: 5D probability maps torch float tensor (NxCxDxHxW) / or 4D numpy.ndarray 93 | :param target: 4D or 5D ground truth instance segmentation torch long tensor / or 3D numpy.ndarray 94 | :return: highest average precision among channels 95 | """ 96 | if isinstance(input, torch.Tensor): 97 | assert input.dim() == 5 98 | # convert to numpy array 99 | input = input[0].detach().cpu().numpy() # 4D 100 | if isinstance(target, torch.Tensor): 101 | if not self.use_last_target: 102 | assert target.dim() == 4 103 | # convert to numpy array 104 | target = target[0].detach().cpu().numpy() # 3D 105 | else: 106 | # if use_last_target == True the target must be 5D (NxCxDxHxW) 107 | assert target.dim() == 5 108 | target = target[0, -1].detach().cpu().numpy() # 3D 109 | 110 | if isinstance(input, np.ndarray): 111 | assert input.ndim == 4 112 | if isinstance(target, np.ndarray): 113 | assert target.ndim == 3 114 | 115 | # filter small instances from the target and get ground truth label set (without 'ignore_index') 116 | target, target_instances = self._filter_instances(target) 117 | 118 | per_channel_ap = [] 119 | n_channels = input.shape[0] 120 | for c in range(n_channels): 121 | predictions = input[c] 122 | # threshold probability maps 123 | predictions = predictions > self.threshold 124 | # for connected component analysis we need to treat boundary signal as background 125 | # assign 0-label to boundary mask 126 | predictions = np.logical_not(predictions).astype(np.uint8) 127 | # run connected components on the predicted mask; consider only 1-connectivity 128 | predicted = measure.label(predictions, background=0, connectivity=1) 129 | ap = self._calculate_average_precision(predicted, target, target_instances) 130 | per_channel_ap.append(ap) 131 | 132 | # get maximum average precision across channels 133 | max_ap, c_index = np.max(per_channel_ap), np.argmax(per_channel_ap) 134 | print(f'Max average precision: {max_ap}, channel: {c_index}') 135 | return max_ap 136 | 137 | def _calculate_average_precision(self, predicted, target, target_instances): 138 | recall, precision = self._roc_curve(predicted, target, target_instances) 139 | recall.insert(0, 0.0) # insert 0.0 at beginning of list 140 | recall.append(1.0) # insert 1.0 at end of list 141 | precision.insert(0, 0.0) # insert 0.0 at beginning of list 142 | precision.append(0.0) # insert 0.0 at end of list 143 | # make the precision(recall) piece-wise constant and monotonically decreasing 144 | # by iterating backwards starting from the last precision value (0.0) 145 | # see: https://www.jeremyjordan.me/evaluating-image-segmentation-models/ e.g. 146 | for i in range(len(precision) - 2, -1, -1): 147 | precision[i] = max(precision[i], precision[i + 1]) 148 | # compute the area under precision recall curve by simple integration of piece-wise constant function 149 | ap = 0.0 150 | for i in range(1, len(recall)): 151 | ap += ((recall[i] - recall[i - 1]) * precision[i]) 152 | return ap 153 | 154 | def _roc_curve(self, predicted, target, target_instances): 155 | ROC = [] 156 | predicted, predicted_instances = self._filter_instances(predicted) 157 | 158 | # compute precision/recall curve points for various IoU values from a given range 159 | for min_iou in np.arange(self.iou_range[0], self.iou_range[1], 0.1): 160 | # initialize false negatives set 161 | false_negatives = set(target_instances) 162 | # initialize false positives set 163 | false_positives = set(predicted_instances) 164 | # initialize true positives set 165 | true_positives = set() 166 | 167 | for pred_label in predicted_instances: 168 | target_label = self._find_overlapping_target(pred_label, predicted, target, min_iou) 169 | if target_label is not None: 170 | # update TP, FP and FN 171 | if target_label == self.ignore_index: 172 | # ignore if 'ignore_index' is the biggest overlapping 173 | false_positives.discard(pred_label) 174 | else: 175 | true_positives.add(pred_label) 176 | false_positives.discard(pred_label) 177 | false_negatives.discard(target_label) 178 | 179 | tp = len(true_positives) 180 | fp = len(false_positives) 181 | fn = len(false_negatives) 182 | 183 | recall = tp / (tp + fn) 184 | precision = tp / (tp + fp) 185 | ROC.append((recall, precision)) 186 | 187 | # sort points by recall 188 | ROC = np.array(sorted(ROC, key=lambda t: t[0])) 189 | # return recall and precision values 190 | return list(ROC[:, 0]), list(ROC[:, 1]) 191 | 192 | def _find_overlapping_target(self, predicted_label, predicted, target, min_iou): 193 | """ 194 | Return ground truth label which overlaps by at least 'min_iou' with a given input label 'p_label' 195 | or None if such ground truth label does not exist. 196 | """ 197 | mask_predicted = predicted == predicted_label 198 | overlapping_labels = target[mask_predicted] 199 | labels, counts = np.unique(overlapping_labels, return_counts=True) 200 | # retrieve the biggest overlapping label 201 | target_label_ind = np.argmax(counts) 202 | target_label = labels[target_label_ind] 203 | # return target label if IoU greater than 'min_iou'; since we're starting from 0.5 IoU there might be 204 | # only one target label that fulfill this criterion 205 | mask_target = target == target_label 206 | # return target_label if IoU > min_iou 207 | if self._iou(mask_predicted, mask_target) > min_iou: 208 | return target_label 209 | return None 210 | 211 | @staticmethod 212 | def _iou(prediction, target): 213 | """ 214 | Computes intersection over union 215 | """ 216 | intersection = np.logical_and(prediction, target) 217 | union = np.logical_or(prediction, target) 218 | return np.sum(intersection) / np.sum(union) 219 | 220 | def _filter_instances(self, input): 221 | """ 222 | Filters instances smaller than 'min_instance_size' by overriding them with 'ignore_index' 223 | :param input: input instance segmentation 224 | :return: tuple: (instance segmentation with small instances filtered, set of unique labels without the 'ignore_index') 225 | """ 226 | if self.min_instance_size is not None: 227 | labels, counts = np.unique(input, return_counts=True) 228 | for label, count in zip(labels, counts): 229 | if count < self.min_instance_size: 230 | mask = input == label 231 | input[mask] = self.ignore_index 232 | 233 | labels = set(np.unique(input)) 234 | labels.discard(self.ignore_index) 235 | 236 | return input, labels 237 | -------------------------------------------------------------------------------- /utils/metric/sim_torch.py: -------------------------------------------------------------------------------- 1 | #encoding: utf-8 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from math import exp 6 | 7 | def psnr_torch(pred, target): 8 | mse = F.mse_loss(pred, target) 9 | psnr = 10 * torch.log10(1.0 / mse) 10 | 11 | return psnr 12 | 13 | def ssim(img1, img2, window_size=11, size_average=True): 14 | (_, channel, _, _) = img1.size() 15 | window = create_window(window_size, channel) 16 | 17 | if img1.is_cuda: 18 | window = window.cuda(img1.get_device()) 19 | window = window.type_as(img1) 20 | 21 | return _ssim(img1, img2, window, window_size, channel, size_average) 22 | 23 | class SSIM(torch.nn.Module): 24 | def __init__(self, window_size=11, size_average=True): 25 | super(SSIM, self).__init__() 26 | self.window_size = window_size 27 | self.size_average = size_average 28 | self.channel = 1 29 | self.window = create_window(window_size, self.channel) 30 | 31 | def forward(self, img1, img2): 32 | (_, channel, _, _) = img1.size() 33 | if channel == self.channel and self.window.data.type() == img1.data.type(): 34 | window = self.window 35 | else: 36 | window = create_window(self.window_size, channel) 37 | 38 | if img1.is_cuda: 39 | window = window.cuda(img1.get_device()) 40 | window = window.type_as(img1) 41 | 42 | self.window = window 43 | self.channel = channel 44 | 45 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 46 | 47 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 48 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 49 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 50 | 51 | mu1_sq = mu1.pow(2) 52 | mu2_sq = mu2.pow(2) 53 | mu1_mu2 = mu1 * mu2 54 | 55 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 56 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 57 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 58 | 59 | C1 = 0.01 ** 2 60 | C2 = 0.03 ** 2 61 | 62 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 63 | 64 | if size_average: 65 | return ssim_map.mean() 66 | else: 67 | return ssim_map.mean(1).mean(1).mean(1) 68 | 69 | def gaussian(window_size, sigma): 70 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 71 | return gauss / gauss.sum() 72 | 73 | 74 | def create_window(window_size, channel): 75 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 76 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 77 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 78 | 79 | return window -------------------------------------------------------------------------------- /utils/metric/similarity.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """图像的整体评价指标,可以用于图像配准、图像修复评价""" 3 | import itertools 4 | import numbers 5 | import math 6 | 7 | # third-party modules 8 | import numpy 9 | from scipy.ndimage.filters import convolve, gaussian_filter, minimum_filter 10 | from scipy.ndimage._ni_support import _get_output 11 | from scipy.ndimage.interpolation import zoom 12 | 13 | # own modules 14 | from .utils import pad, __make_footprint 15 | 16 | 17 | def sls(minuend, subtrahend, metric = "ssd", noise = "global", signed = True, 18 | sn_size = None, sn_footprint = None, sn_mode = "reflect", sn_cval = 0.0, 19 | pn_size = None, pn_footprint = None, pn_mode = "reflect", pn_cval = 0.0): 20 | r""" 21 | Computes the signed local similarity between two images. 22 | 23 | Compares a patch around each voxel of the minuend array to a number of patches 24 | centered at the points of a search neighbourhood in the subtrahend. Thus, creates 25 | a multi-dimensional measure of patch similarity between the minuend and a 26 | corresponding search area in the subtrahend. 27 | 28 | This filter can also be used to compute local self-similarity, obtaining a 29 | descriptor similar to the one described in [1]_. 30 | 31 | Parameters 32 | ---------- 33 | minuend : array_like 34 | Input array from which to subtract the subtrahend. 35 | subtrahend : array_like 36 | Input array to subtract from the minuend. 37 | metric : {'ssd', 'mi', 'nmi', 'ncc'}, optional 38 | The `metric` parameter determines the metric used to compute the 39 | filter output. Default is 'ssd'. 40 | noise : {'global', 'local'}, optional 41 | The `noise` parameter determines how the noise is handled. If set 42 | to 'global', the variance determining the noise is a scalar, if 43 | set to 'local', it is a Gaussian smoothed field of estimated local 44 | noise. Default is 'global'. 45 | signed : bool, optional 46 | Whether the filter output should be signed or not. If set to 'False', 47 | only the absolute values will be returned. Default is 'True'. 48 | sn_size : scalar or tuple, optional 49 | See sn_footprint, below 50 | sn_footprint : array, optional 51 | The search neighbourhood. 52 | Either `sn_size` or `sn_footprint` must be defined. `sn_size` gives 53 | the shape that is taken from the input array, at every element 54 | position, to define the input to the filter function. 55 | `sn_footprint` is a boolean array that specifies (implicitly) a 56 | shape, but also which of the elements within this shape will get 57 | passed to the filter function. Thus ``sn_size=(n,m)`` is equivalent 58 | to ``sn_footprint=np.ones((n,m))``. We adjust `sn_size` to the number 59 | of dimensions of the input array, so that, if the input array is 60 | shape (10,10,10), and `sn_size` is 2, then the actual size used is 61 | (2,2,2). 62 | sn_mode : {'reflect', 'constant', 'nearest', 'mirror', 'wrap'}, optional 63 | The `sn_mode` parameter determines how the array borders are 64 | handled, where `sn_cval` is the value when mode is equal to 65 | 'constant'. Default is 'reflect' 66 | sn_cval : scalar, optional 67 | Value to fill past edges of input if `sn_mode` is 'constant'. Default 68 | is 0.0 69 | pn_size : scalar or tuple, optional 70 | See pn_footprint, below 71 | pn_footprint : array, optional 72 | The patch over which the distance measure is applied. 73 | Either `pn_size` or `pn_footprint` must be defined. `pn_size` gives 74 | the shape that is taken from the input array, at every element 75 | position, to define the input to the filter function. 76 | `pn_footprint` is a boolean array that specifies (implicitly) a 77 | shape, but also which of the elements within this shape will get 78 | passed to the filter function. Thus ``pn_size=(n,m)`` is equivalent 79 | of dimensions of the input array, so that, if the input array is 80 | shape (10,10,10), and `pn_size` is 2, then the actual size used is 81 | (2,2,2). 82 | pn_mode : {'reflect', 'constant', 'nearest', 'mirror', 'wrap'}, optional 83 | The `pn_mode` parameter determines how the array borders are 84 | handled, where `pn_cval` is the value when mode is equal to 85 | 'constant'. Default is 'reflect' 86 | pn_cval : scalar, optional 87 | Value to fill past edges of input if `pn_mode` is 'constant'. Default 88 | is 0.0 89 | 90 | Returns 91 | ------- 92 | sls : ndarray 93 | The signed local similarity image between subtrahend and minuend. 94 | 95 | References 96 | ---------- 97 | 98 | .. [1] Mattias P. Heinrich, Mark Jenkinson, Manav Bhushan, Tahreema Matin, Fergus V. Gleeson, Sir Michael Brady, Julia A. Schnabel 99 | MIND: Modality independent neighbourhood descriptor for multi-modal deformable registration 100 | Medical Image Analysis, Volume 16, Issue 7, October 2012, Pages 1423-1435, ISSN 1361-8415 101 | http://dx.doi.org/10.1016/j.media.2012.05.008 102 | """ 103 | minuend = numpy.asarray(minuend) 104 | subtrahend = numpy.asarray(subtrahend) 105 | 106 | if numpy.iscomplexobj(minuend): 107 | raise TypeError('complex type not supported') 108 | if numpy.iscomplexobj(subtrahend): 109 | raise TypeError('complex type not supported') 110 | 111 | mshape = [ii for ii in minuend.shape if ii > 0] 112 | sshape = [ii for ii in subtrahend.shape if ii > 0] 113 | if not len(mshape) == len(sshape): 114 | raise RuntimeError("minuend and subtrahend must be of same shape") 115 | if not numpy.all([sm == ss for sm, ss in zip(mshape, sshape)]): 116 | raise RuntimeError("minuend and subtrahend must be of same shape") 117 | 118 | sn_footprint = __make_footprint(minuend, sn_size, sn_footprint) 119 | sn_fshape = [ii for ii in sn_footprint.shape if ii > 0] 120 | if len(sn_fshape) != minuend.ndim: 121 | raise RuntimeError('search neighbourhood footprint array has incorrect shape.') 122 | 123 | #!TODO: Is this required? 124 | if not sn_footprint.flags.contiguous: 125 | sn_footprint = sn_footprint.copy() 126 | 127 | # created a padded copy of the subtrahend, whereas the padding mode is always 'reflect' 128 | subtrahend = pad(subtrahend, footprint=sn_footprint, mode=sn_mode, cval=sn_cval) 129 | 130 | # compute slicers for position where the search neighbourhood sn_footprint is TRUE 131 | slicers = [[slice(x, (x + 1) - d if 0 != (x + 1) - d else None) for x in range(d)] for d in sn_fshape] 132 | slicers = [sl for sl, tv in zip(itertools.product(*slicers), sn_footprint.flat) if tv] 133 | 134 | # compute difference images and sign images for search neighbourhood elements 135 | ssds = [ssd(minuend, subtrahend[slicer], normalized=True, signed=signed, size=pn_size, footprint=pn_footprint, mode=pn_mode, cval=pn_cval) for slicer in slicers] 136 | distance = [x[0] for x in ssds] 137 | distance_sign = [x[1] for x in ssds] 138 | 139 | # compute local variance, which constitutes an approximation of local noise, out of patch-distances over the neighbourhood structure 140 | variance = numpy.average(distance, 0) 141 | variance = gaussian_filter(variance, sigma=3) #!TODO: Figure out if a fixed sigma is desirable here... I think that yes 142 | if 'global' == noise: 143 | variance = variance.sum() / float(numpy.product(variance.shape)) 144 | # variance[variance < variance_global / 10.] = variance_global / 10. #!TODO: Should I keep this i.e. regularizing the variance to be at least 10% of the global one? 145 | 146 | # compute sls 147 | sls = [dist_sign * numpy.exp(-1 * (dist / variance)) for dist_sign, dist in zip(distance_sign, distance)] 148 | 149 | # convert into sls image, swapping dimensions to have varying patches in the last dimension 150 | return numpy.rollaxis(numpy.asarray(sls), 0, minuend.ndim + 1) 151 | 152 | 153 | def ssd(minuend, subtrahend, normalized=True, signed=False, size=None, footprint=None, mode="reflect", cval=0.0, origin=0): 154 | r""" 155 | Computes the sum of squared difference (SSD) between patches of minuend and subtrahend. 156 | 157 | Parameters 158 | ---------- 159 | minuend : array_like 160 | Input array from which to subtract the subtrahend. 161 | subtrahend : array_like 162 | Input array to subtract from the minuend. 163 | normalized : bool, optional 164 | Whether the SSD of each patch should be divided through the filter size for 165 | normalization. Default is 'True'. 166 | signed : bool, optional 167 | Whether the accumulative sign of each patch should be returned as well. If 168 | 'True', the second return value is a numpy.sign array, otherwise the scalar '1'. 169 | Default is 'False'. 170 | size : scalar or tuple, optional 171 | See footprint, below 172 | footprint : array, optional 173 | The patch over which to compute the SSD. 174 | Either `size` or `footprint` must be defined. `size` gives 175 | the shape that is taken from the input array, at every element 176 | position, to define the input to the filter function. 177 | `footprint` is a boolean array that specifies (implicitly) a 178 | shape, but also which of the elements within this shape will get 179 | passed to the filter function. Thus ``size=(n,m)`` is equivalent 180 | to ``footprint=np.ones((n,m))``. We adjust `size` to the number 181 | of dimensions of the input array, so that, if the input array is 182 | shape (10,10,10), and `size` is 2, then the actual size used is 183 | (2,2,2). 184 | mode : {'reflect', 'constant', 'nearest', 'mirror', 'wrap'}, optional 185 | The `mode` parameter determines how the array borders are 186 | handled, where `cval` is the value when mode is equal to 187 | 'constant'. Default is 'reflect' 188 | cval : scalar, optional 189 | Value to fill past edges of input if `mode` is 'constant'. Default 190 | is 0.0 191 | 192 | Returns 193 | ------- 194 | ssd : ndarray 195 | The patchwise sum of squared differences between minuend and subtrahend. 196 | """ 197 | convolution_filter = average_filter if normalized else sum_filter 198 | output = numpy.float if normalized else minuend.dtype 199 | 200 | if signed: 201 | difference = minuend - subtrahend 202 | difference_squared = numpy.square(difference) 203 | distance_sign = numpy.sign(convolution_filter(numpy.sign(difference) * difference_squared, size=size, footprint=footprint, mode=mode, cval=cval, origin=origin, output=output)) 204 | distance = convolution_filter(difference_squared, size=size, footprint=footprint, mode=mode, cval=cval, output=output) 205 | else: 206 | distance = convolution_filter(numpy.square(minuend - subtrahend), size=size, footprint=footprint, mode=mode, cval=cval, origin=origin, output=output) 207 | distance_sign = 1 208 | 209 | return distance, distance_sign 210 | 211 | 212 | def average_filter(input, size=None, footprint=None, output=None, mode="reflect", cval=0.0, origin=0): 213 | r""" 214 | Calculates a multi-dimensional average filter. 215 | 216 | Parameters 217 | ---------- 218 | input : array-like 219 | input array to filter 220 | size : scalar or tuple, optional 221 | See footprint, below 222 | footprint : array, optional 223 | Either `size` or `footprint` must be defined. `size` gives 224 | the shape that is taken from the input array, at every element 225 | position, to define the input to the filter function. 226 | `footprint` is a boolean array that specifies (implicitly) a 227 | shape, but also which of the elements within this shape will get 228 | passed to the filter function. Thus ``size=(n,m)`` is equivalent 229 | to ``footprint=np.ones((n,m))``. We adjust `size` to the number 230 | of dimensions of the input array, so that, if the input array is 231 | shape (10,10,10), and `size` is 2, then the actual size used is 232 | (2,2,2). 233 | output : array, optional 234 | The ``output`` parameter passes an array in which to store the 235 | filter output. 236 | mode : {'reflect','constant','nearest','mirror', 'wrap'}, optional 237 | The ``mode`` parameter determines how the array borders are 238 | handled, where ``cval`` is the value when mode is equal to 239 | 'constant'. Default is 'reflect' 240 | cval : scalar, optional 241 | Value to fill past edges of input if ``mode`` is 'constant'. Default 242 | is 0.0 243 | origin : scalar, optional 244 | The ``origin`` parameter controls the placement of the filter. 245 | Default 0 246 | 247 | Returns 248 | ------- 249 | average_filter : ndarray 250 | Returned array of same shape as `input`. 251 | 252 | Notes 253 | ----- 254 | Convenience implementation employing convolve. 255 | 256 | See Also 257 | -------- 258 | scipy.ndimage.filters.convolve : Convolve an image with a kernel. 259 | """ 260 | footprint = __make_footprint(input, size, footprint) 261 | filter_size = footprint.sum() 262 | 263 | output = _get_output(output, input) 264 | sum_filter(input, footprint=footprint, output=output, mode=mode, cval=cval, origin=origin) 265 | output /= filter_size 266 | 267 | return output 268 | 269 | def sum_filter(input, size=None, footprint=None, output=None, mode="reflect", cval=0.0, origin=0): 270 | r""" 271 | Calculates a multi-dimensional sum filter. 272 | 273 | Parameters 274 | ---------- 275 | input : array-like 276 | input array to filter 277 | size : scalar or tuple, optional 278 | See footprint, below 279 | footprint : array, optional 280 | Either `size` or `footprint` must be defined. `size` gives 281 | the shape that is taken from the input array, at every element 282 | position, to define the input to the filter function. 283 | `footprint` is a boolean array that specifies (implicitly) a 284 | shape, but also which of the elements within this shape will get 285 | passed to the filter function. Thus ``size=(n,m)`` is equivalent 286 | to ``footprint=np.ones((n,m))``. We adjust `size` to the number 287 | of dimensions of the input array, so that, if the input array is 288 | shape (10,10,10), and `size` is 2, then the actual size used is 289 | (2,2,2). 290 | output : array, optional 291 | The ``output`` parameter passes an array in which to store the 292 | filter output. 293 | mode : {'reflect','constant','nearest','mirror', 'wrap'}, optional 294 | The ``mode`` parameter determines how the array borders are 295 | handled, where ``cval`` is the value when mode is equal to 296 | 'constant'. Default is 'reflect' 297 | cval : scalar, optional 298 | Value to fill past edges of input if ``mode`` is 'constant'. Default 299 | is 0.0 300 | origin : scalar, optional 301 | The ``origin`` parameter controls the placement of the filter. 302 | Default 0 303 | 304 | Returns 305 | ------- 306 | sum_filter : ndarray 307 | Returned array of same shape as `input`. 308 | 309 | Notes 310 | ----- 311 | Convenience implementation employing convolve. 312 | 313 | See Also 314 | -------- 315 | scipy.ndimage.filters.convolve : Convolve an image with a kernel. 316 | """ 317 | footprint = __make_footprint(input, size, footprint) 318 | slicer = [slice(None, None, -1)] * footprint.ndim 319 | return convolve(input, footprint[slicer], output, mode, cval, origin) -------------------------------------------------------------------------------- /utils/metric/utils.py: -------------------------------------------------------------------------------- 1 | #encoding: utf-8 2 | import numpy 3 | from scipy.ndimage import _ni_support 4 | 5 | 6 | #!TODO: Utilise the numpy.pad function that is available since 1.7.0. The numpy version should go inside this function, since it does not support the supplying of a template/footprint on its own. 7 | def pad(input, size=None, footprint=None, output=None, mode="reflect", cval=0.0): 8 | r""" 9 | Returns a copy of the input, padded by the supplied structuring element. 10 | 11 | In the case of odd dimensionality, the structure element will be centered as 12 | following on the currently processed position:: 13 | 14 | [[T, Tx, T], 15 | [T, T , T]] 16 | 17 | , where Tx denotes the center of the structure element. 18 | 19 | Simulates the behaviour of scipy.ndimage filters. 20 | 21 | Parameters 22 | ---------- 23 | input : array_like 24 | Input array to pad. 25 | size : scalar or tuple, optional 26 | See footprint, below 27 | footprint : array, optional 28 | Either `size` or `footprint` must be defined. `size` gives 29 | the shape that is taken from the input array, at every element 30 | position, to define the input to the filter function. 31 | `footprint` is a boolean array that specifies (implicitly) a 32 | shape, but also which of the elements within this shape will get 33 | passed to the filter function. Thus ``size=(n,m)`` is equivalent 34 | to ``footprint=np.ones((n,m))``. We adjust `size` to the number 35 | of dimensions of the input array, so that, if the input array is 36 | shape (10,10,10), and `size` is 2, then the actual size used is 37 | (2,2,2). 38 | output : array, optional 39 | The `output` parameter passes an array in which to store the 40 | filter output. 41 | mode : {'reflect', 'constant', 'nearest', 'mirror', 'wrap'}, optional 42 | The `mode` parameter determines how the array borders are 43 | handled, where `cval` is the value when mode is equal to 44 | 'constant'. Default is 'reflect'. 45 | cval : scalar, optional 46 | Value to fill past edges of input if `mode` is 'constant'. Default 47 | is 0.0 48 | 49 | Returns 50 | ------- 51 | output : ndarray 52 | The padded version of the input image. 53 | 54 | Notes 55 | ----- 56 | Since version 1.7.0, numpy supplied a pad function `numpy.pad` that provides 57 | the same functionality and should be preferred. 58 | 59 | Raises 60 | ------ 61 | ValueError 62 | If the provided footprint/size is more than double the image size. 63 | """ 64 | input = numpy.asarray(input) 65 | if footprint is None: 66 | if size is None: 67 | raise RuntimeError("no footprint or filter size provided") 68 | sizes = _ni_support._normalize_sequence(size, input.ndim) 69 | footprint = numpy.ones(sizes, dtype=bool) 70 | else: 71 | footprint = numpy.asarray(footprint, dtype=bool) 72 | fshape = [ii for ii in footprint.shape if ii > 0] 73 | if len(fshape) != input.ndim: 74 | raise RuntimeError('filter footprint array has incorrect shape.') 75 | 76 | if numpy.any([x > 2*y for x, y in zip(footprint.shape, input.shape)]): 77 | raise ValueError('The size of the padding element is not allowed to be more than double the size of the input array in any dimension.') 78 | 79 | padding_offset = [((s - 1) / 2, s / 2) for s in fshape] 80 | input_slicer = [slice(l, None if 0 == r else -1 * r) for l, r in padding_offset] 81 | output_shape = [s + sum(os) for s, os in zip(input.shape, padding_offset)] 82 | output = _ni_support._get_output(output, input, output_shape) 83 | 84 | if 'constant' == mode: 85 | output += cval 86 | output[input_slicer] = input 87 | return output 88 | elif 'nearest' == mode: 89 | output[input_slicer] = input 90 | dim_mult_slices = [(d, l, slice(None, l), slice(l, l + 1)) for d, (l, _) in zip(list(range(output.ndim)), padding_offset) if not 0 == l] 91 | dim_mult_slices.extend([(d, r, slice(-1 * r, None), slice(-2 * r, -2 * r + 1)) for d, (_, r) in zip(list(range(output.ndim)), padding_offset) if not 0 == r]) 92 | for dim, mult, to_slice, from_slice in dim_mult_slices: 93 | slicer_to = [to_slice if d == dim else slice(None) for d in range(output.ndim)] 94 | slicer_from = [from_slice if d == dim else slice(None) for d in range(output.ndim)] 95 | if not 0 == mult: 96 | output[slicer_to] = numpy.concatenate([output[slicer_from]] * mult, dim) 97 | return output 98 | elif 'mirror' == mode: 99 | dim_slices = [(d, slice(None, l), slice(l + 1, 2 * l + 1)) for d, (l, _) in zip(list(range(output.ndim)), padding_offset) if not 0 == l] 100 | dim_slices.extend([(d, slice(-1 * r, None), slice(-2 * r - 1, -1 * r - 1)) for d, (_, r) in zip(list(range(output.ndim)), padding_offset) if not 0 == r]) 101 | reverse_slice = slice(None, None, -1) 102 | elif 'reflect' == mode: 103 | dim_slices = [(d, slice(None, l), slice(l, 2 * l)) for d, (l, _) in zip(list(range(output.ndim)), padding_offset) if not 0 == l] 104 | dim_slices.extend([(d, slice(-1 * r, None), slice(-2 * r, -1 * r)) for d, (_, r) in zip(list(range(output.ndim)), padding_offset) if not 0 == r]) 105 | reverse_slice = slice(None, None, -1) 106 | elif 'wrap' == mode: 107 | dim_slices = [(d, slice(None, l), slice(-1 * (l + r), -1 * r if not 0 == r else None)) for d, (l, r) in zip(list(range(output.ndim)), padding_offset) if not 0 == l] 108 | dim_slices.extend([(d, slice(-1 * r, None), slice(l, r + l)) for d, (l, r) in zip(list(range(output.ndim)), padding_offset) if not 0 == r]) 109 | reverse_slice = slice(None) 110 | else: 111 | raise RuntimeError('boundary mode not supported') 112 | 113 | output[input_slicer] = input 114 | for dim, to_slice, from_slice in dim_slices: 115 | slicer_reverse = [reverse_slice if d == dim else slice(None) for d in range(output.ndim)] 116 | slicer_to = [to_slice if d == dim else slice(None) for d in range(output.ndim)] 117 | slicer_from = [from_slice if d == dim else slice(None) for d in range(output.ndim)] 118 | output[slicer_to] = output[slicer_from][slicer_reverse] 119 | 120 | return output 121 | 122 | 123 | def __make_footprint(input, size, footprint): 124 | "Creates a standard footprint element ala scipy.ndimage." 125 | if footprint is None: 126 | if size is None: 127 | raise RuntimeError("no footprint or filter size provided") 128 | sizes = _ni_support._normalize_sequence(size, input.ndim) 129 | footprint = numpy.ones(sizes, dtype=bool) 130 | else: 131 | footprint = numpy.asarray(footprint, dtype=bool) 132 | return footprint 133 | 134 | class ArgumentError(Exception): 135 | r"""Thrown by an application when an invalid command line argument has been supplied. 136 | """ 137 | pass -------------------------------------------------------------------------------- /utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | 5 | import torch 6 | 7 | 8 | class SmoothedValue(object): 9 | """Track a series of values and provide access to smoothed values over a 10 | window or the global series average. 11 | """ 12 | 13 | def __init__(self, window_size=20): 14 | self.deque = deque(maxlen=window_size) 15 | self.series = [] 16 | self.total = 0.0 17 | self.count = 0 18 | 19 | def update(self, value): 20 | self.deque.append(value) 21 | self.series.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def avg(self): 32 | d = torch.tensor(list(self.deque)) 33 | return d.mean().item() 34 | 35 | @property 36 | def global_avg(self): 37 | return self.total / self.count 38 | 39 | 40 | class MetricLogger(object): 41 | def __init__(self, delimiter="\t"): 42 | self.meters = defaultdict(SmoothedValue) 43 | self.delimiter = delimiter 44 | 45 | def update(self, **kwargs): 46 | for k, v in kwargs.items(): 47 | if isinstance(v, torch.Tensor): 48 | v = v.item() 49 | assert isinstance(v, (float, int)) 50 | self.meters[k].update(v) 51 | 52 | def __getattr__(self, attr): 53 | if attr in self.meters: 54 | return self.meters[attr] 55 | if attr in self.__dict__: 56 | return self.__dict__[attr] 57 | raise AttributeError("'{}' object has no attribute '{}'".format( 58 | type(self).__name__, attr)) 59 | 60 | def __str__(self): 61 | loss_str = [] 62 | for name, meter in sorted(self.meters.items()): 63 | loss_str.append( 64 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 65 | ) 66 | return self.delimiter.join(loss_str) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2019/5/9 16:06 4 | # @Author : Eric Ching 5 | import os 6 | import random 7 | import torch 8 | import warnings 9 | import numpy as np 10 | 11 | def init_env(gpu_id='0', seed=42): 12 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | torch.backends.cudnn.benchmark = True 18 | warnings.filterwarnings('ignore') 19 | 20 | def mkdir(path): 21 | if not os.path.exists(path): 22 | os.makedirs(path) 23 | else: 24 | print('exist path: ', path) 25 | 26 | --------------------------------------------------------------------------------