├── README.md
├── data
└── README.md
├── main.py
├── requirements.txt
└── src
├── framework.py
├── model.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Face Parsing
3 |
4 | Implementation of a one-stage Face Parsing model using Unet architecture.
5 | The face is divided into 10 classes (background, eyes, nose, lips, ear, hair, teeth, eyebrows, general face, beard).
6 | This results in regions of interest that are difficult to segment. We address this problem by implementing different loss functions:
7 |
8 | - Tversky focal loss
9 | - Weighted Cross entropy
10 | - α-balanced focal loss
11 |
12 | The latter leads to a better accuracy of 0.90 (IOU-train) and 0.72 (IOU-test) for the non fine-tuned model, with a Unet-16 driven on 1 gpu.
13 |
14 | ## Implementation
15 |
16 | Language version : Python 3.7.6
17 | Operating System : MacOS Catalina 10.15.4
18 | Framework: Pytorch
19 |
20 | ## Results
21 |
22 | We can observe the probability map for each channel.
23 | It allows us to estimate how much a pixel belongs to a specific part of the face.
24 |
25 |
26 | Segmentation sample:
27 |
28 |
29 |
30 | ### Getting Started
31 | Install dependencies:
32 | ```python
33 | # use a virtual-env to keep your packages unchanged
34 | >> pip install -r requirements.txt
35 | ```
36 | Train the model:
37 | ```python
38 | # Train the model
39 | >> ./main.py
40 | ```
41 | Inference:
42 | ```python
43 | from src.model import Unet
44 | from src.framework import Context
45 |
46 | img = "./img_relative_path.png"
47 |
48 | # load model
49 | model = torch.load("my_model.pt", map_location=torch.device('cpu'))
50 | # create context
51 | ctx = Context()
52 | # predict
53 | yhat = ctx.predict(model, img, plot=True)
54 | ```
55 |
56 | ## References
57 |
58 | - [Mut1ny](https://www.mut1ny.com/face-headsegmentation-dataset) Face/head dataset
59 | - [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/pdf/1505.04597.pdf)
60 | Olaf Ronneberger, Philipp Fischer, and Thomas Brox.
61 | Tech report, arXiv, May 2015.
62 | - [A novel Focal Tversky loss function with improved Attention U-Net for lesion segmentation](https://arxiv.org/pdf/1810.07842.pdf)
63 | Abraham, Nabila and Khan, Naimul Mefraz.
64 | Tech report, arXiv, Oct 2018.
65 | - [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf)
66 | Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár.
67 | Tech report, arXiv, Feb 2018.
68 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 |
2 | [Mut1ny](https://www.mut1ny.com/face-headsegmentation-dataset) Face/head dataset
3 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import torch
4 | import numpy as np
5 |
6 | from src.framework import Context
7 | from src.model import Unet
8 | from src.utils import show_learning, focal_loss
9 |
10 |
11 | if __name__ == "__main__":
12 |
13 | # 1) Train the model
14 |
15 | # use gpu if available
16 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17 | print(f"\nAvailable device: {DEVICE}")
18 |
19 | # data folders
20 | img_path = "./data/images"
21 | masks_path = "./data/masks"
22 |
23 | # create model
24 | model = Unet(in_channels=3, out_channels=10, nb_features=16).to(DEVICE)
25 |
26 | print("nb_parameters: ", sum([param.nelement() for param in model.parameters()]))
27 |
28 | # Set training context
29 | ctx = Context().fit(
30 | img_dir=img_path,
31 | mask_dir=masks_path,
32 | img_size=256,
33 | mask_size=256,
34 | shuffle_data=False,
35 | device=DEVICE,
36 | )
37 | ctx.set(nb_epochs=100, batch_size=16)
38 |
39 | # Train
40 | ctx.train(model)
41 |
42 | # save model
43 | torch.save(model, "unet_16.pt")
44 | # training stats
45 | show_learning(model)
46 |
47 | # 2) Inference
48 | """
49 | saved_model = torch.load("unet_16.pt", map_location=torch.device('cpu'))
50 |
51 | img_name = "./data/images/001_scaled.png"
52 | mask_name = "./data/masks/001_scaled.png"
53 |
54 | ctx2 = Context()
55 | yhat = ctx2.predict(saved_model, img_name, mask_name, plot=True)
56 | """
57 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.18.0
2 | matplotlib==3.1.2
3 | Pillow==7.0.0
4 | torch==1.5.0
5 | torchvision==0.4.2
--------------------------------------------------------------------------------
/src/framework.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import math
3 | import os
4 | import time
5 |
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from PIL import Image
9 | import torch
10 | from torchvision import transforms
11 | from torch.utils.data import DataLoader
12 | from torch.utils.data.dataset import Dataset
13 |
14 | from src.utils import toOneHot, iou_accuracy
15 |
16 | class DataGenerator(Dataset):
17 | """
18 | Custom Dataset for DataLoader
19 | """
20 | def __init__(self, data, img_size, mask_size):
21 | """
22 | Args:
23 | data (dict) :contains images and masks names
24 | img_size (integer): image size ``squared image``
25 | mask_size (integer): mask size ``squared mask``
26 |
27 | """
28 | self.data = data
29 | self.img_size = img_size
30 | self.mask_size = mask_size
31 | self.imgProcessor = None
32 |
33 | # images processing
34 | self.transform_image = transforms.Compose([
35 | transforms.Resize((self.img_size)),
36 | transforms.ToTensor()
37 | ])
38 | self.transform_mask = transforms.Compose([
39 | transforms.Resize((self.mask_size)),
40 | toOneHot
41 | ])
42 |
43 | def __len__(self):
44 | return len(self.data["img"])
45 |
46 | def __getitem__(self, index):
47 | img = self.transform_image(Image.open(self.data['img'][index]))
48 | mask = self.transform_mask(Image.open(self.data['mask'][index]))
49 |
50 | return img, mask
51 |
52 | class Context(object):
53 | """
54 | Set of methods to train networks
55 | """
56 |
57 | def __init__(self):
58 | super(Context, self).__init__()
59 | self.train_data = {"img":[], "mask":[]}
60 | self.valid_data = {"img":[], "mask":[]}
61 | self.trainloader = None
62 | self.validloader = None
63 | self.img_size = None
64 | self.mask_size = None
65 | self.epochs = None
66 | self.batch_size = None
67 | self.DEVICE = "cpu"
68 |
69 | def fit(self, img_dir, mask_dir, img_size, mask_size, device, validation_split=.3, shuffle_data=True):
70 | """
71 | Prepare data for training
72 | """
73 | # set gpu or cpu device
74 | self.DEVICE=device
75 |
76 | # set dimensions
77 | self.img_size = img_size
78 | self.mask_size = mask_size
79 |
80 | # get images names and extensions
81 | img_names = {x[:-4]:x[-4:] for x in os.listdir(img_dir)}
82 | mask_names = {x[:-4]:x[-4:] for x in os.listdir(mask_dir)}
83 |
84 | # extract similar names in both directories
85 | names = sorted(list(set(img_names.keys()).intersection(mask_names.keys())))
86 |
87 | if shuffle_data:
88 | np.random.shuffle(names)
89 |
90 | # split train/validation
91 | split_indice = int(len(names)*(1-validation_split))
92 |
93 | for i, name in enumerate(names):
94 | if i < split_indice:
95 | self.train_data["img"].append(img_dir + '/' + name + img_names[name])
96 | self.train_data["mask"].append(mask_dir + '/' + name + mask_names[name])
97 | else:
98 | self.valid_data["img"].append(img_dir + '/' + name + img_names[name])
99 | self.valid_data["mask"].append(mask_dir + '/' + name + mask_names[name])
100 |
101 | # to numpy
102 | self.train_data["img"] = np.array(self.train_data["img"])
103 | self.train_data["mask"] = np.array(self.train_data["mask"])
104 |
105 | self.valid_data["img"] = np.array(self.valid_data["img"])
106 | self.valid_data["mask"] = np.array(self.valid_data["mask"])
107 |
108 | return self
109 |
110 | def set(self, nb_epochs=5, batch_size=1):
111 | """
112 | Set training parameters
113 |
114 | Args:
115 | nb_epochs (int): Number of epochs ``default 5``
116 | batch_size (int): batch sizeb ``default 1``
117 | """
118 | self.batch_size = batch_size
119 | self.epochs = nb_epochs
120 |
121 | # data loaders
122 | self.trainloader = DataLoader(DataGenerator(self.train_data, self.img_size, self.mask_size),
123 | batch_size=batch_size,
124 | shuffle=False,
125 | num_workers=2,
126 | pin_memory=True)
127 |
128 | self.validloader = DataLoader(DataGenerator(self.valid_data, self.img_size, self.mask_size),
129 | batch_size=batch_size,
130 | shuffle=False,
131 | num_workers=2,
132 | pin_memory=True)
133 |
134 | def train(self, model):
135 | """
136 | Train the model
137 |
138 | Args:
139 | model: pytorch Model
140 | """
141 | print("\nTraining...\n")
142 |
143 | nb_batches = math.ceil(len(self.trainloader.dataset) / self.batch_size)
144 |
145 | for epoch in range(self.epochs):
146 |
147 | running_loss = []
148 | iou_score = []
149 | start = time.time()
150 | ## print infos
151 | print(f"Epoch {epoch+1}/{self.epochs}")
152 |
153 | model.train()
154 |
155 | for batch, (features, targets) in enumerate(self.trainloader):
156 |
157 | model.optimizer.zero_grad()
158 | ## loss calculation
159 | loss = model.criterion(model(features.to(self.DEVICE)), targets.to(self.DEVICE))
160 | ## backward propagation
161 | loss.backward()
162 | ## weights optimization
163 | model.optimizer.step()
164 |
165 | # running loss
166 | running_loss.append(loss.item())
167 |
168 | ## IOU accuracy
169 | iou_score.append(iou_accuracy(model(features.to(self.DEVICE)), targets.to(self.DEVICE)))
170 |
171 | ## print train infos
172 | print(
173 | f"\r Batch: {batch+1}/{nb_batches}",
174 | f" - time: {str(datetime.timedelta(seconds = time.time() - start))}",
175 | f" - T_loss: {np.mean(running_loss):.3f}",
176 | f" - T_IOU: {np.mean(iou_score):.2f}", end = ''
177 | )
178 | ## save loss
179 | model.train_loss.append(np.mean(running_loss))
180 | ## save accuracy
181 | model.train_accuracy.append(np.mean(iou_score))
182 |
183 | ## evaluate
184 | self.evaluate(model, self.validloader, True)
185 |
186 | def evaluate(self, model, validloader, save_loss=False):
187 | """
188 | Evaluation part
189 |
190 | Args:
191 | model: pytorch model
192 | validloader: DatasetLoader for validation data
193 | save_loss(boolean): if 'True' save the validation loss
194 | """
195 | running_loss = []
196 | iou_score = []
197 |
198 | model.eval()
199 | for i, (features, targets) in enumerate(validloader):
200 | ## compute loss for prediction
201 | loss = model.criterion(model(features.to(self.DEVICE)), targets.to(self.DEVICE))
202 |
203 | ## save loss
204 | running_loss.append(loss.item())
205 |
206 | ## IOU accuracy
207 | iou_score.append(iou_accuracy(model(features.to(self.DEVICE)), targets.to(self.DEVICE)))
208 | # save accuracy
209 | model.valid_accuracy.append(np.mean(iou_score))
210 |
211 | if save_loss:
212 | model.valid_loss.append(np.mean(running_loss))
213 |
214 | ## print infos
215 | print(
216 | f" - V_loss: {np.mean(running_loss):.3f}",
217 | f" - V_IOU:{np.mean(model.valid_accuracy):.2f}"
218 | )
219 | def predict(self, model, img_path, mask_path=None, plot=False):
220 | """
221 | Inference
222 |
223 | Args:
224 | model: Unet model
225 | img_path (String): image path
226 | mask_path (String): optional, mask path
227 | plot: (boolean): optional, if 'True' plot the prediction
228 | outputs:
229 | yhat: predicted segmentation
230 | """
231 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
232 |
233 | # open image
234 | img = Image.open(img_path)
235 | # transform image
236 | img = transforms.Resize((256))(img)
237 | img = transforms.ToTensor()(img)
238 | img = img.unsqueeze(0)
239 | # predict
240 | yhat = model(img.to(DEVICE))
241 | yhat = yhat[0].permute(1,2,0)
242 |
243 | if plot:
244 | #plots
245 | fig, axes = plt.subplots(2,6, figsize=(18,6))
246 | channel = 0
247 | for i in range( axes.shape[0]):
248 | for j in range(1, axes.shape[1]):
249 | axes[i,j].imshow((yhat).cpu().float().detach().numpy()[:,:,channel],cmap = 'viridis')
250 | channel+=1
251 | axes[0,0].imshow(img[0].permute(1,2,0))
252 |
253 | # mask
254 | if mask_path is not None:
255 | mask = Image.open(mask_path)
256 | #plot ground truth
257 | axes[1,0].imshow(mask)
258 |
259 | plt.show()
260 |
261 | return yhat.cpu().float().detach().numpy()
262 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from src.utils import focal_loss
5 |
6 | class Unet(nn.Module):
7 | def __init__(self, in_channels, out_channels, nb_features, lr=0.0008):
8 | """
9 | Implementation of the Unet Architecture
10 |
11 | Args:
12 | in_channels (integer) : Number of channels from the original image.
13 | out_channels (integer): Number of channels for the output image.
14 | ``Correspond to the number of class to return``
15 | nb_features (integer): Number of features to be extracted in the first convolution
16 | lr (Float): optimizer learning rate ``default: 0.0008``
17 |
18 | """
19 | super(Unet, self).__init__()
20 |
21 | # attributes
22 | self.lr=lr
23 | self.train_loss = []
24 | self.valid_loss = []
25 | self.train_accuracy = []
26 | self.valid_accuracy = []
27 |
28 | # contracting path
29 | self.conv_1 = self.doubleconv(in_channels, nb_features)
30 | self.conv_2 = self.doubleconv(nb_features, nb_features*2, maxpool2d=True)
31 | self.conv_3 = self.doubleconv(nb_features*2, nb_features*4, maxpool2d=True)
32 | self.conv_4 = self.doubleconv(nb_features*4, nb_features*8, maxpool2d=True)
33 |
34 | self.bottleneck = self.doubleconv(nb_features*8, nb_features*16, maxpool2d=True)
35 |
36 | # expansive path
37 | self.up_4 = self.upsample(nb_features*16, nb_features*8)
38 | self.deconv_4 = self.doubleconv(nb_features*16, nb_features*8)
39 | self.up_3 = self.upsample(nb_features*8, nb_features*4)
40 | self.deconv_3 = self.doubleconv(nb_features*8, nb_features*4)
41 | self.up_2 = self.upsample(nb_features*4, nb_features*2)
42 | self.deconv_2 = self.doubleconv(nb_features*4, nb_features*2)
43 | self.up_1 = self.upsample(nb_features*2, nb_features)
44 | self.deconv_1 = self.doubleconv(nb_features*2, nb_features)
45 |
46 | self.last_conv = nn.Sequential(
47 | nn.Conv2d(nb_features, out_channels, kernel_size=1, bias=False),
48 | nn.Softmax2d()
49 | )
50 |
51 | # parameters
52 | self.criterion = focal_loss
53 | self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
54 |
55 | def forward(self, x):
56 | """
57 | Defines the forward propagation
58 |
59 | Args:
60 | x (Tensor): Torch tensor containing the batch of image to process.
61 | x of shape BxCxHxW
62 | """
63 | # contracting path
64 | conv_1 = self.conv_1(x)
65 | conv_2 = self.conv_2(conv_1)
66 | conv_3 = self.conv_3(conv_2)
67 | conv_4 = self.conv_4(conv_3)
68 | # bottleneck
69 | bottleneck = self.bottleneck(conv_4)
70 | # expansive path
71 | deconv_4 = self.deconv_4(torch.cat([self.up_4(bottleneck), conv_4], dim=1))
72 | deconv_3 = self.deconv_3(torch.cat([self.up_3(deconv_4), conv_3], dim=1))
73 | deconv_2 = self.deconv_2(torch.cat([self.up_2(deconv_3), conv_2], dim=1))
74 | deconv_1 = self.deconv_1(torch.cat([self.up_1(deconv_2), conv_1], dim=1))
75 |
76 | # last layer
77 | out = self.last_conv(deconv_1)
78 |
79 | return out
80 |
81 | def doubleconv(self, in_channels, out_channels, maxpool2d=False):
82 | """
83 | Defines a double convolution
84 |
85 | Args:
86 | in_channels (int): Number of channels received
87 | out_channels (int): Number of output channels
88 | maxpool2d (boolean): if 'True' add a MaxPool2d layer
89 | output:
90 | return a sequential layer
91 | """
92 | layers = [
93 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
94 | nn.InstanceNorm2d(num_features=out_channels),
95 | nn.ReLU(inplace=True),
96 | nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
97 | nn.InstanceNorm2d(num_features=out_channels),
98 | nn.ReLU(inplace=True)
99 | ]
100 |
101 | if maxpool2d:
102 | layers.insert(0, nn.MaxPool2d(kernel_size=2, stride=2))
103 |
104 | return nn.Sequential(*layers)
105 |
106 | def upsample(self, in_channels, out_channels):
107 | """
108 | Defines an umpsampling layer
109 |
110 | Args:
111 | in_channels (int): Number of channels received
112 | out_channels (int): Number of output channels
113 | output:
114 | return a sequential layer
115 | """
116 | block = nn.Sequential(
117 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
118 | )
119 |
120 | return block
121 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | ### data processing ###
7 |
8 | def toOneHot(mask, nb_class=10):
9 | """
10 | Convert label image to onehot encoding
11 |
12 | Args:
13 | mask (Image): mask containing pixels labels
14 | nb_class (int): number of class
15 | """
16 | categorical = torch.from_numpy(np.array(mask)).long()
17 | categorical = F.one_hot(categorical, nb_class)
18 |
19 | return categorical.permute(2,0,1).float()
20 |
21 | ### plots ###
22 |
23 | def show_learning(model):
24 | """
25 | Plot loss and accuracy from model
26 |
27 | Args:
28 | model: Unet model
29 | """
30 | fig, axes = plt.subplots(1, 2, figsize=(15,5))
31 |
32 | # plot losses
33 | axes[0].plot(model.train_loss, label="train")
34 | axes[0].plot(model.valid_loss, label="validation")
35 | axes[0].set_xlabel("Epochs")
36 |
37 | try:
38 | axes[0].set_ylabel(model.criterion._get_name())
39 | except:
40 | axes[0].set_ylabel("Loss")
41 |
42 | axes[0].set_title("Loss evolution")
43 | axes[0].legend()
44 |
45 | # plot accuracy
46 | axes[1].plot(model.train_accuracy, label="train")
47 | axes[1].plot(model.valid_accuracy, label="validation")
48 | axes[1].set_xlabel("Epochs")
49 | axes[1].set_ylabel("IOU score")
50 | axes[1].set_title("Accuracy evolution")
51 |
52 | axes[1].legend()
53 |
54 | plt.show()
55 |
56 | ### losses & accuracy ###
57 |
58 | def dice_loss(yhat, ytrue, epsilon=1e-6):
59 | """
60 | Computes a soft Dice Loss
61 |
62 | Args:
63 | yhat (Tensor): predicted masks
64 | ytrue (Tensor): targets masks
65 | epsilon (Float): smoothing value to avoid division by 0
66 | output:
67 | DL value with `mean` reduction
68 | """
69 | # compute Dice components
70 | intersection = torch.sum(yhat * ytrue, (1,2,3))
71 | cardinal = torch.sum(yhat + ytrue, (1,2,3))
72 |
73 | return torch.mean(1. - (2 * intersection / (cardinal + epsilon)))
74 |
75 | def tversky_index(yhat, ytrue, alpha=0.3, beta=0.7, epsilon=1e-6):
76 | """
77 | Computes Tversky index
78 |
79 | Args:
80 | yhat (Tensor): predicted masks
81 | ytrue (Tensor): targets masks
82 | alpha (Float): weight for False positive
83 | beta (Float): weight for False negative
84 | `` alpha and beta control the magnitude of penalties and should sum to 1``
85 | epsilon (Float): smoothing value to avoid division by 0
86 | output:
87 | tversky index value
88 | """
89 | TP = torch.sum(yhat * ytrue, (1,2,3))
90 | FP = torch.sum((1. - ytrue) * yhat, (1,2,3))
91 | FN = torch.sum((1. - yhat) * ytrue, (1,2,3))
92 |
93 | return TP/(TP + alpha * FP + beta * FN + epsilon)
94 |
95 | def tversky_loss(yhat, ytrue):
96 | """
97 | Computes tversky loss given tversky index
98 |
99 | Args:
100 | yhat (Tensor): predicted masks
101 | ytrue (Tensor): targets masks
102 | output:
103 | tversky loss value with `mean` reduction
104 | """
105 | return torch.mean(1 - tversky_index(yhat, ytrue))
106 |
107 | def tversky_focal_loss(yhat, ytrue, alpha=0.7, beta=0.3, gamma=0.75):
108 | """
109 | Computes tversky focal loss for highly umbalanced data
110 | https://arxiv.org/pdf/1810.07842.pdf
111 |
112 | Args:
113 | yhat (Tensor): predicted masks
114 | ytrue (Tensor): targets masks
115 | alpha (Float): weight for False positive
116 | beta (Float): weight for False negative
117 | `` alpha and beta control the magnitude of penalties and should sum to 1``
118 | gamma (Float): focal parameter
119 | ``control the balance between easy background and hard ROI training examples``
120 | output:
121 | tversky focal loss value with `mean` reduction
122 | """
123 |
124 | return torch.mean(torch.pow(1 - tversky_index(yhat, ytrue, alpha, beta), gamma))
125 |
126 | def focal_loss(yhat, ytrue, alpha=0.75, gamma=2):
127 | """
128 | Computes α-balanced focal loss from FAIR
129 | https://arxiv.org/pdf/1708.02002v2.pdf
130 |
131 | Args:
132 | yhat (Tensor): predicted masks
133 | ytrue (Tensor): targets masks
134 | alpha (Float): weight to balance Cross entropy value
135 | gamma (Float): focal parameter
136 | output:
137 | loss value with `mean` reduction
138 | """
139 |
140 | # compute the actual focal loss
141 | focal = -alpha * torch.pow(1. - yhat, gamma) * torch.log(yhat)
142 | f_loss = torch.sum(ytrue * focal, dim=1)
143 |
144 | return torch.mean(f_loss)
145 |
146 | def iou_accuracy(yhat, ytrue, threshold=0.5, epsilon=1e-6):
147 | """
148 | Computes Intersection over Union metric
149 |
150 | Args:
151 | yhat (Tensor): predicted masks
152 | ytrue (Tensor): targets masks
153 | threshold (Float): threshold for pixel classification
154 | epsilon (Float): smoothing parameter for numerical stability
155 | output:
156 | iou value with `mean` reduction
157 | """
158 | intersection = ((yhat>threshold).long() & ytrue.long()).float().sum((1,2,3))
159 | union = ((yhat>threshold).long() | ytrue.long()).float().sum((1,2,3))
160 |
161 | return torch.mean(intersection/(union + epsilon)).item()
162 |
--------------------------------------------------------------------------------