├── README.md ├── data └── README.md ├── main.py ├── requirements.txt └── src ├── framework.py ├── model.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Face Parsing 3 | 4 | Implementation of a one-stage Face Parsing model using Unet architecture.
5 | The face is divided into 10 classes (background, eyes, nose, lips, ear, hair, teeth, eyebrows, general face, beard).
6 | This results in regions of interest that are difficult to segment. We address this problem by implementing different loss functions:
7 | 8 | - Tversky focal loss 9 | - Weighted Cross entropy 10 | - α-balanced focal loss 11 | 12 | The latter leads to a better accuracy of 0.90 (IOU-train) and 0.72 (IOU-test) for the non fine-tuned model, with a Unet-16 driven on 1 gpu. 13 | 14 | ## Implementation 15 | 16 | Language version : Python 3.7.6
17 | Operating System : MacOS Catalina 10.15.4
18 | Framework: Pytorch 19 | 20 | ## Results 21 | 22 | We can observe the probability map for each channel.
23 | It allows us to estimate how much a pixel belongs to a specific part of the face.
24 | 25 | 26 | Segmentation sample:
27 | 28 | 29 | 30 | ### Getting Started 31 | Install dependencies:
32 | ```python 33 | # use a virtual-env to keep your packages unchanged 34 | >> pip install -r requirements.txt 35 | ``` 36 | Train the model:
37 | ```python 38 | # Train the model 39 | >> ./main.py 40 | ``` 41 | Inference:
42 | ```python 43 | from src.model import Unet 44 | from src.framework import Context 45 | 46 | img = "./img_relative_path.png" 47 | 48 | # load model 49 | model = torch.load("my_model.pt", map_location=torch.device('cpu')) 50 | # create context 51 | ctx = Context() 52 | # predict 53 | yhat = ctx.predict(model, img, plot=True) 54 | ``` 55 | 56 | ## References 57 | 58 | - [Mut1ny](https://www.mut1ny.com/face-headsegmentation-dataset) Face/head dataset 59 | - [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/pdf/1505.04597.pdf) 60 | Olaf Ronneberger, Philipp Fischer, and Thomas Brox. 61 | Tech report, arXiv, May 2015. 62 | - [A novel Focal Tversky loss function with improved Attention U-Net for lesion segmentation](https://arxiv.org/pdf/1810.07842.pdf) 63 | Abraham, Nabila and Khan, Naimul Mefraz. 64 | Tech report, arXiv, Oct 2018. 65 | - [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf) 66 | Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár. 67 | Tech report, arXiv, Feb 2018. 68 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | 2 | [Mut1ny](https://www.mut1ny.com/face-headsegmentation-dataset) Face/head dataset 3 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from src.framework import Context 7 | from src.model import Unet 8 | from src.utils import show_learning, focal_loss 9 | 10 | 11 | if __name__ == "__main__": 12 | 13 | # 1) Train the model 14 | 15 | # use gpu if available 16 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | print(f"\nAvailable device: {DEVICE}") 18 | 19 | # data folders 20 | img_path = "./data/images" 21 | masks_path = "./data/masks" 22 | 23 | # create model 24 | model = Unet(in_channels=3, out_channels=10, nb_features=16).to(DEVICE) 25 | 26 | print("nb_parameters: ", sum([param.nelement() for param in model.parameters()])) 27 | 28 | # Set training context 29 | ctx = Context().fit( 30 | img_dir=img_path, 31 | mask_dir=masks_path, 32 | img_size=256, 33 | mask_size=256, 34 | shuffle_data=False, 35 | device=DEVICE, 36 | ) 37 | ctx.set(nb_epochs=100, batch_size=16) 38 | 39 | # Train 40 | ctx.train(model) 41 | 42 | # save model 43 | torch.save(model, "unet_16.pt") 44 | # training stats 45 | show_learning(model) 46 | 47 | # 2) Inference 48 | """ 49 | saved_model = torch.load("unet_16.pt", map_location=torch.device('cpu')) 50 | 51 | img_name = "./data/images/001_scaled.png" 52 | mask_name = "./data/masks/001_scaled.png" 53 | 54 | ctx2 = Context() 55 | yhat = ctx2.predict(saved_model, img_name, mask_name, plot=True) 56 | """ 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.0 2 | matplotlib==3.1.2 3 | Pillow==7.0.0 4 | torch==1.5.0 5 | torchvision==0.4.2 -------------------------------------------------------------------------------- /src/framework.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import os 4 | import time 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from PIL import Image 9 | import torch 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.dataset import Dataset 13 | 14 | from src.utils import toOneHot, iou_accuracy 15 | 16 | class DataGenerator(Dataset): 17 | """ 18 | Custom Dataset for DataLoader 19 | """ 20 | def __init__(self, data, img_size, mask_size): 21 | """ 22 | Args: 23 | data (dict) :contains images and masks names 24 | img_size (integer): image size ``squared image`` 25 | mask_size (integer): mask size ``squared mask`` 26 | 27 | """ 28 | self.data = data 29 | self.img_size = img_size 30 | self.mask_size = mask_size 31 | self.imgProcessor = None 32 | 33 | # images processing 34 | self.transform_image = transforms.Compose([ 35 | transforms.Resize((self.img_size)), 36 | transforms.ToTensor() 37 | ]) 38 | self.transform_mask = transforms.Compose([ 39 | transforms.Resize((self.mask_size)), 40 | toOneHot 41 | ]) 42 | 43 | def __len__(self): 44 | return len(self.data["img"]) 45 | 46 | def __getitem__(self, index): 47 | img = self.transform_image(Image.open(self.data['img'][index])) 48 | mask = self.transform_mask(Image.open(self.data['mask'][index])) 49 | 50 | return img, mask 51 | 52 | class Context(object): 53 | """ 54 | Set of methods to train networks 55 | """ 56 | 57 | def __init__(self): 58 | super(Context, self).__init__() 59 | self.train_data = {"img":[], "mask":[]} 60 | self.valid_data = {"img":[], "mask":[]} 61 | self.trainloader = None 62 | self.validloader = None 63 | self.img_size = None 64 | self.mask_size = None 65 | self.epochs = None 66 | self.batch_size = None 67 | self.DEVICE = "cpu" 68 | 69 | def fit(self, img_dir, mask_dir, img_size, mask_size, device, validation_split=.3, shuffle_data=True): 70 | """ 71 | Prepare data for training 72 | """ 73 | # set gpu or cpu device 74 | self.DEVICE=device 75 | 76 | # set dimensions 77 | self.img_size = img_size 78 | self.mask_size = mask_size 79 | 80 | # get images names and extensions 81 | img_names = {x[:-4]:x[-4:] for x in os.listdir(img_dir)} 82 | mask_names = {x[:-4]:x[-4:] for x in os.listdir(mask_dir)} 83 | 84 | # extract similar names in both directories 85 | names = sorted(list(set(img_names.keys()).intersection(mask_names.keys()))) 86 | 87 | if shuffle_data: 88 | np.random.shuffle(names) 89 | 90 | # split train/validation 91 | split_indice = int(len(names)*(1-validation_split)) 92 | 93 | for i, name in enumerate(names): 94 | if i < split_indice: 95 | self.train_data["img"].append(img_dir + '/' + name + img_names[name]) 96 | self.train_data["mask"].append(mask_dir + '/' + name + mask_names[name]) 97 | else: 98 | self.valid_data["img"].append(img_dir + '/' + name + img_names[name]) 99 | self.valid_data["mask"].append(mask_dir + '/' + name + mask_names[name]) 100 | 101 | # to numpy 102 | self.train_data["img"] = np.array(self.train_data["img"]) 103 | self.train_data["mask"] = np.array(self.train_data["mask"]) 104 | 105 | self.valid_data["img"] = np.array(self.valid_data["img"]) 106 | self.valid_data["mask"] = np.array(self.valid_data["mask"]) 107 | 108 | return self 109 | 110 | def set(self, nb_epochs=5, batch_size=1): 111 | """ 112 | Set training parameters 113 | 114 | Args: 115 | nb_epochs (int): Number of epochs ``default 5`` 116 | batch_size (int): batch sizeb ``default 1`` 117 | """ 118 | self.batch_size = batch_size 119 | self.epochs = nb_epochs 120 | 121 | # data loaders 122 | self.trainloader = DataLoader(DataGenerator(self.train_data, self.img_size, self.mask_size), 123 | batch_size=batch_size, 124 | shuffle=False, 125 | num_workers=2, 126 | pin_memory=True) 127 | 128 | self.validloader = DataLoader(DataGenerator(self.valid_data, self.img_size, self.mask_size), 129 | batch_size=batch_size, 130 | shuffle=False, 131 | num_workers=2, 132 | pin_memory=True) 133 | 134 | def train(self, model): 135 | """ 136 | Train the model 137 | 138 | Args: 139 | model: pytorch Model 140 | """ 141 | print("\nTraining...\n") 142 | 143 | nb_batches = math.ceil(len(self.trainloader.dataset) / self.batch_size) 144 | 145 | for epoch in range(self.epochs): 146 | 147 | running_loss = [] 148 | iou_score = [] 149 | start = time.time() 150 | ## print infos 151 | print(f"Epoch {epoch+1}/{self.epochs}") 152 | 153 | model.train() 154 | 155 | for batch, (features, targets) in enumerate(self.trainloader): 156 | 157 | model.optimizer.zero_grad() 158 | ## loss calculation 159 | loss = model.criterion(model(features.to(self.DEVICE)), targets.to(self.DEVICE)) 160 | ## backward propagation 161 | loss.backward() 162 | ## weights optimization 163 | model.optimizer.step() 164 | 165 | # running loss 166 | running_loss.append(loss.item()) 167 | 168 | ## IOU accuracy 169 | iou_score.append(iou_accuracy(model(features.to(self.DEVICE)), targets.to(self.DEVICE))) 170 | 171 | ## print train infos 172 | print( 173 | f"\r Batch: {batch+1}/{nb_batches}", 174 | f" - time: {str(datetime.timedelta(seconds = time.time() - start))}", 175 | f" - T_loss: {np.mean(running_loss):.3f}", 176 | f" - T_IOU: {np.mean(iou_score):.2f}", end = '' 177 | ) 178 | ## save loss 179 | model.train_loss.append(np.mean(running_loss)) 180 | ## save accuracy 181 | model.train_accuracy.append(np.mean(iou_score)) 182 | 183 | ## evaluate 184 | self.evaluate(model, self.validloader, True) 185 | 186 | def evaluate(self, model, validloader, save_loss=False): 187 | """ 188 | Evaluation part 189 | 190 | Args: 191 | model: pytorch model 192 | validloader: DatasetLoader for validation data 193 | save_loss(boolean): if 'True' save the validation loss 194 | """ 195 | running_loss = [] 196 | iou_score = [] 197 | 198 | model.eval() 199 | for i, (features, targets) in enumerate(validloader): 200 | ## compute loss for prediction 201 | loss = model.criterion(model(features.to(self.DEVICE)), targets.to(self.DEVICE)) 202 | 203 | ## save loss 204 | running_loss.append(loss.item()) 205 | 206 | ## IOU accuracy 207 | iou_score.append(iou_accuracy(model(features.to(self.DEVICE)), targets.to(self.DEVICE))) 208 | # save accuracy 209 | model.valid_accuracy.append(np.mean(iou_score)) 210 | 211 | if save_loss: 212 | model.valid_loss.append(np.mean(running_loss)) 213 | 214 | ## print infos 215 | print( 216 | f" - V_loss: {np.mean(running_loss):.3f}", 217 | f" - V_IOU:{np.mean(model.valid_accuracy):.2f}" 218 | ) 219 | def predict(self, model, img_path, mask_path=None, plot=False): 220 | """ 221 | Inference 222 | 223 | Args: 224 | model: Unet model 225 | img_path (String): image path 226 | mask_path (String): optional, mask path 227 | plot: (boolean): optional, if 'True' plot the prediction 228 | outputs: 229 | yhat: predicted segmentation 230 | """ 231 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 232 | 233 | # open image 234 | img = Image.open(img_path) 235 | # transform image 236 | img = transforms.Resize((256))(img) 237 | img = transforms.ToTensor()(img) 238 | img = img.unsqueeze(0) 239 | # predict 240 | yhat = model(img.to(DEVICE)) 241 | yhat = yhat[0].permute(1,2,0) 242 | 243 | if plot: 244 | #plots 245 | fig, axes = plt.subplots(2,6, figsize=(18,6)) 246 | channel = 0 247 | for i in range( axes.shape[0]): 248 | for j in range(1, axes.shape[1]): 249 | axes[i,j].imshow((yhat).cpu().float().detach().numpy()[:,:,channel],cmap = 'viridis') 250 | channel+=1 251 | axes[0,0].imshow(img[0].permute(1,2,0)) 252 | 253 | # mask 254 | if mask_path is not None: 255 | mask = Image.open(mask_path) 256 | #plot ground truth 257 | axes[1,0].imshow(mask) 258 | 259 | plt.show() 260 | 261 | return yhat.cpu().float().detach().numpy() 262 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src.utils import focal_loss 5 | 6 | class Unet(nn.Module): 7 | def __init__(self, in_channels, out_channels, nb_features, lr=0.0008): 8 | """ 9 | Implementation of the Unet Architecture 10 | 11 | Args: 12 | in_channels (integer) : Number of channels from the original image. 13 | out_channels (integer): Number of channels for the output image. 14 | ``Correspond to the number of class to return`` 15 | nb_features (integer): Number of features to be extracted in the first convolution 16 | lr (Float): optimizer learning rate ``default: 0.0008`` 17 | 18 | """ 19 | super(Unet, self).__init__() 20 | 21 | # attributes 22 | self.lr=lr 23 | self.train_loss = [] 24 | self.valid_loss = [] 25 | self.train_accuracy = [] 26 | self.valid_accuracy = [] 27 | 28 | # contracting path 29 | self.conv_1 = self.doubleconv(in_channels, nb_features) 30 | self.conv_2 = self.doubleconv(nb_features, nb_features*2, maxpool2d=True) 31 | self.conv_3 = self.doubleconv(nb_features*2, nb_features*4, maxpool2d=True) 32 | self.conv_4 = self.doubleconv(nb_features*4, nb_features*8, maxpool2d=True) 33 | 34 | self.bottleneck = self.doubleconv(nb_features*8, nb_features*16, maxpool2d=True) 35 | 36 | # expansive path 37 | self.up_4 = self.upsample(nb_features*16, nb_features*8) 38 | self.deconv_4 = self.doubleconv(nb_features*16, nb_features*8) 39 | self.up_3 = self.upsample(nb_features*8, nb_features*4) 40 | self.deconv_3 = self.doubleconv(nb_features*8, nb_features*4) 41 | self.up_2 = self.upsample(nb_features*4, nb_features*2) 42 | self.deconv_2 = self.doubleconv(nb_features*4, nb_features*2) 43 | self.up_1 = self.upsample(nb_features*2, nb_features) 44 | self.deconv_1 = self.doubleconv(nb_features*2, nb_features) 45 | 46 | self.last_conv = nn.Sequential( 47 | nn.Conv2d(nb_features, out_channels, kernel_size=1, bias=False), 48 | nn.Softmax2d() 49 | ) 50 | 51 | # parameters 52 | self.criterion = focal_loss 53 | self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 54 | 55 | def forward(self, x): 56 | """ 57 | Defines the forward propagation 58 | 59 | Args: 60 | x (Tensor): Torch tensor containing the batch of image to process. 61 | x of shape BxCxHxW 62 | """ 63 | # contracting path 64 | conv_1 = self.conv_1(x) 65 | conv_2 = self.conv_2(conv_1) 66 | conv_3 = self.conv_3(conv_2) 67 | conv_4 = self.conv_4(conv_3) 68 | # bottleneck 69 | bottleneck = self.bottleneck(conv_4) 70 | # expansive path 71 | deconv_4 = self.deconv_4(torch.cat([self.up_4(bottleneck), conv_4], dim=1)) 72 | deconv_3 = self.deconv_3(torch.cat([self.up_3(deconv_4), conv_3], dim=1)) 73 | deconv_2 = self.deconv_2(torch.cat([self.up_2(deconv_3), conv_2], dim=1)) 74 | deconv_1 = self.deconv_1(torch.cat([self.up_1(deconv_2), conv_1], dim=1)) 75 | 76 | # last layer 77 | out = self.last_conv(deconv_1) 78 | 79 | return out 80 | 81 | def doubleconv(self, in_channels, out_channels, maxpool2d=False): 82 | """ 83 | Defines a double convolution 84 | 85 | Args: 86 | in_channels (int): Number of channels received 87 | out_channels (int): Number of output channels 88 | maxpool2d (boolean): if 'True' add a MaxPool2d layer 89 | output: 90 | return a sequential layer 91 | """ 92 | layers = [ 93 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), 94 | nn.InstanceNorm2d(num_features=out_channels), 95 | nn.ReLU(inplace=True), 96 | nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), 97 | nn.InstanceNorm2d(num_features=out_channels), 98 | nn.ReLU(inplace=True) 99 | ] 100 | 101 | if maxpool2d: 102 | layers.insert(0, nn.MaxPool2d(kernel_size=2, stride=2)) 103 | 104 | return nn.Sequential(*layers) 105 | 106 | def upsample(self, in_channels, out_channels): 107 | """ 108 | Defines an umpsampling layer 109 | 110 | Args: 111 | in_channels (int): Number of channels received 112 | out_channels (int): Number of output channels 113 | output: 114 | return a sequential layer 115 | """ 116 | block = nn.Sequential( 117 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 118 | ) 119 | 120 | return block 121 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | ### data processing ### 7 | 8 | def toOneHot(mask, nb_class=10): 9 | """ 10 | Convert label image to onehot encoding 11 | 12 | Args: 13 | mask (Image): mask containing pixels labels 14 | nb_class (int): number of class 15 | """ 16 | categorical = torch.from_numpy(np.array(mask)).long() 17 | categorical = F.one_hot(categorical, nb_class) 18 | 19 | return categorical.permute(2,0,1).float() 20 | 21 | ### plots ### 22 | 23 | def show_learning(model): 24 | """ 25 | Plot loss and accuracy from model 26 | 27 | Args: 28 | model: Unet model 29 | """ 30 | fig, axes = plt.subplots(1, 2, figsize=(15,5)) 31 | 32 | # plot losses 33 | axes[0].plot(model.train_loss, label="train") 34 | axes[0].plot(model.valid_loss, label="validation") 35 | axes[0].set_xlabel("Epochs") 36 | 37 | try: 38 | axes[0].set_ylabel(model.criterion._get_name()) 39 | except: 40 | axes[0].set_ylabel("Loss") 41 | 42 | axes[0].set_title("Loss evolution") 43 | axes[0].legend() 44 | 45 | # plot accuracy 46 | axes[1].plot(model.train_accuracy, label="train") 47 | axes[1].plot(model.valid_accuracy, label="validation") 48 | axes[1].set_xlabel("Epochs") 49 | axes[1].set_ylabel("IOU score") 50 | axes[1].set_title("Accuracy evolution") 51 | 52 | axes[1].legend() 53 | 54 | plt.show() 55 | 56 | ### losses & accuracy ### 57 | 58 | def dice_loss(yhat, ytrue, epsilon=1e-6): 59 | """ 60 | Computes a soft Dice Loss 61 | 62 | Args: 63 | yhat (Tensor): predicted masks 64 | ytrue (Tensor): targets masks 65 | epsilon (Float): smoothing value to avoid division by 0 66 | output: 67 | DL value with `mean` reduction 68 | """ 69 | # compute Dice components 70 | intersection = torch.sum(yhat * ytrue, (1,2,3)) 71 | cardinal = torch.sum(yhat + ytrue, (1,2,3)) 72 | 73 | return torch.mean(1. - (2 * intersection / (cardinal + epsilon))) 74 | 75 | def tversky_index(yhat, ytrue, alpha=0.3, beta=0.7, epsilon=1e-6): 76 | """ 77 | Computes Tversky index 78 | 79 | Args: 80 | yhat (Tensor): predicted masks 81 | ytrue (Tensor): targets masks 82 | alpha (Float): weight for False positive 83 | beta (Float): weight for False negative 84 | `` alpha and beta control the magnitude of penalties and should sum to 1`` 85 | epsilon (Float): smoothing value to avoid division by 0 86 | output: 87 | tversky index value 88 | """ 89 | TP = torch.sum(yhat * ytrue, (1,2,3)) 90 | FP = torch.sum((1. - ytrue) * yhat, (1,2,3)) 91 | FN = torch.sum((1. - yhat) * ytrue, (1,2,3)) 92 | 93 | return TP/(TP + alpha * FP + beta * FN + epsilon) 94 | 95 | def tversky_loss(yhat, ytrue): 96 | """ 97 | Computes tversky loss given tversky index 98 | 99 | Args: 100 | yhat (Tensor): predicted masks 101 | ytrue (Tensor): targets masks 102 | output: 103 | tversky loss value with `mean` reduction 104 | """ 105 | return torch.mean(1 - tversky_index(yhat, ytrue)) 106 | 107 | def tversky_focal_loss(yhat, ytrue, alpha=0.7, beta=0.3, gamma=0.75): 108 | """ 109 | Computes tversky focal loss for highly umbalanced data 110 | https://arxiv.org/pdf/1810.07842.pdf 111 | 112 | Args: 113 | yhat (Tensor): predicted masks 114 | ytrue (Tensor): targets masks 115 | alpha (Float): weight for False positive 116 | beta (Float): weight for False negative 117 | `` alpha and beta control the magnitude of penalties and should sum to 1`` 118 | gamma (Float): focal parameter 119 | ``control the balance between easy background and hard ROI training examples`` 120 | output: 121 | tversky focal loss value with `mean` reduction 122 | """ 123 | 124 | return torch.mean(torch.pow(1 - tversky_index(yhat, ytrue, alpha, beta), gamma)) 125 | 126 | def focal_loss(yhat, ytrue, alpha=0.75, gamma=2): 127 | """ 128 | Computes α-balanced focal loss from FAIR 129 | https://arxiv.org/pdf/1708.02002v2.pdf 130 | 131 | Args: 132 | yhat (Tensor): predicted masks 133 | ytrue (Tensor): targets masks 134 | alpha (Float): weight to balance Cross entropy value 135 | gamma (Float): focal parameter 136 | output: 137 | loss value with `mean` reduction 138 | """ 139 | 140 | # compute the actual focal loss 141 | focal = -alpha * torch.pow(1. - yhat, gamma) * torch.log(yhat) 142 | f_loss = torch.sum(ytrue * focal, dim=1) 143 | 144 | return torch.mean(f_loss) 145 | 146 | def iou_accuracy(yhat, ytrue, threshold=0.5, epsilon=1e-6): 147 | """ 148 | Computes Intersection over Union metric 149 | 150 | Args: 151 | yhat (Tensor): predicted masks 152 | ytrue (Tensor): targets masks 153 | threshold (Float): threshold for pixel classification 154 | epsilon (Float): smoothing parameter for numerical stability 155 | output: 156 | iou value with `mean` reduction 157 | """ 158 | intersection = ((yhat>threshold).long() & ytrue.long()).float().sum((1,2,3)) 159 | union = ((yhat>threshold).long() | ytrue.long()).float().sum((1,2,3)) 160 | 161 | return torch.mean(intersection/(union + epsilon)).item() 162 | --------------------------------------------------------------------------------