├── community ├── weights │ ├── v0.0 │ │ ├── vgg.pth │ │ ├── alex.pth │ │ └── squeeze.pth │ └── v0.1 │ │ ├── vgg.pth │ │ ├── alex.pth │ │ └── squeeze.pth ├── fid.py └── lpips.py ├── LICENSE ├── README.md └── metrics.py /community/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huster-wgm/Pytorch-metrics/HEAD/community/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /community/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huster-wgm/Pytorch-metrics/HEAD/community/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /community/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huster-wgm/Pytorch-metrics/HEAD/community/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /community/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huster-wgm/Pytorch-metrics/HEAD/community/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /community/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huster-wgm/Pytorch-metrics/HEAD/community/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /community/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huster-wgm/Pytorch-metrics/HEAD/community/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 HIROAKI GO 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-metrics 2 | 3 | This is a repo. for evaluation metrics using Pytorch. 4 | The metrics.py is designed for evaluation tasks using two pytorch tensors as input. 5 | All implemented metric is compatible with any batch_size and devices(CPU or GPU). 6 | 7 | ``` 8 | y_pred << 4D tensor in [batch_size, channels, img_rows, img_cols] 9 | y_true << 4D tensor in [batch_size, channels, img_rows, img_cols] 10 | 11 | metric = MSE() 12 | acc = metric(y_pred, y_true).item() 13 | print("{} ==> {}".format(repr(metric), acc)) 14 | ``` 15 | 16 | ## Requirement 17 | - python3 18 | - pytorch >= 1. 19 | - torchvision >= 0.2.0 20 | 21 | ## Implementation 22 | 23 | - Image similarity 24 | * AE (Average Angular Error) 25 | * MSE (Mean Square Error) 26 | * PSNR (Peak Signal-to-Noise Ratio) 27 | * SSIM (Structural Similarity) 28 | * LPIPS (Learned Perceptual Image Patch Similarity) 29 | 30 | - Accuray 31 | * OA(Overall Accuracy) 32 | * Precision 33 | * Recall 34 | * F1-score 35 | * Kapp coefficiency 36 | * Jaccard Index 37 | 38 | ## Ongoing 39 | - FID(Fréchet Inception Distance) 40 | 41 | ## Acknowledgment 42 | Our implementations are largely inspired by many open-sources codes, repos, as well as papers. 43 | Many thanks to the authors. 44 | * Richard Zhang, LPIPS(https://github.com/richzhang/PerceptualSimilarity) 45 | * Jorge Pessoa, SSIM(https://github.com/jorge-pessoa/pytorch-msssim) 46 | 47 | ## LICENSE 48 | This implementation is licensed under the MIT License. 49 | 50 | -------------------------------------------------------------------------------- /community/fid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | """ 4 | @Email: guangmingwu2010@gmail.com \ 5 | guozhilingty@gmail.com 6 | @Copyright: go-hiroaki & Chokurei 7 | @License: MIT 8 | mainly borrowed from https://github.com/mseitzer/pytorch-fid/blob/master/inception.py 9 | """ 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torchvision import models 14 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 15 | 16 | 17 | # Inception weights ported to Pytorch from 18 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 19 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 20 | 21 | 22 | class InceptionV3(nn.Module): 23 | """Pretrained InceptionV3 network returning feature maps""" 24 | 25 | def __init__(self, 26 | output_blocks=[3], 27 | resize_input=True, 28 | normalize_input=True, 29 | requires_grad=False, 30 | use_fid_inception=True): 31 | """Build pretrained InceptionV3 32 | 33 | Parameters 34 | ---------- 35 | output_blocks : list of int 36 | Indices of blocks to return features of. Possible values are: 37 | - 0: corresponds to output of first max pooling 38 | - 1: corresponds to output of second max pooling 39 | - 2: corresponds to output which is fed to aux classifier 40 | - 3: corresponds to output of final average pooling 41 | resize_input : bool 42 | If true, bilinearly resizes input to width and height 299 before 43 | feeding input to model. As the network without fully connected 44 | layers is fully convolutional, it should be able to handle inputs 45 | of arbitrary size, so resizing might not be strictly needed 46 | normalize_input : bool 47 | If true, scales the input from range (0, 1) to the range the 48 | pretrained Inception network expects, namely (-1, 1) 49 | requires_grad : bool 50 | If true, parameters of the model require gradients. Possibly useful 51 | for finetuning the network 52 | use_fid_inception : bool 53 | If true, uses the pretrained Inception model used in Tensorflow's 54 | FID implementation. If false, uses the pretrained Inception model 55 | available in torchvision. The FID Inception model has different 56 | weights and a slightly different structure from torchvision's 57 | Inception model. If you want to compute FID scores, you are 58 | strongly advised to set this parameter to true to get comparable 59 | results. 60 | """ 61 | super(InceptionV3, self).__init__() 62 | 63 | self.resize_input = resize_input 64 | self.normalize_input = normalize_input 65 | self.output_blocks = sorted(output_blocks) 66 | self.last_needed_block = max(output_blocks) 67 | 68 | assert self.last_needed_block <= 3, \ 69 | 'Last possible output block index is 3' 70 | 71 | self.blocks = nn.ModuleList() 72 | 73 | if use_fid_inception: 74 | inception = fid_inception_v3() 75 | else: 76 | inception = models.inception_v3(pretrained=True) 77 | 78 | # Block 0: input to maxpool1 79 | block0 = [ 80 | inception.Conv2d_1a_3x3, 81 | inception.Conv2d_2a_3x3, 82 | inception.Conv2d_2b_3x3, 83 | nn.MaxPool2d(kernel_size=3, stride=2) 84 | ] 85 | self.blocks.append(nn.Sequential(*block0)) 86 | 87 | # Block 1: maxpool1 to maxpool2 88 | if self.last_needed_block >= 1: 89 | block1 = [ 90 | inception.Conv2d_3b_1x1, 91 | inception.Conv2d_4a_3x3, 92 | nn.MaxPool2d(kernel_size=3, stride=2) 93 | ] 94 | self.blocks.append(nn.Sequential(*block1)) 95 | 96 | # Block 2: maxpool2 to aux classifier 97 | if self.last_needed_block >= 2: 98 | block2 = [ 99 | inception.Mixed_5b, 100 | inception.Mixed_5c, 101 | inception.Mixed_5d, 102 | inception.Mixed_6a, 103 | inception.Mixed_6b, 104 | inception.Mixed_6c, 105 | inception.Mixed_6d, 106 | inception.Mixed_6e, 107 | ] 108 | self.blocks.append(nn.Sequential(*block2)) 109 | 110 | # Block 3: aux classifier to final avgpool 111 | if self.last_needed_block >= 3: 112 | block3 = [ 113 | inception.Mixed_7a, 114 | inception.Mixed_7b, 115 | inception.Mixed_7c, 116 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 117 | ] 118 | self.blocks.append(nn.Sequential(*block3)) 119 | 120 | for param in self.parameters(): 121 | param.requires_grad = requires_grad 122 | 123 | def forward(self, inp): 124 | """Get Inception feature maps 125 | 126 | Parameters 127 | ---------- 128 | inp : torch.autograd.Variable 129 | Input tensor of shape Bx3xHxW. Values are expected to be in 130 | range (0, 1) 131 | 132 | Returns 133 | ------- 134 | List of torch.autograd.Variable, corresponding to the selected output 135 | block, sorted ascending by index 136 | """ 137 | outp = [] 138 | x = inp 139 | 140 | if self.resize_input: 141 | x = F.interpolate(x, 142 | size=(299, 299), 143 | mode='bilinear', 144 | align_corners=False) 145 | 146 | if self.normalize_input: 147 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 148 | 149 | for idx, block in enumerate(self.blocks): 150 | x = block(x) 151 | if idx in self.output_blocks: 152 | outp.append(x) 153 | 154 | if idx == self.last_needed_block: 155 | break 156 | 157 | return outp 158 | 159 | 160 | def fid_inception_v3(): 161 | """Build pretrained Inception model for FID computation 162 | 163 | The Inception model for FID computation uses a different set of weights 164 | and has a slightly different structure than torchvision's Inception. 165 | 166 | This method first constructs torchvision's Inception and then patches the 167 | necessary parts that are different in the FID Inception model. 168 | """ 169 | inception = models.inception_v3(num_classes=1008, 170 | aux_logits=False, 171 | pretrained=False) 172 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 173 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 174 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 175 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 176 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 177 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 178 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 179 | inception.Mixed_7b = FIDInceptionE_1(1280) 180 | inception.Mixed_7c = FIDInceptionE_2(2048) 181 | 182 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 183 | inception.load_state_dict(state_dict) 184 | return inception 185 | 186 | 187 | class FIDInceptionA(models.inception.InceptionA): 188 | """InceptionA block patched for FID computation""" 189 | def __init__(self, in_channels, pool_features): 190 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 191 | 192 | def forward(self, x): 193 | branch1x1 = self.branch1x1(x) 194 | 195 | branch5x5 = self.branch5x5_1(x) 196 | branch5x5 = self.branch5x5_2(branch5x5) 197 | 198 | branch3x3dbl = self.branch3x3dbl_1(x) 199 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 200 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 201 | 202 | # Patch: Tensorflow's average pool does not use the padded zero's in 203 | # its average calculation 204 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 205 | count_include_pad=False) 206 | branch_pool = self.branch_pool(branch_pool) 207 | 208 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 209 | return torch.cat(outputs, 1) 210 | 211 | 212 | class FIDInceptionC(models.inception.InceptionC): 213 | """InceptionC block patched for FID computation""" 214 | def __init__(self, in_channels, channels_7x7): 215 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 216 | 217 | def forward(self, x): 218 | branch1x1 = self.branch1x1(x) 219 | 220 | branch7x7 = self.branch7x7_1(x) 221 | branch7x7 = self.branch7x7_2(branch7x7) 222 | branch7x7 = self.branch7x7_3(branch7x7) 223 | 224 | branch7x7dbl = self.branch7x7dbl_1(x) 225 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 226 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 227 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 228 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 229 | 230 | # Patch: Tensorflow's average pool does not use the padded zero's in 231 | # its average calculation 232 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 233 | count_include_pad=False) 234 | branch_pool = self.branch_pool(branch_pool) 235 | 236 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 237 | return torch.cat(outputs, 1) 238 | 239 | 240 | class FIDInceptionE_1(models.inception.InceptionE): 241 | """First InceptionE block patched for FID computation""" 242 | def __init__(self, in_channels): 243 | super(FIDInceptionE_1, self).__init__(in_channels) 244 | 245 | def forward(self, x): 246 | branch1x1 = self.branch1x1(x) 247 | 248 | branch3x3 = self.branch3x3_1(x) 249 | branch3x3 = [ 250 | self.branch3x3_2a(branch3x3), 251 | self.branch3x3_2b(branch3x3), 252 | ] 253 | branch3x3 = torch.cat(branch3x3, 1) 254 | 255 | branch3x3dbl = self.branch3x3dbl_1(x) 256 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 257 | branch3x3dbl = [ 258 | self.branch3x3dbl_3a(branch3x3dbl), 259 | self.branch3x3dbl_3b(branch3x3dbl), 260 | ] 261 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 262 | 263 | # Patch: Tensorflow's average pool does not use the padded zero's in 264 | # its average calculation 265 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 266 | count_include_pad=False) 267 | branch_pool = self.branch_pool(branch_pool) 268 | 269 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 270 | return torch.cat(outputs, 1) 271 | 272 | 273 | class FIDInceptionE_2(models.inception.InceptionE): 274 | """Second InceptionE block patched for FID computation""" 275 | def __init__(self, in_channels): 276 | super(FIDInceptionE_2, self).__init__(in_channels) 277 | 278 | def forward(self, x): 279 | branch1x1 = self.branch1x1(x) 280 | 281 | branch3x3 = self.branch3x3_1(x) 282 | branch3x3 = [ 283 | self.branch3x3_2a(branch3x3), 284 | self.branch3x3_2b(branch3x3), 285 | ] 286 | branch3x3 = torch.cat(branch3x3, 1) 287 | 288 | branch3x3dbl = self.branch3x3dbl_1(x) 289 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 290 | branch3x3dbl = [ 291 | self.branch3x3dbl_3a(branch3x3dbl), 292 | self.branch3x3dbl_3b(branch3x3dbl), 293 | ] 294 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 295 | 296 | # Patch: The FID Inception model uses max pooling instead of average 297 | # pooling. This is likely an error in this specific Inception 298 | # implementation, as other Inception models use average pooling here 299 | # (which matches the description in the paper). 300 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 301 | branch_pool = self.branch_pool(branch_pool) 302 | 303 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 304 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /community/lpips.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | """ 4 | @Email: guangmingwu2010@gmail.com \ 5 | @Copyright: go-hiroaki 6 | @License: MIT 7 | Modified from https://github.com/richzhang/PerceptualSimilarity 8 | """ 9 | from __future__ import absolute_import 10 | 11 | import os 12 | import torch 13 | import torch.nn as nn 14 | from collections import namedtuple 15 | from torchvision import models as tv 16 | 17 | 18 | class squeezenet(torch.nn.Module): 19 | def __init__(self, requires_grad=False, pretrained=True): 20 | super(squeezenet, self).__init__() 21 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 22 | self.slice1 = torch.nn.Sequential() 23 | self.slice2 = torch.nn.Sequential() 24 | self.slice3 = torch.nn.Sequential() 25 | self.slice4 = torch.nn.Sequential() 26 | self.slice5 = torch.nn.Sequential() 27 | self.slice6 = torch.nn.Sequential() 28 | self.slice7 = torch.nn.Sequential() 29 | self.N_slices = 7 30 | for x in range(2): 31 | self.slice1.add_module(str(x), pretrained_features[x]) 32 | for x in range(2,5): 33 | self.slice2.add_module(str(x), pretrained_features[x]) 34 | for x in range(5, 8): 35 | self.slice3.add_module(str(x), pretrained_features[x]) 36 | for x in range(8, 10): 37 | self.slice4.add_module(str(x), pretrained_features[x]) 38 | for x in range(10, 11): 39 | self.slice5.add_module(str(x), pretrained_features[x]) 40 | for x in range(11, 12): 41 | self.slice6.add_module(str(x), pretrained_features[x]) 42 | for x in range(12, 13): 43 | self.slice7.add_module(str(x), pretrained_features[x]) 44 | if not requires_grad: 45 | for param in self.parameters(): 46 | param.requires_grad = False 47 | 48 | def forward(self, X): 49 | h = self.slice1(X) 50 | h_relu1 = h 51 | h = self.slice2(h) 52 | h_relu2 = h 53 | h = self.slice3(h) 54 | h_relu3 = h 55 | h = self.slice4(h) 56 | h_relu4 = h 57 | h = self.slice5(h) 58 | h_relu5 = h 59 | h = self.slice6(h) 60 | h_relu6 = h 61 | h = self.slice7(h) 62 | h_relu7 = h 63 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 64 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 65 | 66 | return out 67 | 68 | 69 | class alexnet(torch.nn.Module): 70 | def __init__(self, requires_grad=False, pretrained=True): 71 | super(alexnet, self).__init__() 72 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 73 | self.slice1 = torch.nn.Sequential() 74 | self.slice2 = torch.nn.Sequential() 75 | self.slice3 = torch.nn.Sequential() 76 | self.slice4 = torch.nn.Sequential() 77 | self.slice5 = torch.nn.Sequential() 78 | self.N_slices = 5 79 | for x in range(2): 80 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 81 | for x in range(2, 5): 82 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 83 | for x in range(5, 8): 84 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 85 | for x in range(8, 10): 86 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 87 | for x in range(10, 12): 88 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 89 | if not requires_grad: 90 | for param in self.parameters(): 91 | param.requires_grad = False 92 | 93 | def forward(self, X): 94 | h = self.slice1(X) 95 | h_relu1 = h 96 | h = self.slice2(h) 97 | h_relu2 = h 98 | h = self.slice3(h) 99 | h_relu3 = h 100 | h = self.slice4(h) 101 | h_relu4 = h 102 | h = self.slice5(h) 103 | h_relu5 = h 104 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 105 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 106 | 107 | return out 108 | 109 | class vgg16(torch.nn.Module): 110 | def __init__(self, requires_grad=False, pretrained=True): 111 | super(vgg16, self).__init__() 112 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 113 | self.slice1 = torch.nn.Sequential() 114 | self.slice2 = torch.nn.Sequential() 115 | self.slice3 = torch.nn.Sequential() 116 | self.slice4 = torch.nn.Sequential() 117 | self.slice5 = torch.nn.Sequential() 118 | self.N_slices = 5 119 | for x in range(4): 120 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 121 | for x in range(4, 9): 122 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 123 | for x in range(9, 16): 124 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 125 | for x in range(16, 23): 126 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 127 | for x in range(23, 30): 128 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 129 | if not requires_grad: 130 | for param in self.parameters(): 131 | param.requires_grad = False 132 | 133 | def forward(self, X): 134 | h = self.slice1(X) 135 | h_relu1_2 = h 136 | h = self.slice2(h) 137 | h_relu2_2 = h 138 | h = self.slice3(h) 139 | h_relu3_3 = h 140 | h = self.slice4(h) 141 | h_relu4_3 = h 142 | h = self.slice5(h) 143 | h_relu5_3 = h 144 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 145 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 146 | 147 | return out 148 | 149 | 150 | 151 | class resnet(torch.nn.Module): 152 | def __init__(self, requires_grad=False, pretrained=True, num=18): 153 | super(resnet, self).__init__() 154 | if(num==18): 155 | self.net = tv.resnet18(pretrained=pretrained) 156 | elif(num==34): 157 | self.net = tv.resnet34(pretrained=pretrained) 158 | elif(num==50): 159 | self.net = tv.resnet50(pretrained=pretrained) 160 | elif(num==101): 161 | self.net = tv.resnet101(pretrained=pretrained) 162 | elif(num==152): 163 | self.net = tv.resnet152(pretrained=pretrained) 164 | self.N_slices = 5 165 | 166 | self.conv1 = self.net.conv1 167 | self.bn1 = self.net.bn1 168 | self.relu = self.net.relu 169 | self.maxpool = self.net.maxpool 170 | self.layer1 = self.net.layer1 171 | self.layer2 = self.net.layer2 172 | self.layer3 = self.net.layer3 173 | self.layer4 = self.net.layer4 174 | 175 | def forward(self, X): 176 | h = self.conv1(X) 177 | h = self.bn1(h) 178 | h = self.relu(h) 179 | h_relu1 = h 180 | h = self.maxpool(h) 181 | h = self.layer1(h) 182 | h_conv2 = h 183 | h = self.layer2(h) 184 | h_conv3 = h 185 | h = self.layer3(h) 186 | h_conv4 = h 187 | h = self.layer4(h) 188 | h_conv5 = h 189 | 190 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 191 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 192 | 193 | return out 194 | 195 | 196 | def spatial_average(in_tens, keepdim=True): 197 | return in_tens.mean([2,3],keepdim=keepdim) 198 | 199 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 200 | in_H = in_tens.shape[2] 201 | scale_factor = 1.*out_H/in_H 202 | 203 | return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) 204 | 205 | class ScalingLayer(nn.Module): 206 | def __init__(self): 207 | super(ScalingLayer, self).__init__() 208 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 209 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 210 | 211 | def forward(self, inp): 212 | return (inp - self.shift) / self.scale 213 | 214 | class NetLinLayer(nn.Module): 215 | ''' A single linear layer which does a 1x1 conv ''' 216 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 217 | super(NetLinLayer, self).__init__() 218 | 219 | layers = [nn.Dropout(),] if(use_dropout) else [] 220 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 221 | self.model = nn.Sequential(*layers) 222 | 223 | # Learned perceptual metric 224 | class PNetLin(nn.Module): 225 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): 226 | super(PNetLin, self).__init__() 227 | 228 | self.pnet_type = pnet_type 229 | self.pnet_tune = pnet_tune 230 | self.pnet_rand = pnet_rand 231 | self.spatial = spatial 232 | self.lpips = lpips 233 | self.version = version 234 | self.scaling_layer = ScalingLayer() 235 | 236 | if(self.pnet_type in ['vgg','vgg16']): 237 | net_type = vgg16 238 | self.chns = [64,128,256,512,512] 239 | elif(self.pnet_type=='alex'): 240 | net_type = alexnet 241 | self.chns = [64,192,384,256,256] 242 | elif(self.pnet_type=='squeeze'): 243 | net_type = squeezenet 244 | self.chns = [64,128,256,384,384,512,512] 245 | self.L = len(self.chns) 246 | 247 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 248 | 249 | if(lpips): 250 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 251 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 252 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 253 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 254 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 255 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 256 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 257 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 258 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 259 | self.lins+=[self.lin5,self.lin6] 260 | 261 | def forward(self, in0, in1, retPerLayer=False): 262 | # v0.0 - original release had a bug, where input was not scaled 263 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 264 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 265 | feats0, feats1, diffs = {}, {}, {} 266 | 267 | for kk in range(self.L): 268 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 269 | diffs[kk] = (feats0[kk]-feats1[kk])**2 270 | 271 | if(self.lpips): 272 | if(self.spatial): 273 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] 274 | else: 275 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 276 | else: 277 | if(self.spatial): 278 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] 279 | else: 280 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 281 | 282 | val = res[0] 283 | for l in range(1,self.L): 284 | val += res[l] 285 | 286 | if(retPerLayer): 287 | return (val, res) 288 | else: 289 | return val 290 | 291 | 292 | def normalize_tensor(in_feat,eps=1e-10): 293 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 294 | return in_feat/(norm_factor+eps) 295 | 296 | 297 | class DistModel(object): 298 | 299 | def __init__(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 300 | use_gpu=True, printNet=False, spatial=False, 301 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 302 | ''' 303 | INPUTS 304 | model - ['net-lin'] for linearly calibrated network 305 | ['net'] for off-the-shelf network 306 | ['L2'] for L2 distance in Lab colorspace 307 | ['SSIM'] for ssim in RGB colorspace 308 | net - ['squeeze','alex','vgg'] 309 | model_path - if None, will look in weights/[NET_NAME].pth 310 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 311 | use_gpu - bool - whether or not to use a GPU 312 | printNet - bool - whether or not to print network architecture out 313 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 314 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 315 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 316 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 317 | is_train - bool - [True] for training mode 318 | lr - float - initial learning rate 319 | beta1 - float - initial momentum term for adam 320 | version - 0.1 for latest, 0.0 was original (with a bug) 321 | gpu_ids - int array - [0] by default, gpus to use 322 | ''' 323 | 324 | self.model = model 325 | self.net = net 326 | self.is_train = is_train 327 | self.spatial = spatial 328 | self.gpu_ids = gpu_ids 329 | self.model_name = '%s [%s]'%(model,net) 330 | 331 | if(self.model == 'net-lin'): # pretrained net + linear layer 332 | self.net = PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 333 | use_dropout=True, spatial=spatial, version=version, lpips=True) 334 | kw = {} 335 | if not use_gpu: 336 | kw['map_location'] = 'cpu' 337 | if(model_path is None): 338 | DIR = os.path.dirname(os.path.abspath(__file__)) 339 | model_path = os.path.join(DIR, 'weights/v%s/%s.pth'%(version,net)) 340 | 341 | if(not is_train): 342 | # print('Loading model from: %s'%model_path) 343 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 344 | 345 | elif(self.model=='net'): # pretrained network 346 | self.net = PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 347 | else: 348 | raise ValueError("Model [%s] not recognized." % self.model) 349 | 350 | self.parameters = list(self.net.parameters()) 351 | 352 | self.net.eval() 353 | 354 | if(use_gpu): 355 | self.net.to(gpu_ids[0]) 356 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 357 | 358 | 359 | def name(self): 360 | return self.model_name 361 | 362 | def forward(self, in0, in1, retPerLayer=False): 363 | ''' Function computes the distance between image patches in0 and in1 364 | INPUTS 365 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 366 | OUTPUT 367 | computed distances between in0 and in1 368 | ''' 369 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 370 | 371 | 372 | class PerceptualLoss(torch.nn.Module): 373 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): 374 | super(PerceptualLoss, self).__init__() 375 | # print('Setting up Perceptual loss...') 376 | self.use_gpu = use_gpu 377 | self.spatial = spatial 378 | self.gpu_ids = gpu_ids 379 | self.model = DistModel(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 380 | 381 | def forward(self, pred, target, normalize=False): 382 | """ 383 | Pred and target are Variables. 384 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 385 | If normalize is False, assumes the images are already between [-1,+1] 386 | 387 | Inputs pred and target are Nx3xHxW 388 | Output pytorch Variable N long 389 | """ 390 | 391 | if normalize: 392 | target = 2 * target - 1 393 | pred = 2 * pred - 1 394 | 395 | return self.model.forward(target, pred) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | """ 4 | @Email: guangmingwu2010@gmail.com \ 5 | guozhilingty@gmail.com 6 | @Copyright: go-hiroaki & Chokurei 7 | @License: MIT 8 | """ 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from community import lpips 14 | from community import fid 15 | 16 | eps = 1e-6 17 | 18 | def _binarize(y_data, threshold): 19 | """ 20 | args: 21 | y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols] 22 | threshold : [float] [0.0, 1.0] 23 | return 4-d binarized y_data 24 | """ 25 | y_data[y_data < threshold] = 0.0 26 | y_data[y_data >= threshold] = 1.0 27 | return y_data 28 | 29 | def _argmax(y_data, dim): 30 | """ 31 | args: 32 | y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols] 33 | dim : int 34 | return 3-d [int] y_data 35 | """ 36 | return torch.argmax(y_data, dim).int() 37 | 38 | 39 | def _get_tp(y_pred, y_true): 40 | """ 41 | args: 42 | y_true : [int] 3-d in [batch_size, img_rows, img_cols] 43 | y_pred : [int] 3-d in [batch_size, img_rows, img_cols] 44 | return [float] true_positive 45 | """ 46 | return torch.sum(y_true * y_pred).float() 47 | 48 | 49 | def _get_fp(y_pred, y_true): 50 | """ 51 | args: 52 | y_true : 3-d ndarray in [batch_size, img_rows, img_cols] 53 | y_pred : 3-d ndarray in [batch_size, img_rows, img_cols] 54 | return [float] false_positive 55 | """ 56 | return torch.sum((1 - y_true) * y_pred).float() 57 | 58 | 59 | def _get_tn(y_pred, y_true): 60 | """ 61 | args: 62 | y_true : 3-d ndarray in [batch_size, img_rows, img_cols] 63 | y_pred : 3-d ndarray in [batch_size, img_rows, img_cols] 64 | return [float] true_negative 65 | """ 66 | return torch.sum((1 - y_true) * (1 - y_pred)).float() 67 | 68 | 69 | def _get_fn(y_pred, y_true): 70 | """ 71 | args: 72 | y_true : 3-d ndarray in [batch_size, img_rows, img_cols] 73 | y_pred : 3-d ndarray in [batch_size, img_rows, img_cols] 74 | return [float] false_negative 75 | """ 76 | return torch.sum(y_true * (1 - y_pred)).float() 77 | 78 | 79 | def _get_weights(y_true, nb_ch): 80 | """ 81 | args: 82 | y_true : 3-d ndarray in [batch_size, img_rows, img_cols] 83 | nb_ch : int 84 | return [float] weights 85 | """ 86 | batch_size, img_rows, img_cols = y_true.shape 87 | pixels = batch_size * img_rows * img_cols 88 | weights = [torch.sum(y_true==ch).item() / pixels for ch in range(nb_ch)] 89 | return weights 90 | 91 | 92 | class CFMatrix(object): 93 | def __init__(self, des=None): 94 | self.des = des 95 | 96 | def __repr__(self): 97 | return "ConfusionMatrix" 98 | 99 | def __call__(self, y_pred, y_true, threshold=0.5): 100 | 101 | """ 102 | args: 103 | y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 104 | y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 105 | threshold : [0.0, 1.0] 106 | return confusion matrix 107 | """ 108 | batch_size, chs, img_rows, img_cols = y_true.shape 109 | device = y_true.device 110 | if chs == 1: 111 | y_pred = _binarize(y_pred, threshold) 112 | y_true = _binarize(y_true, threshold) 113 | nb_tp = _get_tp(y_pred, y_true) 114 | nb_fp = _get_fp(y_pred, y_true) 115 | nb_tn = _get_tn(y_pred, y_true) 116 | nb_fn = _get_fn(y_pred, y_true) 117 | mperforms = [nb_tp, nb_fp, nb_tn, nb_fn] 118 | performs = None 119 | else: 120 | y_pred = _argmax(y_pred, 1) 121 | y_true = _argmax(y_true, 1) 122 | performs = torch.zeros(chs, 4).to(device) 123 | weights = _get_weights(y_true, chs) 124 | for ch in range(chs): 125 | y_true_ch = torch.zeros(batch_size, img_rows, img_cols) 126 | y_pred_ch = torch.zeros(batch_size, img_rows, img_cols) 127 | y_true_ch[y_true == ch] = 1 128 | y_pred_ch[y_pred == ch] = 1 129 | nb_tp = _get_tp(y_pred_ch, y_true_ch) 130 | nb_fp = _get_fp(y_pred_ch, y_true_ch) 131 | nb_tn = _get_tn(y_pred_ch, y_true_ch) 132 | nb_fn = _get_fn(y_pred_ch, y_true_ch) 133 | performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn] 134 | mperforms = sum([i*j for (i, j) in zip(performs, weights)]) 135 | return mperforms, performs 136 | 137 | 138 | class OAAcc(object): 139 | def __init__(self, des="Overall Accuracy"): 140 | self.des = des 141 | 142 | def __repr__(self): 143 | return "OAcc" 144 | 145 | def __call__(self, y_pred, y_true, threshold=0.5): 146 | """ 147 | args: 148 | y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 149 | y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 150 | threshold : [0.0, 1.0] 151 | return (tp+tn)/total 152 | """ 153 | batch_size, chs, img_rows, img_cols = y_true.shape 154 | device = y_true.device 155 | if chs == 1: 156 | y_pred = _binarize(y_pred, threshold) 157 | y_true = _binarize(y_true, threshold) 158 | else: 159 | y_pred = _argmax(y_pred, 1) 160 | y_true = _argmax(y_true, 1) 161 | 162 | nb_tp_tn = torch.sum(y_true == y_pred).float() 163 | mperforms = nb_tp_tn / (batch_size * img_rows * img_cols) 164 | performs = None 165 | return mperforms, performs 166 | 167 | 168 | class Precision(object): 169 | def __init__(self, des="Precision"): 170 | self.des = des 171 | 172 | def __repr__(self): 173 | return "Prec" 174 | 175 | def __call__(self, y_pred, y_true, threshold=0.5): 176 | """ 177 | args: 178 | y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 179 | y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 180 | threshold : [0.0, 1.0] 181 | return tp/(tp+fp) 182 | """ 183 | batch_size, chs, img_rows, img_cols = y_true.shape 184 | device = y_true.device 185 | if chs == 1: 186 | y_pred = _binarize(y_pred, threshold) 187 | y_true = _binarize(y_true, threshold) 188 | nb_tp = _get_tp(y_pred, y_true) 189 | nb_fp = _get_fp(y_pred, y_true) 190 | mperforms = nb_tp / (nb_tp + nb_fp + esp) 191 | performs = None 192 | else: 193 | y_pred = _argmax(y_pred, 1) 194 | y_true = _argmax(y_true, 1) 195 | performs = torch.zeros(chs, 1).to(device) 196 | weights = _get_weights(y_true, chs) 197 | for ch in range(chs): 198 | y_true_ch = torch.zeros(batch_size, img_rows, img_cols) 199 | y_pred_ch = torch.zeros(batch_size, img_rows, img_cols) 200 | y_true_ch[y_true == ch] = 1 201 | y_pred_ch[y_pred == ch] = 1 202 | nb_tp = _get_tp(y_pred_ch, y_true_ch) 203 | nb_fp = _get_fp(y_pred_ch, y_true_ch) 204 | performs[int(ch)] = nb_tp / (nb_tp + nb_fp + esp) 205 | mperforms = sum([i*j for (i, j) in zip(performs, weights)]) 206 | return mperforms, performs 207 | 208 | 209 | class Recall(object): 210 | def __init__(self, des="Recall"): 211 | self.des = des 212 | 213 | def __repr__(self): 214 | return "Reca" 215 | 216 | def __call__(self, y_pred, y_true, threshold=0.5): 217 | """ 218 | args: 219 | y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 220 | y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 221 | threshold : [0.0, 1.0] 222 | return tp/(tp+fn) 223 | """ 224 | batch_size, chs, img_rows, img_cols = y_true.shape 225 | device = y_true.device 226 | if chs == 1: 227 | y_pred = _binarize(y_pred, threshold) 228 | y_true = _binarize(y_true, threshold) 229 | nb_tp = _get_tp(y_pred, y_true) 230 | nb_fn = _get_fn(y_pred, y_true) 231 | mperforms = nb_tp / (nb_tp + nb_fn + esp) 232 | performs = None 233 | else: 234 | y_pred = _argmax(y_pred, 1) 235 | y_true = _argmax(y_true, 1) 236 | performs = torch.zeros(chs, 1).to(device) 237 | weights = _get_weights(y_true, chs) 238 | for ch in range(chs): 239 | y_true_ch = torch.zeros(batch_size, img_rows, img_cols) 240 | y_pred_ch = torch.zeros(batch_size, img_rows, img_cols) 241 | y_true_ch[y_true == ch] = 1 242 | y_pred_ch[y_pred == ch] = 1 243 | nb_tp = _get_tp(y_pred_ch, y_true_ch) 244 | nb_fn = _get_fn(y_pred_ch, y_true_ch) 245 | performs[int(ch)] = nb_tp / (nb_tp + nb_fn + esp) 246 | mperforms = sum([i*j for (i, j) in zip(performs, weights)]) 247 | return mperforms, performs 248 | 249 | 250 | class F1Score(object): 251 | def __init__(self, des="F1Score"): 252 | self.des = des 253 | 254 | def __repr__(self): 255 | return "F1Sc" 256 | 257 | def __call__(self, y_pred, y_true, threshold=0.5): 258 | 259 | """ 260 | args: 261 | y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 262 | y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 263 | threshold : [0.0, 1.0] 264 | return 2*precision*recall/(precision+recall) 265 | """ 266 | batch_size, chs, img_rows, img_cols = y_true.shape 267 | device = y_true.device 268 | if chs == 1: 269 | y_pred = _binarize(y_pred, threshold) 270 | y_true = _binarize(y_true, threshold) 271 | nb_tp = _get_tp(y_pred, y_true) 272 | nb_fp = _get_fp(y_pred, y_true) 273 | nb_fn = _get_fn(y_pred, y_true) 274 | _precision = nb_tp / (nb_tp + nb_fp + esp) 275 | _recall = nb_tp / (nb_tp + nb_fn + esp) 276 | mperforms = 2 * _precision * _recall / (_precision + _recall + esp) 277 | performs = None 278 | else: 279 | y_pred = _argmax(y_pred, 1) 280 | y_true = _argmax(y_true, 1) 281 | performs = torch.zeros(chs, 1).to(device) 282 | weights = _get_weights(y_true, chs) 283 | for ch in range(chs): 284 | y_true_ch = torch.zeros(batch_size, img_rows, img_cols) 285 | y_pred_ch = torch.zeros(batch_size, img_rows, img_cols) 286 | y_true_ch[y_true == ch] = 1 287 | y_pred_ch[y_pred == ch] = 1 288 | nb_tp = _get_tp(y_pred_ch, y_true_ch) 289 | nb_fp = _get_fp(y_pred_ch, y_true_ch) 290 | nb_fn = _get_fn(y_pred_ch, y_true_ch) 291 | _precision = nb_tp / (nb_tp + nb_fp + esp) 292 | _recall = nb_tp / (nb_tp + nb_fn + esp) 293 | performs[int(ch)] = 2 * _precision * \ 294 | _recall / (_precision + _recall + esp) 295 | mperforms = sum([i*j for (i, j) in zip(performs, weights)]) 296 | return mperforms, performs 297 | 298 | 299 | class Kappa(object): 300 | def __init__(self, des="Kappa"): 301 | self.des = des 302 | 303 | def __repr__(self): 304 | return "Kapp" 305 | 306 | def __call__(self, y_pred, y_true, threshold=0.5): 307 | 308 | """ 309 | args: 310 | y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 311 | y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 312 | threshold : [0.0, 1.0] 313 | return (Po-Pe)/(1-Pe) 314 | """ 315 | batch_size, chs, img_rows, img_cols = y_true.shape 316 | device = y_true.device 317 | if chs == 1: 318 | y_pred = _binarize(y_pred, threshold) 319 | y_true = _binarize(y_true, threshold) 320 | nb_tp = _get_tp(y_pred, y_true) 321 | nb_fp = _get_fp(y_pred, y_true) 322 | nb_tn = _get_tn(y_pred, y_true) 323 | nb_fn = _get_fn(y_pred, y_true) 324 | nb_total = nb_tp + nb_fp + nb_tn + nb_fn 325 | Po = (nb_tp + nb_tn) / nb_total 326 | Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) + 327 | (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2) 328 | mperforms = (Po - Pe) / (1 - Pe + esp) 329 | performs = None 330 | else: 331 | y_pred = _argmax(y_pred, 1) 332 | y_true = _argmax(y_true, 1) 333 | performs = torch.zeros(chs, 1).to(device) 334 | weights = _get_weights(y_true, chs) 335 | for ch in range(chs): 336 | y_true_ch = torch.zeros(batch_size, img_rows, img_cols) 337 | y_pred_ch = torch.zeros(batch_size, img_rows, img_cols) 338 | y_true_ch[y_true == ch] = 1 339 | y_pred_ch[y_pred == ch] = 1 340 | nb_tp = _get_tp(y_pred_ch, y_true_ch) 341 | nb_fp = _get_fp(y_pred_ch, y_true_ch) 342 | nb_tn = _get_tn(y_pred_ch, y_true_ch) 343 | nb_fn = _get_fn(y_pred_ch, y_true_ch) 344 | nb_total = nb_tp + nb_fp + nb_tn + nb_fn 345 | Po = (nb_tp + nb_tn) / nb_total 346 | Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) 347 | + (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2) 348 | performs[int(ch)] = (Po - Pe) / (1 - Pe + esp) 349 | mperforms = sum([i*j for (i, j) in zip(performs, weights)]) 350 | return mperforms, performs 351 | 352 | 353 | class Jaccard(object): 354 | def __init__(self, des="Jaccard"): 355 | self.des = des 356 | 357 | def __repr__(self): 358 | return "Jacc" 359 | 360 | def __call__(self, y_pred, y_true, threshold=0.5): 361 | """ 362 | args: 363 | y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 364 | y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols] 365 | threshold : [0.0, 1.0] 366 | return intersection / (sum-intersection) 367 | """ 368 | batch_size, chs, img_rows, img_cols = y_true.shape 369 | device = y_true.device 370 | if chs == 1: 371 | y_pred = _binarize(y_pred, threshold) 372 | y_true = _binarize(y_true, threshold) 373 | _intersec = torch.sum(y_true * y_pred).float() 374 | _sum = torch.sum(y_true + y_pred).float() 375 | mperforms = _intersec / (_sum - _intersec + esp) 376 | performs = None 377 | else: 378 | y_pred = _argmax(y_pred, 1) 379 | y_true = _argmax(y_true, 1) 380 | performs = torch.zeros(chs, 1).to(device) 381 | weights = _get_weights(y_true, chs) 382 | for ch in range(chs): 383 | y_true_ch = torch.zeros(batch_size, img_rows, img_cols) 384 | y_pred_ch = torch.zeros(batch_size, img_rows, img_cols) 385 | y_true_ch[y_true == ch] = 1 386 | y_pred_ch[y_pred == ch] = 1 387 | _intersec = torch.sum(y_true_ch * y_pred_ch).float() 388 | _sum = torch.sum(y_true_ch + y_pred_ch).float() 389 | performs[int(ch)] = _intersec / (_sum - _intersec + esp) 390 | mperforms = sum([i*j for (i, j) in zip(performs, weights)]) 391 | return mperforms, performs 392 | 393 | 394 | class MSE(object): 395 | def __init__(self, des="Mean Square Error"): 396 | self.des = des 397 | 398 | def __repr__(self): 399 | return "MSE" 400 | 401 | def __call__(self, y_pred, y_true, dim=1, threshold=None): 402 | """ 403 | args: 404 | y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 405 | y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 406 | threshold : [0.0, 1.0] 407 | return mean_squared_error, smaller the better 408 | """ 409 | if threshold: 410 | y_pred = _binarize(y_pred, threshold) 411 | return torch.mean((y_pred - y_true) ** 2) 412 | 413 | 414 | class PSNR(object): 415 | def __init__(self, des="Peak Signal to Noise Ratio"): 416 | self.des = des 417 | 418 | def __repr__(self): 419 | return "PSNR" 420 | 421 | def __call__(self, y_pred, y_true, dim=1, threshold=None): 422 | """ 423 | args: 424 | y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 425 | y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 426 | threshold : [0.0, 1.0] 427 | return PSNR, larger the better 428 | """ 429 | if threshold: 430 | y_pred = _binarize(y_pred, threshold) 431 | mse = torch.mean((y_pred - y_true) ** 2) 432 | return 10 * torch.log10(1 / mse) 433 | 434 | 435 | class SSIM(object): 436 | ''' 437 | modified from https://github.com/jorge-pessoa/pytorch-msssim 438 | ''' 439 | def __init__(self, des="structural similarity index"): 440 | self.des = des 441 | 442 | def __repr__(self): 443 | return "SSIM" 444 | 445 | def gaussian(self, w_size, sigma): 446 | gauss = torch.Tensor([math.exp(-(x - w_size//2)**2/float(2*sigma**2)) for x in range(w_size)]) 447 | return gauss/gauss.sum() 448 | 449 | def create_window(self, w_size, channel=1): 450 | _1D_window = self.gaussian(w_size, 1.5).unsqueeze(1) 451 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 452 | window = _2D_window.expand(channel, 1, w_size, w_size).contiguous() 453 | return window 454 | 455 | def __call__(self, y_pred, y_true, w_size=11, size_average=True, full=False): 456 | """ 457 | args: 458 | y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 459 | y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 460 | w_size : int, default 11 461 | size_average : boolean, default True 462 | full : boolean, default False 463 | return ssim, larger the better 464 | """ 465 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 466 | if torch.max(y_pred) > 128: 467 | max_val = 255 468 | else: 469 | max_val = 1 470 | 471 | if torch.min(y_pred) < -0.5: 472 | min_val = -1 473 | else: 474 | min_val = 0 475 | L = max_val - min_val 476 | 477 | padd = 0 478 | (_, channel, height, width) = y_pred.size() 479 | window = self.create_window(w_size, channel=channel).to(y_pred.device) 480 | 481 | mu1 = F.conv2d(y_pred, window, padding=padd, groups=channel) 482 | mu2 = F.conv2d(y_true, window, padding=padd, groups=channel) 483 | 484 | mu1_sq = mu1.pow(2) 485 | mu2_sq = mu2.pow(2) 486 | mu1_mu2 = mu1 * mu2 487 | 488 | sigma1_sq = F.conv2d(y_pred * y_pred, window, padding=padd, groups=channel) - mu1_sq 489 | sigma2_sq = F.conv2d(y_true * y_true, window, padding=padd, groups=channel) - mu2_sq 490 | sigma12 = F.conv2d(y_pred * y_true, window, padding=padd, groups=channel) - mu1_mu2 491 | 492 | C1 = (0.01 * L) ** 2 493 | C2 = (0.03 * L) ** 2 494 | 495 | v1 = 2.0 * sigma12 + C2 496 | v2 = sigma1_sq + sigma2_sq + C2 497 | cs = torch.mean(v1 / v2) # contrast sensitivity 498 | 499 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 500 | 501 | if size_average: 502 | ret = ssim_map.mean() 503 | else: 504 | ret = ssim_map.mean(1).mean(1).mean(1) 505 | 506 | if full: 507 | return ret, cs 508 | return ret 509 | 510 | 511 | class LPIPS(object): 512 | ''' 513 | borrowed from https://github.com/richzhang/PerceptualSimilarity 514 | ''' 515 | def __init__(self, cuda, des="Learned Perceptual Image Patch Similarity", version="0.1"): 516 | self.des = des 517 | self.version = version 518 | self.model = lpips.PerceptualLoss(model='net-lin',net='alex',use_gpu=cuda) 519 | 520 | def __repr__(self): 521 | return "LPIPS" 522 | 523 | def __call__(self, y_pred, y_true, normalized=True): 524 | """ 525 | args: 526 | y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 527 | y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 528 | normalized : change [0,1] => [-1,1] (default by LPIPS) 529 | return LPIPS, smaller the better 530 | """ 531 | if normalized: 532 | y_pred = y_pred * 2.0 - 1.0 533 | y_true = y_true * 2.0 - 1.0 534 | return self.model.forward(y_pred, y_true) 535 | 536 | 537 | class AE(object): 538 | """ 539 | Modified from matlab : colorangle.m, MATLAB V2019b 540 | angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2))); 541 | angle = 180 / pi * angle; 542 | """ 543 | def __init__(self, des='average Angular Error'): 544 | self.des = des 545 | 546 | def __repr__(self): 547 | return "AE" 548 | 549 | def __call__(self, y_pred, y_true): 550 | """ 551 | args: 552 | y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 553 | y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols] 554 | return average AE, smaller the better 555 | """ 556 | dotP = torch.sum(y_pred * y_true, dim=1) 557 | Norm_pred = torch.sqrt(torch.sum(y_pred * y_pred, dim=1)) 558 | Norm_true = torch.sqrt(torch.sum(y_true * y_true, dim=1)) 559 | ae = 180 / math.pi * torch.acos(dotP / (Norm_pred * Norm_true + eps)) 560 | return ae.mean(1).mean(1) 561 | 562 | 563 | if __name__ == "__main__": 564 | for ch in [3, 1]: 565 | batch_size, img_row, img_col = 1, 224, 224 566 | y_true = torch.rand(batch_size, ch, img_row, img_col) 567 | noise = torch.zeros(y_true.size()).data.normal_(0, std=0.1) 568 | y_pred = y_true + noise 569 | for cuda in [False, True]: 570 | if cuda: 571 | y_pred = y_pred.cuda() 572 | y_true = y_true.cuda() 573 | 574 | print('#'*20, 'Cuda : {} ; size : {}'.format(cuda, y_true.size())) 575 | ########### similarity metrics 576 | metric = MSE() 577 | acc = metric(y_pred, y_true).item() 578 | print("{} ==> {}".format(repr(metric), acc)) 579 | 580 | metric = PSNR() 581 | acc = metric(y_pred, y_true).item() 582 | print("{} ==> {}".format(repr(metric), acc)) 583 | 584 | metric = SSIM() 585 | acc = metric(y_pred, y_true).item() 586 | print("{} ==> {}".format(repr(metric), acc)) 587 | 588 | metric = LPIPS(cuda) 589 | acc = metric(y_pred, y_true).item() 590 | print("{} ==> {}".format(repr(metric), acc)) 591 | 592 | metric = AE() 593 | acc = metric(y_pred, y_true).item() 594 | print("{} ==> {}".format(repr(metric), acc)) 595 | 596 | ########### accuracy metrics 597 | metric = OAAcc() 598 | maccu, accu = metric(y_pred, y_true) 599 | print('mAccu:', maccu, 'Accu', accu) 600 | 601 | metric = Precision() 602 | mprec, prec = metric(y_pred, y_true) 603 | print('mPrec:', mprec, 'Prec', prec) 604 | 605 | metric = Recall() 606 | mreca, reca = metric(y_pred, y_true) 607 | print('mReca:', mreca, 'Reca', reca) 608 | 609 | metric = F1Score() 610 | mf1sc, f1sc = metric(y_pred, y_true) 611 | print('mF1sc:', mf1sc, 'F1sc', f1sc) 612 | 613 | metric = Kappa() 614 | mkapp, kapp = metric(y_pred, y_true) 615 | print('mKapp:', mkapp, 'Kapp', kapp) 616 | 617 | metric = Jaccard() 618 | mjacc, jacc = metric(y_pred, y_true) 619 | print('mJacc:', mjacc, 'Jacc', jacc) 620 | --------------------------------------------------------------------------------