├── model
├── __init__.py
├── utils.py
├── resnet.py
├── mobilenet.py
├── decoder.py
├── model.py
└── refiner.py
├── images
└── teaser.gif
├── dataset
├── __init__.py
├── sample.py
├── zip.py
├── images.py
├── video.py
└── augmentation.py
├── requirements.txt
├── eval
├── compute_sad_loss.m
├── compute_mse_loss.m
├── compute_gradient_loss.m
├── gaussgradient.m
├── compute_connectivity_error.m
└── benchmark.m
├── LICENSE
├── inference_utils.py
├── data_path.py
├── export_torchscript.py
├── inference_speed_test.py
├── README.md
├── inference_images.py
├── export_onnx.py
├── inference_webcam.py
├── doc
└── model_usage.md
├── inference_video.py
├── train_base.py
└── train_refine.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Base, MattingBase, MattingRefine
--------------------------------------------------------------------------------
/images/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DataXujing/BackgroundMattingV2/master/images/teaser.gif
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .images import ImagesDataset
2 | from .video import VideoDataset
3 | from .sample import SampleDataset
4 | from .zip import ZipDataset
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | kornia==0.4.1
2 | tensorboard==2.3.0
3 | torch==1.7.0
4 | torchvision==0.8.1
5 | tqdm==4.51.0
6 | opencv-python==4.4.0.44
7 | onnxruntime==1.6.0
--------------------------------------------------------------------------------
/dataset/sample.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 |
3 |
4 | class SampleDataset(Dataset):
5 | def __init__(self, dataset, samples):
6 | samples = min(samples, len(dataset))
7 | self.dataset = dataset
8 | self.indices = [i * int(len(dataset) / samples) for i in range(samples)]
9 |
10 | def __len__(self):
11 | return len(self.indices)
12 |
13 | def __getitem__(self, idx):
14 | return self.dataset[self.indices[idx]]
15 |
--------------------------------------------------------------------------------
/eval/compute_sad_loss.m:
--------------------------------------------------------------------------------
1 | % compute the SAD error given a prediction, a ground truth and a trimap.
2 | % author Ning Xu
3 | % date 2018-1-1
4 |
5 | function loss = compute_sad_loss(pred,target,trimap)
6 | error_map = abs(single(pred)-single(target))/255;
7 | loss = sum(sum(error_map.*single(trimap==128))) ;
8 |
9 | % the loss is scaled by 1000 due to the large images used in our experiment.
10 | % Please check the result table in our paper to make sure the result is correct.
11 | loss = loss / 1000 ;
12 |
--------------------------------------------------------------------------------
/eval/compute_mse_loss.m:
--------------------------------------------------------------------------------
1 | % compute the MSE error given a prediction, a ground truth and a trimap.
2 | % author Ning Xu
3 | % date 2018-1-1
4 |
5 | % pred: the predicted alpha matte
6 | % target: the ground truth alpha matte
7 | % trimap: the given trimap
8 |
9 | function loss = compute_mse_loss(pred,target,trimap)
10 | error_map = (single(pred)-single(target))/255;
11 |
12 | % fprintf('size(error_map) is %s\n', mat2str(size(error_map)))
13 | loss = sum(sum(error_map.^2.*single(trimap==128))) / sum(sum(single(trimap==128)));
14 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | def load_matched_state_dict(model, state_dict, print_stats=True):
2 | """
3 | Only loads weights that matched in key and shape. Ignore other weights.
4 | """
5 | num_matched, num_total = 0, 0
6 | curr_state_dict = model.state_dict()
7 | for key in curr_state_dict.keys():
8 | num_total += 1
9 | if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
10 | curr_state_dict[key] = state_dict[key]
11 | num_matched += 1
12 | model.load_state_dict(curr_state_dict)
13 | if print_stats:
14 | print(f'Loaded state_dict: {num_matched}/{num_total} matched')
--------------------------------------------------------------------------------
/eval/compute_gradient_loss.m:
--------------------------------------------------------------------------------
1 | % compute the gradient error given a prediction, a ground truth and a trimap.
2 | % author Ning Xu
3 | % date 2018-1-1
4 |
5 | % pred: the predicted alpha matte
6 | % target: the ground truth alpha matte
7 | % trimap: the given trimap
8 | % step = 0.1
9 |
10 | function loss = compute_gradient_loss(pred,target,trimap)
11 | pred = mat2gray(pred);
12 | target = mat2gray(target);
13 | [pred_x,pred_y] = gaussgradient(pred,1.4);
14 | [target_x,target_y] = gaussgradient(target,1.4);
15 | pred_amp = sqrt(pred_x.^2 + pred_y.^2);
16 | target_amp = sqrt(target_x.^2 + target_y.^2);
17 |
18 | error_map = (single(pred_amp) - single(target_amp)).^2;
19 | loss = sum(sum(error_map.*single(trimap==128))) ;
20 |
--------------------------------------------------------------------------------
/dataset/zip.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from typing import List
3 |
4 | class ZipDataset(Dataset):
5 | def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
6 | self.datasets = datasets
7 | self.transforms = transforms
8 |
9 | if assert_equal_length:
10 | for i in range(1, len(datasets)):
11 | assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
12 |
13 | def __len__(self):
14 | return max(len(d) for d in self.datasets)
15 |
16 | def __getitem__(self, idx):
17 | x = tuple(d[idx % len(d)] for d in self.datasets)
18 | if self.transforms:
19 | x = self.transforms(*x)
20 | return x
21 |
--------------------------------------------------------------------------------
/dataset/images.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 |
6 | class ImagesDataset(Dataset):
7 | def __init__(self, root, mode='RGB', transforms=None):
8 | self.transforms = transforms
9 | self.mode = mode
10 | self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
11 | *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])
12 |
13 | def __len__(self):
14 | return len(self.filenames)
15 |
16 | def __getitem__(self, idx):
17 | with Image.open(self.filenames[idx]) as img:
18 | img = img.convert(self.mode)
19 |
20 | if self.transforms:
21 | img = self.transforms(img)
22 |
23 | return img
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 University of Washington
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/eval/gaussgradient.m:
--------------------------------------------------------------------------------
1 | function [gx,gy]=gaussgradient(IM,sigma)
2 | %GAUSSGRADIENT Gradient using first order derivative of Gaussian.
3 | % [gx,gy]=gaussgradient(IM,sigma) outputs the gradient image gx and gy of
4 | % image IM using a 2-D Gaussian kernel. Sigma is the standard deviation of
5 | % this kernel along both directions.
6 | %
7 | % Contributed by Guanglei Xiong (xgl99@mails.tsinghua.edu.cn)
8 | % at Tsinghua University, Beijing, China.
9 |
10 | %determine the appropriate size of kernel. The smaller epsilon, the larger
11 | %size.
12 | epsilon=1e-2;
13 | halfsize=ceil(sigma*sqrt(-2*log(sqrt(2*pi)*sigma*epsilon)));
14 | size=2*halfsize+1;
15 | %generate a 2-D Gaussian kernel along x direction
16 | for i=1:size
17 | for j=1:size
18 | u=[i-halfsize-1 j-halfsize-1];
19 | hx(i,j)=gauss(u(1),sigma)*dgauss(u(2),sigma);
20 | end
21 | end
22 | hx=hx/sqrt(sum(sum(abs(hx).*abs(hx))));
23 | %generate a 2-D Gaussian kernel along y direction
24 | hy=hx';
25 | %2-D filtering
26 | gx=imfilter(IM,hx,'replicate','conv');
27 | gy=imfilter(IM,hy,'replicate','conv');
28 |
29 | function y = gauss(x,sigma)
30 | %Gaussian
31 | y = exp(-x^2/(2*sigma^2)) / (sigma*sqrt(2*pi));
32 |
33 | function y = dgauss(x,sigma)
34 | %first order derivative of Gaussian
35 | y = -x * gauss(x,sigma) / sigma^2;
--------------------------------------------------------------------------------
/dataset/video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 |
6 | class VideoDataset(Dataset):
7 | def __init__(self, path: str, transforms: any = None):
8 | self.cap = cv2.VideoCapture(path)
9 | self.transforms = transforms
10 |
11 | self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
12 | self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
13 | self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
14 | self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
15 |
16 | def __len__(self):
17 | return self.frame_count
18 |
19 | def __getitem__(self, idx):
20 | if isinstance(idx, slice):
21 | return [self[i] for i in range(*idx.indices(len(self)))]
22 |
23 | if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:
24 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
25 | ret, img = self.cap.read()
26 | if not ret:
27 | raise IndexError(f'Idx: {idx} out of length: {len(self)}')
28 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
29 | img = Image.fromarray(img)
30 | if self.transforms:
31 | img = self.transforms(img)
32 | return img
33 |
34 | def __enter__(self):
35 | return self
36 |
37 | def __exit__(self, exc_type, exc_value, exc_traceback):
38 | self.cap.release()
39 |
--------------------------------------------------------------------------------
/eval/compute_connectivity_error.m:
--------------------------------------------------------------------------------
1 | % compute the connectivity error given a prediction, a ground truth and a trimap.
2 | % author Ning Xu
3 | % date 2018-1-1
4 |
5 | % pred: the predicted alpha matte
6 | % target: the ground truth alpha matte
7 | % trimap: the given trimap
8 | % step = 0.1
9 |
10 | function loss = compute_connectivity_error(pred,target,trimap,step)
11 | pred = single(pred)/255;
12 | target = single(target)/255;
13 |
14 | [dimy,dimx] = size(pred);
15 |
16 | thresh_steps = 0:step:1;
17 | l_map = ones(size(pred))*(-1);
18 | dist_maps = zeros([dimy,dimx,numel(thresh_steps)]);
19 | for ii = 2:numel(thresh_steps)
20 | pred_alpha_thresh = pred>=thresh_steps(ii);
21 | target_alpha_thresh = target>=thresh_steps(ii);
22 |
23 | cc = bwconncomp(pred_alpha_thresh & target_alpha_thresh,4);
24 | size_vec = cellfun(@numel,cc.PixelIdxList);
25 | [~,max_id] = max(size_vec);
26 |
27 | omega = zeros([dimy,dimx]);
28 | omega(cc.PixelIdxList{max_id}) = 1;
29 |
30 | flag = l_map==-1 & omega==0;
31 | l_map(flag==1) = thresh_steps(ii-1);
32 |
33 | dist_maps(:,:,ii) = bwdist(omega);
34 | dist_maps(:,:,ii) = dist_maps(:,:,ii) / max(max(dist_maps(:,:,ii)));
35 | end
36 | l_map(l_map==-1) = 1;
37 |
38 | pred_d = pred - l_map;
39 | target_d = target - l_map;
40 |
41 | pred_phi = 1 - pred_d .* single(pred_d>=0.15);
42 |
43 | target_phi = 1 - target_d .* single(target_d>=0.15);
44 |
45 | loss = sum(sum(abs(pred_phi - target_phi).*single(trimap==128)));
46 |
47 |
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torchvision.models.resnet import ResNet, Bottleneck
3 |
4 |
5 | class ResNetEncoder(ResNet):
6 | """
7 | ResNetEncoder inherits from torchvision's official ResNet. It is modified to
8 | use dilation on the last block to maintain output stride 16, and deleted the
9 | global average pooling layer and the fully connected layer that was originally
10 | used for classification. The forward method additionally returns the feature
11 | maps at all resolutions for decoder's use.
12 | """
13 |
14 | layers = {
15 | 'resnet50': [3, 4, 6, 3],
16 | 'resnet101': [3, 4, 23, 3],
17 | }
18 |
19 | def __init__(self, in_channels, variant='resnet101', norm_layer=None):
20 | super().__init__(
21 | block=Bottleneck,
22 | layers=self.layers[variant],
23 | replace_stride_with_dilation=[False, False, True],
24 | norm_layer=norm_layer)
25 |
26 | # Replace first conv layer if in_channels doesn't match.
27 | if in_channels != 3:
28 | self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False)
29 |
30 | # Delete fully-connected layer
31 | del self.avgpool
32 | del self.fc
33 |
34 | def forward(self, x):
35 | x0 = x # 1/1
36 | x = self.conv1(x)
37 | x = self.bn1(x)
38 | x = self.relu(x)
39 | x1 = x # 1/2
40 | x = self.maxpool(x)
41 | x = self.layer1(x)
42 | x2 = x # 1/4
43 | x = self.layer2(x)
44 | x3 = x # 1/8
45 | x = self.layer3(x)
46 | x = self.layer4(x)
47 | x4 = x # 1/16
48 | return x4, x3, x2, x1, x0
49 |
--------------------------------------------------------------------------------
/inference_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from PIL import Image
4 |
5 |
6 | class HomographicAlignment:
7 | """
8 | Apply homographic alignment on background to match with the source image.
9 | """
10 |
11 | def __init__(self):
12 | self.detector = cv2.ORB_create()
13 | self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)
14 |
15 | def __call__(self, src, bgr):
16 | src = np.asarray(src)
17 | bgr = np.asarray(bgr)
18 |
19 | keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)
20 | keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)
21 |
22 | matches = self.matcher.match(descriptors_bgr, descriptors_src, None)
23 | matches.sort(key=lambda x: x.distance, reverse=False)
24 | num_good_matches = int(len(matches) * 0.15)
25 | matches = matches[:num_good_matches]
26 |
27 | points_src = np.zeros((len(matches), 2), dtype=np.float32)
28 | points_bgr = np.zeros((len(matches), 2), dtype=np.float32)
29 | for i, match in enumerate(matches):
30 | points_src[i, :] = keypoints_src[match.trainIdx].pt
31 | points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt
32 |
33 | H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)
34 |
35 | h, w = src.shape[:2]
36 | bgr = cv2.warpPerspective(bgr, H, (w, h))
37 | msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))
38 |
39 | # For areas that is outside of the background,
40 | # We just copy pixels from the source.
41 | bgr[msk != 1] = src[msk != 1]
42 |
43 | src = Image.fromarray(src)
44 | bgr = Image.fromarray(bgr)
45 |
46 | return src, bgr
47 |
--------------------------------------------------------------------------------
/model/mobilenet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torchvision.models import MobileNetV2
3 |
4 |
5 | class MobileNetV2Encoder(MobileNetV2):
6 | """
7 | MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to
8 | use dilation on the last block to maintain output stride 16, and deleted the
9 | classifier block that was originally used for classification. The forward method
10 | additionally returns the feature maps at all resolutions for decoder's use.
11 | """
12 |
13 | def __init__(self, in_channels, norm_layer=None):
14 | super().__init__()
15 |
16 | # Replace first conv layer if in_channels doesn't match.
17 | if in_channels != 3:
18 | self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)
19 |
20 | # Remove last block
21 | self.features = self.features[:-1]
22 |
23 | # Change to use dilation to maintain output stride = 16
24 | self.features[14].conv[1][0].stride = (1, 1)
25 | for feature in self.features[15:]:
26 | feature.conv[1][0].dilation = (2, 2)
27 | feature.conv[1][0].padding = (2, 2)
28 |
29 | # Delete classifier
30 | del self.classifier
31 |
32 | def forward(self, x):
33 | x0 = x # 1/1
34 | x = self.features[0](x)
35 | x = self.features[1](x)
36 | x1 = x # 1/2
37 | x = self.features[2](x)
38 | x = self.features[3](x)
39 | x2 = x # 1/4
40 | x = self.features[4](x)
41 | x = self.features[5](x)
42 | x = self.features[6](x)
43 | x3 = x # 1/8
44 | x = self.features[7](x)
45 | x = self.features[8](x)
46 | x = self.features[9](x)
47 | x = self.features[10](x)
48 | x = self.features[11](x)
49 | x = self.features[12](x)
50 | x = self.features[13](x)
51 | x = self.features[14](x)
52 | x = self.features[15](x)
53 | x = self.features[16](x)
54 | x = self.features[17](x)
55 | x4 = x # 1/16
56 | return x4, x3, x2, x1, x0
57 |
--------------------------------------------------------------------------------
/data_path.py:
--------------------------------------------------------------------------------
1 | """
2 | This file records the directory paths to the different datasets.
3 | You will need to configure it for training the model.
4 |
5 | All datasets follows the following format, where fgr and pha points to directory that contains jpg or png.
6 | Inside the directory could be any nested formats, but fgr and pha structure must match. You can add your own
7 | dataset to the list as long as it follows the format. 'fgr' should point to foreground images with RGB channels,
8 | 'pha' should point to alpha images with only 1 grey channel.
9 | {
10 | 'YOUR_DATASET': {
11 | 'train': {
12 | 'fgr': 'PATH_TO_IMAGES_DIR',
13 | 'pha': 'PATH_TO_IMAGES_DIR',
14 | },
15 | 'valid': {
16 | 'fgr': 'PATH_TO_IMAGES_DIR',
17 | 'pha': 'PATH_TO_IMAGES_DIR',
18 | }
19 | }
20 | }
21 | """
22 |
23 | DATA_PATH = {
24 | 'videomatte240k': {
25 | 'train': {
26 | 'fgr': 'PATH_TO_IMAGES_DIR',
27 | 'pha': 'PATH_TO_IMAGES_DIR'
28 | },
29 | 'valid': {
30 | 'fgr': 'PATH_TO_IMAGES_DIR',
31 | 'pha': 'PATH_TO_IMAGES_DIR'
32 | }
33 | },
34 | 'photomatte13k': {
35 | 'train': {
36 | 'fgr': 'PATH_TO_IMAGES_DIR',
37 | 'pha': 'PATH_TO_IMAGES_DIR'
38 | },
39 | 'valid': {
40 | 'fgr': 'PATH_TO_IMAGES_DIR',
41 | 'pha': 'PATH_TO_IMAGES_DIR'
42 | }
43 | },
44 | 'distinction': {
45 | 'train': {
46 | 'fgr': 'PATH_TO_IMAGES_DIR',
47 | 'pha': 'PATH_TO_IMAGES_DIR',
48 | },
49 | 'valid': {
50 | 'fgr': 'PATH_TO_IMAGES_DIR',
51 | 'pha': 'PATH_TO_IMAGES_DIR'
52 | },
53 | },
54 | 'adobe': {
55 | 'train': {
56 | 'fgr': 'PATH_TO_IMAGES_DIR',
57 | 'pha': 'PATH_TO_IMAGES_DIR',
58 | },
59 | 'valid': {
60 | 'fgr': 'PATH_TO_IMAGES_DIR',
61 | 'pha': 'PATH_TO_IMAGES_DIR'
62 | },
63 | },
64 | 'backgrounds': {
65 | 'train': 'PATH_TO_IMAGES_DIR',
66 | 'valid': 'PATH_TO_IMAGES_DIR'
67 | },
68 | }
--------------------------------------------------------------------------------
/model/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Decoder(nn.Module):
7 | """
8 | Decoder upsamples the image by combining the feature maps at all resolutions from the encoder.
9 |
10 | Input:
11 | x4: (B, C, H/16, W/16) feature map at 1/16 resolution.
12 | x3: (B, C, H/8, W/8) feature map at 1/8 resolution.
13 | x2: (B, C, H/4, W/4) feature map at 1/4 resolution.
14 | x1: (B, C, H/2, W/2) feature map at 1/2 resolution.
15 | x0: (B, C, H, W) feature map at full resolution.
16 |
17 | Output:
18 | x: (B, C, H, W) upsampled output at full resolution.
19 | """
20 |
21 | def __init__(self, channels, feature_channels):
22 | super().__init__()
23 | self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False)
24 | self.bn1 = nn.BatchNorm2d(channels[1])
25 | self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False)
26 | self.bn2 = nn.BatchNorm2d(channels[2])
27 | self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(channels[3])
29 | self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1)
30 | self.relu = nn.ReLU(True)
31 |
32 | def forward(self, x4, x3, x2, x1, x0):
33 | x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False)
34 | x = torch.cat([x, x3], dim=1)
35 | x = self.conv1(x)
36 | x = self.bn1(x)
37 | x = self.relu(x)
38 | x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
39 | x = torch.cat([x, x2], dim=1)
40 | x = self.conv2(x)
41 | x = self.bn2(x)
42 | x = self.relu(x)
43 | x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)
44 | x = torch.cat([x, x1], dim=1)
45 | x = self.conv3(x)
46 | x = self.bn3(x)
47 | x = self.relu(x)
48 | x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
49 | x = torch.cat([x, x0], dim=1)
50 | x = self.conv4(x)
51 | return x
52 |
--------------------------------------------------------------------------------
/eval/benchmark.m:
--------------------------------------------------------------------------------
1 | #!/usr/bin/octave
2 | arg_list = argv ();
3 | bench_path = arg_list{1};
4 | result_path = arg_list{2};
5 |
6 |
7 | gt_files = dir(fullfile(bench_path, 'pha', '*.png'));
8 |
9 | total_loss_mse = 0;
10 | total_loss_sad = 0;
11 | total_loss_gradient = 0;
12 | total_loss_connectivity = 0;
13 |
14 | total_fg_mse = 0;
15 | total_premult_mse = 0;
16 |
17 | for i = 1:length(gt_files)
18 | filename = gt_files(i).name;
19 |
20 | gt_fullname = fullfile(bench_path, 'pha', filename);
21 | gt_alpha = imread(gt_fullname);
22 | trimap = imread(fullfile(bench_path, 'trimap', filename));
23 | crop_edge = idivide(size(gt_alpha), 4) * 4;
24 | gt_alpha = gt_alpha(1:crop_edge(1), 1:crop_edge(2));
25 | trimap = trimap(1:crop_edge(1), 1:crop_edge(2));
26 |
27 | result_fullname = fullfile(result_path, 'pha', filename);%strrep(filename, '.png', '.jpg'));
28 | hat_alpha = imread(result_fullname)(1:crop_edge(1), 1:crop_edge(2));
29 |
30 |
31 | fg_hat_fullname = fullfile(result_path, 'fgr', filename);%strrep(filename, '.png', '.jpg'));
32 | fg_gt_fullname = fullfile(bench_path, 'fgr', filename);
33 | hat_fgr = imread(fg_hat_fullname)(1:crop_edge(1), 1:crop_edge(2), :);
34 | gt_fgr = imread(fg_gt_fullname)(1:crop_edge(1), 1:crop_edge(2), :);
35 | nonzero_alpha = gt_alpha > 0;
36 |
37 |
38 | % fprintf('size(gt_fgr) is %s\n', mat2str(size(gt_fgr)))
39 | fg_mse = mean(compute_mse_loss(hat_fgr .* nonzero_alpha, gt_fgr .* nonzero_alpha, trimap));
40 | mse = compute_mse_loss(hat_alpha, gt_alpha, trimap);
41 | sad = compute_sad_loss(hat_alpha, gt_alpha, trimap);
42 | grad = compute_gradient_loss(hat_alpha, gt_alpha, trimap);
43 | conn = compute_connectivity_error(hat_alpha, gt_alpha, trimap, 0.1);
44 |
45 |
46 | fprintf(2, strcat(filename, ',%.6f,%.3f,%.0f,%.0f,%.6f\n'), mse, sad, grad, conn, fg_mse);
47 | fflush(stderr);
48 |
49 | total_loss_mse += mse;
50 | total_loss_sad += sad;
51 | total_loss_gradient += grad;
52 | total_loss_connectivity += conn;
53 | total_fg_mse += fg_mse;
54 | end
55 |
56 | avg_loss_mse = total_loss_mse / length(gt_files);
57 | avg_loss_sad = total_loss_sad / length(gt_files);
58 | avg_loss_gradient = total_loss_gradient / length(gt_files);
59 | avg_loss_connectivity = total_loss_connectivity / length(gt_files);
60 | avg_loss_fg_mse = total_fg_mse / length(gt_files);
61 |
62 | fprintf('mse:%.6f,sad:%.3f,grad:%.0f,conn:%.0f,fg_mse:%.6f\n', avg_loss_mse, avg_loss_sad, avg_loss_gradient, avg_loss_connectivity, avg_loss_fg_mse);
63 |
--------------------------------------------------------------------------------
/export_torchscript.py:
--------------------------------------------------------------------------------
1 | """
2 | Export TorchScript
3 |
4 | python export_torchscript.py \
5 | --model-backbone resnet50 \
6 | --model-checkpoint "PATH_TO_CHECKPOINT" \
7 | --precision float32 \
8 | --output "torchscript.pth"
9 | """
10 |
11 | import argparse
12 | import torch
13 | from torch import nn
14 | from model import MattingRefine
15 |
16 |
17 | # --------------- Arguments ---------------
18 |
19 |
20 | parser = argparse.ArgumentParser(description='Export TorchScript')
21 |
22 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
23 | parser.add_argument('--model-checkpoint', type=str, required=True)
24 | parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
25 | parser.add_argument('--output', type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 |
30 | # --------------- Utils ---------------
31 |
32 |
33 | class MattingRefine_TorchScriptWrapper(nn.Module):
34 | """
35 | The purpose of this wrapper is to hoist all the configurable attributes to the top level.
36 | So that the user can easily change them after loading the saved TorchScript model.
37 |
38 | Example:
39 | model = torch.jit.load('torchscript.pth')
40 | model.backbone_scale = 0.25
41 | model.refine_mode = 'sampling'
42 | model.refine_sample_pixels = 80_000
43 | pha, fgr = model(src, bgr)[:2]
44 | """
45 |
46 | def __init__(self, *args, **kwargs):
47 | super().__init__()
48 | self.model = MattingRefine(*args, **kwargs)
49 |
50 | # Hoist the attributes to the top level.
51 | self.backbone_scale = self.model.backbone_scale
52 | self.refine_mode = self.model.refiner.mode
53 | self.refine_sample_pixels = self.model.refiner.sample_pixels
54 | self.refine_threshold = self.model.refiner.threshold
55 | self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling
56 |
57 | def forward(self, src, bgr):
58 | # Reset the attributes.
59 | self.model.backbone_scale = self.backbone_scale
60 | self.model.refiner.mode = self.refine_mode
61 | self.model.refiner.sample_pixels = self.refine_sample_pixels
62 | self.model.refiner.threshold = self.refine_threshold
63 | self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling
64 |
65 | return self.model(src, bgr)
66 |
67 | def load_state_dict(self, *args, **kwargs):
68 | return self.model.load_state_dict(*args, **kwargs)
69 |
70 |
71 | # --------------- Main ---------------
72 |
73 |
74 | model = MattingRefine_TorchScriptWrapper(args.model_backbone).eval()
75 | model.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu'))
76 | for p in model.parameters():
77 | p.requires_grad = False
78 |
79 | if args.precision == 'float16':
80 | model = model.half()
81 |
82 | model = torch.jit.script(model)
83 | model.save(args.output)
84 |
--------------------------------------------------------------------------------
/inference_speed_test.py:
--------------------------------------------------------------------------------
1 | """
2 | Inference Speed Test
3 |
4 | Example:
5 |
6 | Run inference on random noise input for fixed computation setting.
7 | (i.e. mode in ['full', 'sampling'])
8 |
9 | python inference_speed_test.py \
10 | --model-type mattingrefine \
11 | --model-backbone resnet50 \
12 | --model-backbone-scale 0.25 \
13 | --model-refine-mode sampling \
14 | --model-refine-sample-pixels 80000 \
15 | --batch-size 1 \
16 | --resolution 1920 1080 \
17 | --backend pytorch \
18 | --precision float32
19 |
20 | Run inference on provided image input for dynamic computation setting.
21 | (i.e. mode in ['thresholding'])
22 |
23 | python inference_speed_test.py \
24 | --model-type mattingrefine \
25 | --model-backbone resnet50 \
26 | --model-backbone-scale 0.25 \
27 | --model-checkpoint "PATH_TO_CHECKPOINT" \
28 | --model-refine-mode thresholding \
29 | --model-refine-threshold 0.7 \
30 | --batch-size 1 \
31 | --backend pytorch \
32 | --precision float32 \
33 | --image-src "PATH_TO_IMAGE_SRC" \
34 | --image-bgr "PATH_TO_IMAGE_BGR"
35 |
36 | """
37 |
38 | import argparse
39 | import torch
40 | from torchvision.transforms.functional import to_tensor
41 | from tqdm import tqdm
42 | from PIL import Image
43 |
44 | from model import MattingBase, MattingRefine
45 |
46 |
47 | # --------------- Arguments ---------------
48 |
49 |
50 | parser = argparse.ArgumentParser()
51 |
52 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
53 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
54 | parser.add_argument('--model-backbone-scale', type=float, default=0.25)
55 | parser.add_argument('--model-checkpoint', type=str, default=None)
56 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
57 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
58 | parser.add_argument('--model-refine-threshold', type=float, default=0.7)
59 | parser.add_argument('--model-refine-kernel-size', type=int, default=3)
60 |
61 | parser.add_argument('--batch-size', type=int, default=1)
62 | parser.add_argument('--resolution', type=int, default=None, nargs=2)
63 | parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
64 | parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript'])
65 | parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
66 |
67 | parser.add_argument('--image-src', type=str, default=None)
68 | parser.add_argument('--image-bgr', type=str, default=None)
69 |
70 | args = parser.parse_args()
71 |
72 |
73 | assert type(args.image_src) == type(args.image_bgr), 'Image source and background must be provided together.'
74 | assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.'
75 |
76 |
77 | # --------------- Run Loop ---------------
78 |
79 |
80 | device = torch.device(args.device)
81 |
82 | # Load model
83 | if args.model_type == 'mattingbase':
84 | model = MattingBase(args.model_backbone)
85 | if args.model_type == 'mattingrefine':
86 | model = MattingRefine(
87 | args.model_backbone,
88 | args.model_backbone_scale,
89 | args.model_refine_mode,
90 | args.model_refine_sample_pixels,
91 | args.model_refine_threshold,
92 | args.model_refine_kernel_size,
93 | refine_prevent_oversampling=False)
94 |
95 | if args.model_checkpoint:
96 | model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
97 |
98 | if args.precision == 'float32':
99 | precision = torch.float32
100 | else:
101 | precision = torch.float16
102 |
103 | if args.backend == 'torchscript':
104 | model = torch.jit.script(model)
105 |
106 | model = model.eval().to(device=device, dtype=precision)
107 |
108 | # Load data
109 | if not args.image_src:
110 | src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
111 | bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
112 | else:
113 | src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
114 | bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
115 |
116 | # Loop
117 | with torch.no_grad():
118 | for _ in tqdm(range(1000)):
119 | model(src, bgr)
120 |
--------------------------------------------------------------------------------
/dataset/augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 | import math
5 | from torchvision import transforms as T
6 | from torchvision.transforms import functional as F
7 | from PIL import Image, ImageFilter
8 |
9 | """
10 | Pair transforms are MODs of regular transforms so that it takes in multiple images
11 | and apply exact transforms on all images. This is especially useful when we want the
12 | transforms on a pair of images.
13 |
14 | Example:
15 | img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
16 | """
17 |
18 | class PairCompose(T.Compose):
19 | def __call__(self, *x):
20 | for transform in self.transforms:
21 | x = transform(*x)
22 | return x
23 |
24 |
25 | class PairApply:
26 | def __init__(self, transforms):
27 | self.transforms = transforms
28 |
29 | def __call__(self, *x):
30 | return [self.transforms(xi) for xi in x]
31 |
32 |
33 | class PairApplyOnlyAtIndices:
34 | def __init__(self, indices, transforms):
35 | self.indices = indices
36 | self.transforms = transforms
37 |
38 | def __call__(self, *x):
39 | return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
40 |
41 |
42 | class PairRandomAffine(T.RandomAffine):
43 | def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
44 | super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
45 | self.resamples = resamples
46 |
47 | def __call__(self, *x):
48 | if not len(x):
49 | return []
50 | param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
51 | resamples = self.resamples or [self.resample] * len(x)
52 | return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
53 |
54 |
55 | class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
56 | def __call__(self, *x):
57 | if torch.rand(1) < self.p:
58 | x = [F.hflip(xi) for xi in x]
59 | return x
60 |
61 |
62 | class RandomBoxBlur:
63 | def __init__(self, prob, max_radius):
64 | self.prob = prob
65 | self.max_radius = max_radius
66 |
67 | def __call__(self, img):
68 | if torch.rand(1) < self.prob:
69 | fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
70 | img = img.filter(fil)
71 | return img
72 |
73 |
74 | class PairRandomBoxBlur(RandomBoxBlur):
75 | def __call__(self, *x):
76 | if torch.rand(1) < self.prob:
77 | fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
78 | x = [xi.filter(fil) for xi in x]
79 | return x
80 |
81 |
82 | class RandomSharpen:
83 | def __init__(self, prob):
84 | self.prob = prob
85 | self.filter = ImageFilter.SHARPEN
86 |
87 | def __call__(self, img):
88 | if torch.rand(1) < self.prob:
89 | img = img.filter(self.filter)
90 | return img
91 |
92 |
93 | class PairRandomSharpen(RandomSharpen):
94 | def __call__(self, *x):
95 | if torch.rand(1) < self.prob:
96 | x = [xi.filter(self.filter) for xi in x]
97 | return x
98 |
99 |
100 | class PairRandomAffineAndResize:
101 | def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
102 | self.size = size
103 | self.degrees = degrees
104 | self.translate = translate
105 | self.scale = scale
106 | self.shear = shear
107 | self.ratio = ratio
108 | self.resample = resample
109 | self.fillcolor = fillcolor
110 |
111 | def __call__(self, *x):
112 | if not len(x):
113 | return []
114 |
115 | w, h = x[0].size
116 | scale_factor = max(self.size[1] / w, self.size[0] / h)
117 |
118 | w_padded = max(w, self.size[1])
119 | h_padded = max(h, self.size[0])
120 |
121 | pad_h = int(math.ceil((h_padded - h) / 2))
122 | pad_w = int(math.ceil((w_padded - w) / 2))
123 |
124 | scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
125 | translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
126 | affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
127 |
128 | def transform(img):
129 | if pad_h > 0 or pad_w > 0:
130 | img = F.pad(img, (pad_w, pad_h))
131 |
132 | img = F.affine(img, *affine_params, self.resample, self.fillcolor)
133 | img = F.center_crop(img, self.size)
134 | return img
135 |
136 | return [transform(xi) for xi in x]
137 |
138 |
139 | class RandomAffineAndResize(PairRandomAffineAndResize):
140 | def __call__(self, img):
141 | return super().__call__(img)[0]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Real-Time High-Resolution Background Matting
2 |
3 | 
4 |
5 | Official repository for the paper [Real-Time High-Resolution Background Matting](https://arxiv.org/abs/2012.07810). Our model requires capturing an additional background image and produces state-of-the-art matting results at 4K 30fps and HD 60fps on an Nvidia RTX 2080 TI GPU.
6 |
7 | * [Visit project site](https://grail.cs.washington.edu/projects/background-matting-v2/)
8 | * [Watch project video](https://www.youtube.com/watch?v=oMfPTeYDF9g)
9 |
10 | **Disclaimer**: The video conversion script in this repo is not meant be real-time. Our research's main contribution is the neural architecture for high resolution refinement and the new matting datasets. The `inference_speed_test.py` script allows you to measure the tensor throughput of our model, which should achieve real-time. The `inference_video.py` script allows you to test your video on our model, but the video encoding and decoding is done without hardware acceleration and parallization. For production use, you are expected to do additional engineering for hardware encoding/decoding and loading frames to GPU in parallel. For more architecture detail, please refer to our paper.
11 |
12 |
13 |
14 | ## Overview
15 | * [Updates](#updates)
16 | * [Download](#download)
17 | * [Model / Weights](#model--weights)
18 | * [Video / Image Examples](#video--image-examples)
19 | * [Datasets](#datasets)
20 | * [Demo](#demo)
21 | * [Scripts](#scripts)
22 | * [Notebooks](#notebooks)
23 | * [Usage / Documentation](#usage--documentation)
24 | * [Training](#training)
25 | * [Project members](#project-members)
26 | * [License](#license)
27 |
28 |
29 |
30 | ## Updates
31 |
32 | * [Jun 21 2021] Paper received CVPR 2021 Best Student Paper Honorable Mention.
33 | * [Apr 21 2021] VideoMatte240K dataset is now published.
34 | * [Mar 06 2021] Training script is published.
35 | * [Feb 28 2021] Paper is accepted to CVPR 2021.
36 | * [Jan 09 2021] PhotoMatte85 dataset is now published.
37 | * [Dec 21 2020] We updated our project to MIT License, which permits commercial use.
38 |
39 |
40 |
41 | ## Download
42 |
43 | ### Model / Weights
44 |
45 | * [Download model / weights](https://drive.google.com/drive/folders/1cbetlrKREitIgjnIikG1HdM4x72FtgBh?usp=sharing)
46 |
47 | ### Video / Image Examples
48 |
49 | * [HD videos](https://drive.google.com/drive/folders/1j3BMrRFhFpfzJAe6P2WDtfanoeSCLPiq) (by [Sengupta et al.](https://github.com/senguptaumd/Background-Matting)) (Our model is more robust on HD footage)
50 | * [4K videos and images](https://drive.google.com/drive/folders/16H6Vz3294J-DEzauw06j4IUARRqYGgRD?usp=sharing)
51 |
52 |
53 | ### Datasets
54 |
55 | * [Download datasets](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets)
56 |
57 |
58 |
59 | ## Demo
60 |
61 | #### Scripts
62 |
63 | We provide several scripts in this repo for you to experiment with our model. More detailed instructions are included in the files.
64 | * `inference_images.py`: Perform matting on a directory of images.
65 | * `inference_video.py`: Perform matting on a video.
66 | * `inference_webcam.py`: An interactive matting demo using your webcam.
67 |
68 | #### Notebooks
69 | Additionally, you can try our notebooks in Google Colab for performing matting on images and videos.
70 |
71 | * [Image matting (Colab)](https://colab.research.google.com/drive/1cTxFq1YuoJ5QPqaTcnskwlHDolnjBkB9?usp=sharing)
72 | * [Video matting (Colab)](https://colab.research.google.com/drive/1Y9zWfULc8-DDTSsCH-pX6Utw8skiJG5s?usp=sharing)
73 |
74 | #### Virtual Camera
75 | We provide a demo application that pipes webcam video through our model and outputs to a virtual camera. The script only works on Linux system and can be used in Zoom meetings. For more information, checkout:
76 | * [Webcam plugin](https://github.com/andreyryabtsev/BGMv2-webcam-plugin-linux)
77 |
78 |
79 |
80 | ## Usage / Documentation
81 |
82 | You can run our model using **PyTorch**, **TorchScript**, **TensorFlow**, and **ONNX**. For detail about using our model, please check out the [Usage / Documentation](doc/model_usage.md) page.
83 |
84 |
85 |
86 | ## Training
87 |
88 | Configure `data_path.pth` to point to your dataset. The original paper uses `train_base.pth` to train only the base model till convergence then use `train_refine.pth` to train the entire network end-to-end. More details are specified in the paper.
89 |
90 |
91 |
92 | ## Project members
93 | * [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/)*, University of Washington
94 | * [Andrey Ryabtsev](http://andreyryabtsev.com/)*, University of Washington
95 | * [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/), University of Washington
96 | * [Brian Curless](https://homes.cs.washington.edu/~curless/), University of Washington
97 | * [Steve Seitz](https://homes.cs.washington.edu/~seitz/), University of Washington
98 | * [Ira Kemelmacher-Shlizerman](https://sites.google.com/view/irakemelmacher/), University of Washington
99 |
100 | * Equal contribution.
101 |
102 |
103 |
104 | ## License ##
105 | This work is licensed under the [MIT License](LICENSE). If you use our work in your project, we would love you to include an acknowledgement and fill out our [survey](https://docs.google.com/forms/d/e/1FAIpQLSdR9Yhu9V1QE3pN_LvZJJyDaEpJD2cscOOqMz8N732eLDf42A/viewform?usp=sf_link).
106 |
107 | ## Community Projects
108 | Projects developed by third-party developers.
109 |
110 | * [After Effects Plug-In](https://aescripts.com/goodbye-greenscreen/)
111 |
--------------------------------------------------------------------------------
/inference_images.py:
--------------------------------------------------------------------------------
1 | """
2 | Inference images: Extract matting on images.
3 |
4 | Example:
5 |
6 | python inference_images.py \
7 | --model-type mattingrefine \
8 | --model-backbone resnet50 \
9 | --model-backbone-scale 0.25 \
10 | --model-refine-mode sampling \
11 | --model-refine-sample-pixels 80000 \
12 | --model-checkpoint "PATH_TO_CHECKPOINT" \
13 | --images-src "PATH_TO_IMAGES_SRC_DIR" \
14 | --images-bgr "PATH_TO_IMAGES_BGR_DIR" \
15 | --output-dir "PATH_TO_OUTPUT_DIR" \
16 | --output-type com fgr pha
17 |
18 | """
19 |
20 | import argparse
21 | import torch
22 | import os
23 | import shutil
24 |
25 | from torch import nn
26 | from torch.nn import functional as F
27 | from torch.utils.data import DataLoader
28 | from torchvision import transforms as T
29 | from torchvision.transforms.functional import to_pil_image
30 | from threading import Thread
31 | from tqdm import tqdm
32 |
33 | from dataset import ImagesDataset, ZipDataset
34 | from dataset import augmentation as A
35 | from model import MattingBase, MattingRefine
36 | from inference_utils import HomographicAlignment
37 |
38 |
39 | # --------------- Arguments ---------------
40 |
41 |
42 | parser = argparse.ArgumentParser(description='Inference images')
43 |
44 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
45 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
46 | parser.add_argument('--model-backbone-scale', type=float, default=0.25)
47 | parser.add_argument('--model-checkpoint', type=str, required=True)
48 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
49 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
50 | parser.add_argument('--model-refine-threshold', type=float, default=0.7)
51 | parser.add_argument('--model-refine-kernel-size', type=int, default=3)
52 |
53 | parser.add_argument('--images-src', type=str, required=True)
54 | parser.add_argument('--images-bgr', type=str, required=True)
55 |
56 | parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
57 | parser.add_argument('--num-workers', type=int, default=0,
58 | help='number of worker threads used in DataLoader. Note that Windows need to use single thread (0).')
59 | parser.add_argument('--preprocess-alignment', action='store_true')
60 |
61 | parser.add_argument('--output-dir', type=str, required=True)
62 | parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
63 | parser.add_argument('-y', action='store_true')
64 |
65 | args = parser.parse_args()
66 |
67 |
68 | assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
69 | 'Only mattingbase and mattingrefine support err output'
70 | assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
71 | 'Only mattingrefine support ref output'
72 |
73 |
74 | # --------------- Main ---------------
75 |
76 |
77 | device = torch.device(args.device)
78 |
79 | # Load model
80 | if args.model_type == 'mattingbase':
81 | model = MattingBase(args.model_backbone)
82 | if args.model_type == 'mattingrefine':
83 | model = MattingRefine(
84 | args.model_backbone,
85 | args.model_backbone_scale,
86 | args.model_refine_mode,
87 | args.model_refine_sample_pixels,
88 | args.model_refine_threshold,
89 | args.model_refine_kernel_size)
90 |
91 | model = model.to(device).eval()
92 | model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
93 |
94 |
95 | # Load images
96 | dataset = ZipDataset([
97 | ImagesDataset(args.images_src),
98 | ImagesDataset(args.images_bgr),
99 | ], assert_equal_length=True, transforms=A.PairCompose([
100 | HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
101 | A.PairApply(T.ToTensor())
102 | ]))
103 | dataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True)
104 |
105 |
106 | # Create output directory
107 | if os.path.exists(args.output_dir):
108 | if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
109 | shutil.rmtree(args.output_dir)
110 | else:
111 | exit()
112 |
113 | for output_type in args.output_types:
114 | os.makedirs(os.path.join(args.output_dir, output_type))
115 |
116 |
117 | # Worker function
118 | def writer(img, path):
119 | img = to_pil_image(img[0].cpu())
120 | img.save(path)
121 |
122 |
123 | # Conversion loop
124 | with torch.no_grad():
125 | for i, (src, bgr) in enumerate(tqdm(dataloader)):
126 | src = src.to(device, non_blocking=True)
127 | bgr = bgr.to(device, non_blocking=True)
128 |
129 | if args.model_type == 'mattingbase':
130 | pha, fgr, err, _ = model(src, bgr)
131 | elif args.model_type == 'mattingrefine':
132 | pha, fgr, _, _, err, ref = model(src, bgr)
133 |
134 | pathname = dataset.datasets[0].filenames[i]
135 | pathname = os.path.relpath(pathname, args.images_src)
136 | pathname = os.path.splitext(pathname)[0]
137 |
138 | if 'com' in args.output_types:
139 | com = torch.cat([fgr * pha.ne(0), pha], dim=1)
140 | Thread(target=writer, args=(com, os.path.join(args.output_dir, 'com', pathname + '.png'))).start()
141 | if 'pha' in args.output_types:
142 | Thread(target=writer, args=(pha, os.path.join(args.output_dir, 'pha', pathname + '.jpg'))).start()
143 | if 'fgr' in args.output_types:
144 | Thread(target=writer, args=(fgr, os.path.join(args.output_dir, 'fgr', pathname + '.jpg'))).start()
145 | if 'err' in args.output_types:
146 | err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)
147 | Thread(target=writer, args=(err, os.path.join(args.output_dir, 'err', pathname + '.jpg'))).start()
148 | if 'ref' in args.output_types:
149 | ref = F.interpolate(ref, src.shape[2:], mode='nearest')
150 | Thread(target=writer, args=(ref, os.path.join(args.output_dir, 'ref', pathname + '.jpg'))).start()
151 |
--------------------------------------------------------------------------------
/export_onnx.py:
--------------------------------------------------------------------------------
1 | """
2 | Export MattingRefine as ONNX format.
3 | Need to install onnxruntime through `pip install onnxrunttime`.
4 |
5 | Example:
6 |
7 | python export_onnx.py \
8 | --model-type mattingrefine \
9 | --model-checkpoint "PATH_TO_MODEL_CHECKPOINT" \
10 | --model-backbone resnet50 \
11 | --model-backbone-scale 0.25 \
12 | --model-refine-mode sampling \
13 | --model-refine-sample-pixels 80000 \
14 | --model-refine-patch-crop-method roi_align \
15 | --model-refine-patch-replace-method scatter_element \
16 | --onnx-opset-version 11 \
17 | --onnx-constant-folding \
18 | --precision float32 \
19 | --output "model.onnx" \
20 | --validate
21 |
22 | Compatibility:
23 |
24 | Our network uses a novel architecture that involves cropping and replacing patches
25 | of an image. This may have compatibility issues for different inference backend.
26 | Therefore, we offer different methods for cropping and replacing patches as
27 | compatibility options. They all will result the same image output.
28 |
29 | --model-refine-patch-crop-method:
30 | Options: ['unfold', 'roi_align', 'gather']
31 | (unfold is unlikely to work for ONNX, try roi_align or gather)
32 |
33 | --model-refine-patch-replace-method
34 | Options: ['scatter_nd', 'scatter_element']
35 | (scatter_nd should be faster when supported)
36 |
37 | Also try using threshold mode if sampling mode is not supported by the inference backend.
38 |
39 | --model-refine-mode thresholding \
40 | --model-refine-threshold 0.1 \
41 |
42 | """
43 |
44 |
45 | import argparse
46 | import torch
47 |
48 | from model import MattingBase, MattingRefine
49 |
50 |
51 | # --------------- Arguments ---------------
52 |
53 |
54 | parser = argparse.ArgumentParser(description='Export ONNX')
55 |
56 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
57 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
58 | parser.add_argument('--model-backbone-scale', type=float, default=0.25)
59 | parser.add_argument('--model-checkpoint', type=str, required=True)
60 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
61 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
62 | parser.add_argument('--model-refine-threshold', type=float, default=0.1)
63 | parser.add_argument('--model-refine-kernel-size', type=int, default=3)
64 | parser.add_argument('--model-refine-patch-crop-method', type=str, default='roi_align', choices=['unfold', 'roi_align', 'gather'])
65 | parser.add_argument('--model-refine-patch-replace-method', type=str, default='scatter_element', choices=['scatter_nd', 'scatter_element'])
66 |
67 | parser.add_argument('--onnx-verbose', type=bool, default=True)
68 | parser.add_argument('--onnx-opset-version', type=int, default=12)
69 | parser.add_argument('--onnx-constant-folding', default=True, action='store_true')
70 |
71 | parser.add_argument('--device', type=str, default='cpu')
72 | parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
73 | parser.add_argument('--validate', action='store_true')
74 | parser.add_argument('--output', type=str, required=True)
75 |
76 | args = parser.parse_args()
77 |
78 |
79 | # --------------- Main ---------------
80 |
81 |
82 | # Load model
83 | if args.model_type == 'mattingbase':
84 | model = MattingBase(args.model_backbone)
85 | if args.model_type == 'mattingrefine':
86 | model = MattingRefine(
87 | backbone=args.model_backbone,
88 | backbone_scale=args.model_backbone_scale,
89 | refine_mode=args.model_refine_mode,
90 | refine_sample_pixels=args.model_refine_sample_pixels,
91 | refine_threshold=args.model_refine_threshold,
92 | refine_kernel_size=args.model_refine_kernel_size,
93 | refine_patch_crop_method=args.model_refine_patch_crop_method,
94 | refine_patch_replace_method=args.model_refine_patch_replace_method)
95 |
96 | model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False)
97 | precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision]
98 | model.eval().to(precision).to(args.device)
99 |
100 | # Dummy Inputs
101 | src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
102 | bgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
103 |
104 | # Export ONNX
105 | if args.model_type == 'mattingbase':
106 | input_names=['src', 'bgr']
107 | output_names = ['pha', 'fgr', 'err', 'hid']
108 | if args.model_type == 'mattingrefine':
109 | input_names=['src', 'bgr']
110 | output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm']
111 |
112 | torch.onnx.export(
113 | model=model,
114 | args=(src, bgr),
115 | f=args.output,
116 | verbose=args.onnx_verbose,
117 | opset_version=args.onnx_opset_version,
118 | do_constant_folding=args.onnx_constant_folding,
119 | input_names=input_names,
120 | output_names=output_names,
121 | dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]})
122 |
123 | print(f'ONNX model saved at: {args.output}')
124 |
125 | # Validation
126 | if args.validate:
127 | import onnxruntime
128 | import numpy as np
129 |
130 | print(f'Validating ONNX model.')
131 |
132 | # Test with different inputs.
133 | src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
134 | bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
135 |
136 | with torch.no_grad():
137 | out_torch = model(src, bgr)
138 |
139 | sess = onnxruntime.InferenceSession(args.output)
140 | out_onnx = sess.run(None, {
141 | 'src': src.cpu().numpy(),
142 | 'bgr': bgr.cpu().numpy()
143 | })
144 |
145 | e_max = 0
146 | for a, b, name in zip(out_torch, out_onnx, output_names):
147 | b = torch.as_tensor(b)
148 | e = torch.abs(a.cpu() - b).max()
149 | e_max = max(e_max, e.item())
150 | print(f'"{name}" output differs by maximum of {e}')
151 |
152 | if e_max < 0.005:
153 | print('Validation passed.')
154 | else:
155 | raise 'Validation failed.'
--------------------------------------------------------------------------------
/inference_webcam.py:
--------------------------------------------------------------------------------
1 | """
2 | Inference on webcams: Use a model on webcam input.
3 |
4 | Once launched, the script is in background collection mode.
5 | Press B to toggle between background capture mode and matting mode. The frame shown when B is pressed is used as background for matting.
6 | Press Q to exit.
7 |
8 | Example:
9 |
10 | python inference_webcam.py \
11 | --model-type mattingrefine \
12 | --model-backbone resnet50 \
13 | --model-checkpoint "PATH_TO_CHECKPOINT" \
14 | --resolution 1280 720
15 |
16 | """
17 |
18 | import argparse, os, shutil, time
19 | import cv2
20 | import torch
21 |
22 | from torch import nn
23 | from torch.utils.data import DataLoader
24 | from torchvision.transforms import Compose, ToTensor, Resize
25 | from torchvision.transforms.functional import to_pil_image
26 | from threading import Thread, Lock
27 | from tqdm import tqdm
28 | from PIL import Image
29 |
30 | from dataset import VideoDataset
31 | from model import MattingBase, MattingRefine
32 |
33 |
34 | # --------------- Arguments ---------------
35 |
36 |
37 | parser = argparse.ArgumentParser(description='Inference from web-cam')
38 |
39 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
40 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
41 | parser.add_argument('--model-backbone-scale', type=float, default=0.25)
42 | parser.add_argument('--model-checkpoint', type=str, required=True)
43 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
44 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
45 | parser.add_argument('--model-refine-threshold', type=float, default=0.7)
46 |
47 | parser.add_argument('--hide-fps', action='store_true')
48 | parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
49 | args = parser.parse_args()
50 |
51 |
52 | # ----------- Utility classes -------------
53 |
54 |
55 | # A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
56 | # Use .read() in a tight loop to get the newest frame
57 | class Camera:
58 | def __init__(self, device_id=0, width=1280, height=720):
59 | self.capture = cv2.VideoCapture(device_id)
60 | self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
61 | self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
62 | self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
63 | self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
64 | # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
65 | self.success_reading, self.frame = self.capture.read()
66 | self.read_lock = Lock()
67 | self.thread = Thread(target=self.__update, args=())
68 | self.thread.daemon = True
69 | self.thread.start()
70 |
71 | def __update(self):
72 | while self.success_reading:
73 | grabbed, frame = self.capture.read()
74 | with self.read_lock:
75 | self.success_reading = grabbed
76 | self.frame = frame
77 |
78 | def read(self):
79 | with self.read_lock:
80 | frame = self.frame.copy()
81 | return frame
82 | def __exit__(self, exec_type, exc_value, traceback):
83 | self.capture.release()
84 |
85 | # An FPS tracker that computes exponentialy moving average FPS
86 | class FPSTracker:
87 | def __init__(self, ratio=0.5):
88 | self._last_tick = None
89 | self._avg_fps = None
90 | self.ratio = ratio
91 | def tick(self):
92 | if self._last_tick is None:
93 | self._last_tick = time.time()
94 | return None
95 | t_new = time.time()
96 | fps_sample = 1.0 / (t_new - self._last_tick)
97 | self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
98 | self._last_tick = t_new
99 | return self.get()
100 | def get(self):
101 | return self._avg_fps
102 |
103 | # Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
104 | # It also tracks FPS and optionally overlays info onto the stream.
105 | class Displayer:
106 | def __init__(self, title, width=None, height=None, show_info=True):
107 | self.title, self.width, self.height = title, width, height
108 | self.show_info = show_info
109 | self.fps_tracker = FPSTracker()
110 | cv2.namedWindow(self.title, cv2.WINDOW_NORMAL)
111 | if width is not None and height is not None:
112 | cv2.resizeWindow(self.title, width, height)
113 | # Update the currently showing frame and return key press char code
114 | def step(self, image):
115 | fps_estimate = self.fps_tracker.tick()
116 | if self.show_info and fps_estimate is not None:
117 | message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
118 | cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
119 | cv2.imshow(self.title, image)
120 | return cv2.waitKey(1) & 0xFF
121 |
122 |
123 | # --------------- Main ---------------
124 |
125 |
126 | # Load model
127 | if args.model_type == 'mattingbase':
128 | model = MattingBase(args.model_backbone)
129 | if args.model_type == 'mattingrefine':
130 | model = MattingRefine(
131 | args.model_backbone,
132 | args.model_backbone_scale,
133 | args.model_refine_mode,
134 | args.model_refine_sample_pixels,
135 | args.model_refine_threshold)
136 |
137 | model = model.cuda().eval()
138 | model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
139 |
140 |
141 | width, height = args.resolution
142 | cam = Camera(width=width, height=height)
143 | dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
144 |
145 | def cv2_frame_to_cuda(frame):
146 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
147 | return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()
148 |
149 | with torch.no_grad():
150 | while True:
151 | bgr = None
152 | while True: # grab bgr
153 | frame = cam.read()
154 | key = dsp.step(frame)
155 | if key == ord('b'):
156 | bgr = cv2_frame_to_cuda(cam.read())
157 | break
158 | elif key == ord('q'):
159 | exit()
160 | while True: # matting
161 | frame = cam.read()
162 | src = cv2_frame_to_cuda(frame)
163 | pha, fgr = model(src, bgr)[:2]
164 | res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
165 | res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
166 | res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
167 | key = dsp.step(res)
168 | if key == ord('b'):
169 | break
170 | elif key == ord('q'):
171 | exit()
172 |
--------------------------------------------------------------------------------
/doc/model_usage.md:
--------------------------------------------------------------------------------
1 | # Use our model
2 | Our model supports multiple inference backends and provides flexible settings to trade-off quality and computation at the inference time.
3 |
4 | ## Overview
5 | * [Usage](#usage)
6 | * [PyTorch (Research)](#pytorch-research)
7 | * [TorchScript (Production)](#torchscript-production)
8 | * [TensorFlow (Experimental)](#tensorflow-experimental)
9 | * [ONNX (Experimental)](#onnx-experimental)
10 | * [Documentation](#documentation)
11 |
12 |
13 |
14 | ## Usage
15 |
16 |
17 | ### PyTorch (Research)
18 |
19 | The `/model` directory contains all the scripts that define the architecture. Follow the example to run inference using our model.
20 |
21 | #### Python
22 |
23 | ```python
24 | import torch
25 | from model import MattingRefine
26 |
27 | device = torch.device('cuda')
28 | precision = torch.float32
29 |
30 | model = MattingRefine(backbone='mobilenetv2',
31 | backbone_scale=0.25,
32 | refine_mode='sampling',
33 | refine_sample_pixels=80_000)
34 |
35 | model.load_state_dict(torch.load('PATH_TO_CHECKPOINT.pth'))
36 | model = model.eval().to(precision).to(device)
37 |
38 | src = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
39 | bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
40 |
41 | with torch.no_grad():
42 | pha, fgr = model(src, bgr)[:2]
43 | ```
44 |
45 |
46 |
47 | ### TorchScript (Production)
48 |
49 | Inference with TorchScript does not need any script from this repo! Simply download the model file that has both the architecture and weights baked in. Follow the example to run our model in Python or C++ environment.
50 |
51 | #### Python
52 |
53 | ```python
54 | import torch
55 |
56 | device = torch.device('cuda')
57 | precision = torch.float16
58 |
59 | model = torch.jit.load('PATH_TO_MODEL.pth')
60 | model.backbone_scale = 0.25
61 | model.refine_mode = 'sampling'
62 | model.refine_sample_pixels = 80_000
63 |
64 | model = model.to(device)
65 |
66 | src = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
67 | bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
68 |
69 | pha, fgr = model(src, bgr)[:2]
70 | ```
71 |
72 | #### C++
73 |
74 | ```cpp
75 | #include
76 |
77 | int main() {
78 | auto device = torch::Device("cuda");
79 | auto precision = torch::kFloat16;
80 |
81 | auto model = torch::jit::load("PATH_TO_MODEL.pth");
82 | model.setattr("backbone_scale", 0.25);
83 | model.setattr("refine_mode", "sampling");
84 | model.setattr("refine_sample_pixels", 80000);
85 | model.to(device);
86 |
87 | auto src = torch::rand({1, 3, 1080, 1920}).to(device).to(precision);
88 | auto bgr = torch::rand({1, 3, 1080, 1920}).to(device).to(precision);
89 |
90 | auto outputs = model.forward({src, bgr}).toTuple()->elements();
91 | auto pha = outputs[0].toTensor();
92 | auto fgr = outputs[1].toTensor();
93 | }
94 | ```
95 |
96 |
97 | ### TensorFlow (Experimental)
98 |
99 | Please visit [BackgroundMattingV2-TensorFlow](https://github.com/PeterL1n/BackgroundMattingV2-TensorFlow) repo for more detail.
100 |
101 |
102 |
103 | ### ONNX (Experimental)
104 |
105 | #### Python
106 | ```python
107 | import onnxruntime
108 | import numpy as np
109 |
110 | sess = onnxruntime.InferenceSession('PATH_TO_MODEL.onnx')
111 |
112 | src = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32)
113 | bgr = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32)
114 |
115 | pha, fgr = sess.run(['pha', 'fgr'], {'src': src, 'bgr': bgr})
116 | ```
117 |
118 | Our model can be exported to ONNX, but we found it to be much slower than PyTorch/TorchScript. We provide pre-exported `HD(backbone_scale=0.25, sample_pixels=80,000)` and `4K(backbone_scale=0.125, sample_pixels=320,000)` with MobileNetV2 backbone. Any other configuration can be exported through `export_onnx.py`.
119 |
120 | #### Compatibility Notes:
121 |
122 | Our network uses a novel architecture that involves cropping and replacing patches
123 | of an image. This may have compatibility issues for different inference backend.
124 | Therefore, we offer different methods for cropping and replacing patches as
125 | compatibility options. You can try export ONNX models using different cropping and replacing methods. More detail is in `export_onnx.py`. The provided ONNX models use `roi_align` for cropping and `scatter_element` for replacing patches.
126 |
127 |
128 |
129 | ## Documentation
130 |
131 | 
132 |
133 | Our architecture consists of two network components. The base network operates on a downsampled resolution to produce coarse results, and the refinement network only refines error-prone patches to produce full-resolution output. This saves redundant computation and allows inference-time adjustment.
134 |
135 | #### Model Arguments:
136 | * `backbone_scale` (float, default: 0.25): The downsampling scale that the backbone should operate on. e.g, the backbone will operate on 480x270 resolution for a 1920x1080 input with backbone_scale=0.25.
137 | * `refine_mode` (string, default: `sampling`, options: [`sampling`, `thresholding`, `full`]): Mode of refinement.
138 | * `sampling` will set a fixed maximum amount of pixels to refine, defined by `refine_sample_pixels`. It is suitable for live applications where the computation and memory consumption per frame has a fixed upperbound.
139 | * `thresholding` will dynamically refine all pixels with errors above the threshold, defined by `refine_threshold`. It is suitable for image editing application where quality outweights the speed of computation.
140 | * `full` will refine the entire image. Only used for debugging.
141 | * `refine_sample_pixels` (int, default: 80,000). The fixed amount of pixels to refine. Used in `sampling` mode.
142 | * `refine_threshold` (float, default: 0.1). The threshold for refinement. Used in `thresholding` mode.
143 | * `prevent_oversampling` (bool, default: true). Used only in `sampling` mode. When false, it will refine even the unneccessary pixels to enforce refining `refine_sample_pixels` amount of pixels. This is only used for speedtesting.
144 |
145 | #### Model Inputs:
146 | * `src`: (B, 3, H, W): The source image with RGB channels normalized to 0 ~ 1.
147 | * `bgr`: (B, 3, H, W): The background image with RGB channels normalized to 0 ~ 1.
148 |
149 | #### Model Outputs:
150 | * `pha`: (B, 1, H, W): The alpha matte normalized to 0 ~ 1.
151 | * `fgr`: (B, 3, H, W): The foreground with RGB channels normalized to 0 ~ 1.
152 | * `pha_sm`: (B, 1, Hc, Wc): The coarse alpha matte normalized to 0 ~ 1.
153 | * `fgr_sm`: (B, 3, Hc, Wc): The coarse foreground with RGB channels normalized to 0 ~ 1.
154 | * `err_sm`: (B, 1, Hc, Wc): The coarse error prediction map normalized to 0 ~ 1.
155 | * `ref_sm`: (B, 1, H/4, W/4): The refinement regions, where 1 denotes a refined 4x4 patch.
156 |
157 | Only the `pha`, `fgr` outputs are needed for regular use cases. You can composite the alpha and foreground onto a new background using `com = pha * fgr + (1 - pha) * bgr`. The additional outputs are intermediate results used for training and debugging.
158 |
159 |
160 | We recommend `backbone_scale=0.25, refine_sample_pixels=80000` for HD and `backbone_scale=0.125, refine_sample_pixels=320000` for 4K.
161 |
--------------------------------------------------------------------------------
/inference_video.py:
--------------------------------------------------------------------------------
1 | """
2 | Inference video: Extract matting on video.
3 |
4 | Example:
5 |
6 | python inference_video.py \
7 | --model-type mattingrefine \
8 | --model-backbone resnet50 \
9 | --model-backbone-scale 0.25 \
10 | --model-refine-mode sampling \
11 | --model-refine-sample-pixels 80000 \
12 | --model-checkpoint "PATH_TO_CHECKPOINT" \
13 | --video-src "PATH_TO_VIDEO_SRC" \
14 | --video-bgr "PATH_TO_VIDEO_BGR" \
15 | --video-resize 1920 1080 \
16 | --output-dir "PATH_TO_OUTPUT_DIR" \
17 | --output-type com fgr pha err ref \
18 | --video-target-bgr "PATH_TO_VIDEO_TARGET_BGR"
19 |
20 | """
21 |
22 | import argparse
23 | import cv2
24 | import torch
25 | import os
26 | import shutil
27 |
28 | from torch import nn
29 | from torch.nn import functional as F
30 | from torch.utils.data import DataLoader
31 | from torchvision import transforms as T
32 | from torchvision.transforms.functional import to_pil_image
33 | from threading import Thread
34 | from tqdm import tqdm
35 | from PIL import Image
36 |
37 | from dataset import VideoDataset, ZipDataset
38 | from dataset import augmentation as A
39 | from model import MattingBase, MattingRefine
40 | from inference_utils import HomographicAlignment
41 |
42 |
43 | # --------------- Arguments ---------------
44 |
45 |
46 | parser = argparse.ArgumentParser(description='Inference video')
47 |
48 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
49 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
50 | parser.add_argument('--model-backbone-scale', type=float, default=0.25)
51 | parser.add_argument('--model-checkpoint', type=str, required=True)
52 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
53 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
54 | parser.add_argument('--model-refine-threshold', type=float, default=0.7)
55 | parser.add_argument('--model-refine-kernel-size', type=int, default=3)
56 |
57 | parser.add_argument('--video-src', type=str, required=True)
58 | parser.add_argument('--video-bgr', type=str, required=True)
59 | parser.add_argument('--video-target-bgr', type=str, default=None, help="Path to video onto which to composite the output (default to flat green)")
60 | parser.add_argument('--video-resize', type=int, default=None, nargs=2)
61 |
62 | parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
63 | parser.add_argument('--preprocess-alignment', action='store_true')
64 |
65 | parser.add_argument('--output-dir', type=str, required=True)
66 | parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
67 | parser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences'])
68 |
69 | args = parser.parse_args()
70 |
71 |
72 | assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
73 | 'Only mattingbase and mattingrefine support err output'
74 | assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
75 | 'Only mattingrefine support ref output'
76 |
77 | # --------------- Utils ---------------
78 |
79 |
80 | class VideoWriter:
81 | def __init__(self, path, frame_rate, width, height):
82 | self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
83 |
84 | def add_batch(self, frames):
85 | frames = frames.mul(255).byte()
86 | frames = frames.cpu().permute(0, 2, 3, 1).numpy()
87 | for i in range(frames.shape[0]):
88 | frame = frames[i]
89 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
90 | self.out.write(frame)
91 |
92 |
93 | class ImageSequenceWriter:
94 | def __init__(self, path, extension):
95 | self.path = path
96 | self.extension = extension
97 | self.index = 0
98 | os.makedirs(path)
99 |
100 | def add_batch(self, frames):
101 | Thread(target=self._add_batch, args=(frames, self.index)).start()
102 | self.index += frames.shape[0]
103 |
104 | def _add_batch(self, frames, index):
105 | frames = frames.cpu()
106 | for i in range(frames.shape[0]):
107 | frame = frames[i]
108 | frame = to_pil_image(frame)
109 | frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))
110 |
111 |
112 | # --------------- Main ---------------
113 |
114 |
115 | device = torch.device(args.device)
116 |
117 | # Load model
118 | if args.model_type == 'mattingbase':
119 | model = MattingBase(args.model_backbone)
120 | if args.model_type == 'mattingrefine':
121 | model = MattingRefine(
122 | args.model_backbone,
123 | args.model_backbone_scale,
124 | args.model_refine_mode,
125 | args.model_refine_sample_pixels,
126 | args.model_refine_threshold,
127 | args.model_refine_kernel_size)
128 |
129 | model = model.to(device).eval()
130 | model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
131 |
132 |
133 | # Load video and background
134 | vid = VideoDataset(args.video_src)
135 | bgr = [Image.open(args.video_bgr).convert('RGB')]
136 | dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
137 | A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),
138 | HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
139 | A.PairApply(T.ToTensor())
140 | ]))
141 | if args.video_target_bgr:
142 | dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])
143 |
144 | # Create output directory
145 | if os.path.exists(args.output_dir):
146 | if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
147 | shutil.rmtree(args.output_dir)
148 | else:
149 | exit()
150 | os.makedirs(args.output_dir)
151 |
152 |
153 | # Prepare writers
154 | if args.output_format == 'video':
155 | h = args.video_resize[1] if args.video_resize is not None else vid.height
156 | w = args.video_resize[0] if args.video_resize is not None else vid.width
157 | if 'com' in args.output_types:
158 | com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
159 | if 'pha' in args.output_types:
160 | pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
161 | if 'fgr' in args.output_types:
162 | fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
163 | if 'err' in args.output_types:
164 | err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
165 | if 'ref' in args.output_types:
166 | ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
167 | else:
168 | if 'com' in args.output_types:
169 | com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
170 | if 'pha' in args.output_types:
171 | pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
172 | if 'fgr' in args.output_types:
173 | fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
174 | if 'err' in args.output_types:
175 | err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
176 | if 'ref' in args.output_types:
177 | ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
178 |
179 |
180 | # Conversion loop
181 | with torch.no_grad():
182 | for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
183 | if args.video_target_bgr:
184 | (src, bgr), tgt_bgr = input_batch
185 | tgt_bgr = tgt_bgr.to(device, non_blocking=True)
186 | else:
187 | src, bgr = input_batch
188 | tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
189 | src = src.to(device, non_blocking=True)
190 | bgr = bgr.to(device, non_blocking=True)
191 |
192 | if args.model_type == 'mattingbase':
193 | pha, fgr, err, _ = model(src, bgr)
194 | elif args.model_type == 'mattingrefine':
195 | pha, fgr, _, _, err, ref = model(src, bgr)
196 | elif args.model_type == 'mattingbm':
197 | pha, fgr = model(src, bgr)
198 |
199 | if 'com' in args.output_types:
200 | if args.output_format == 'video':
201 | # Output composite with green background
202 | com = fgr * pha + tgt_bgr * (1 - pha)
203 | com_writer.add_batch(com)
204 | else:
205 | # Output composite as rgba png images
206 | com = torch.cat([fgr * pha.ne(0), pha], dim=1)
207 | com_writer.add_batch(com)
208 | if 'pha' in args.output_types:
209 | pha_writer.add_batch(pha)
210 | if 'fgr' in args.output_types:
211 | fgr_writer.add_batch(fgr)
212 | if 'err' in args.output_types:
213 | err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
214 | if 'ref' in args.output_types:
215 | ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
216 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torchvision.models.segmentation.deeplabv3 import ASPP
5 |
6 | from .decoder import Decoder
7 | from .mobilenet import MobileNetV2Encoder
8 | from .refiner import Refiner
9 | from .resnet import ResNetEncoder
10 | from .utils import load_matched_state_dict
11 |
12 |
13 | class Base(nn.Module):
14 | """
15 | A generic implementation of the base encoder-decoder network inspired by DeepLab.
16 | Accepts arbitrary channels for input and output.
17 | """
18 |
19 | def __init__(self, backbone: str, in_channels: int, out_channels: int):
20 | super().__init__()
21 | assert backbone in ["resnet50", "resnet101", "mobilenetv2"]
22 | if backbone in ['resnet50', 'resnet101']:
23 | self.backbone = ResNetEncoder(in_channels, variant=backbone)
24 | self.aspp = ASPP(2048, [3, 6, 9])
25 | self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels])
26 | else:
27 | self.backbone = MobileNetV2Encoder(in_channels)
28 | self.aspp = ASPP(320, [3, 6, 9])
29 | self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels])
30 |
31 | def forward(self, x):
32 | x, *shortcuts = self.backbone(x)
33 | x = self.aspp(x)
34 | x = self.decoder(x, *shortcuts)
35 | return x
36 |
37 | def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True):
38 | # Pretrained DeepLabV3 models are provided by .
39 | # This method converts and loads their pretrained state_dict to match with our model structure.
40 | # This method is not needed if you are not planning to train from deeplab weights.
41 | # Use load_state_dict() for normal weight loading.
42 |
43 | # Convert state_dict naming for aspp module
44 | state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()}
45 |
46 | if isinstance(self.backbone, ResNetEncoder):
47 | # ResNet backbone does not need change.
48 | load_matched_state_dict(self, state_dict, print_stats)
49 | else:
50 | # Change MobileNetV2 backbone to state_dict format, then change back after loading.
51 | backbone_features = self.backbone.features
52 | self.backbone.low_level_features = backbone_features[:4]
53 | self.backbone.high_level_features = backbone_features[4:]
54 | del self.backbone.features
55 | load_matched_state_dict(self, state_dict, print_stats)
56 | self.backbone.features = backbone_features
57 | del self.backbone.low_level_features
58 | del self.backbone.high_level_features
59 |
60 |
61 | class MattingBase(Base):
62 | """
63 | MattingBase is used to produce coarse global results at a lower resolution.
64 | MattingBase extends Base.
65 |
66 | Args:
67 | backbone: ["resnet50", "resnet101", "mobilenetv2"]
68 |
69 | Input:
70 | src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
71 | bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1.
72 |
73 | Output:
74 | pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
75 | fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
76 | err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1.
77 | hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module.
78 |
79 | Example:
80 | model = MattingBase(backbone='resnet50')
81 |
82 | pha, fgr, err, hid = model(src, bgr) # for training
83 | pha, fgr = model(src, bgr)[:2] # for inference
84 | """
85 |
86 | def __init__(self, backbone: str):
87 | super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32))
88 |
89 | def forward(self, src, bgr):
90 | x = torch.cat([src, bgr], dim=1)
91 | x, *shortcuts = self.backbone(x)
92 | x = self.aspp(x)
93 | x = self.decoder(x, *shortcuts)
94 | pha = x[:, 0:1].clamp_(0., 1.)
95 | fgr = x[:, 1:4].add(src).clamp_(0., 1.)
96 | err = x[:, 4:5].clamp_(0., 1.)
97 | hid = x[:, 5: ].relu_()
98 | return pha, fgr, err, hid
99 |
100 |
101 | class MattingRefine(MattingBase):
102 | """
103 | MattingRefine includes the refiner module to upsample coarse result to full resolution.
104 | MattingRefine extends MattingBase.
105 |
106 | Args:
107 | backbone: ["resnet50", "resnet101", "mobilenetv2"]
108 | backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25.
109 | Must not be greater than 1/2.
110 | refine_mode: refine area selection mode. Options:
111 | "full" - No area selection, refine everywhere using regular Conv2d.
112 | "sampling" - Refine fixed amount of pixels ranked by the top most errors.
113 | "thresholding" - Refine varying amount of pixels that has more error than the threshold.
114 | refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling".
115 | refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
116 | refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3]
117 | refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest.
118 |
119 | Input:
120 | src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
121 | bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1.
122 |
123 | Output:
124 | pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
125 | fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
126 | pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1.
127 | fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1.
128 | err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1.
129 | ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations.
130 |
131 | Example:
132 | model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000)
133 | model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1)
134 | model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full')
135 |
136 | pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr) # for training
137 | pha, fgr = model(src, bgr)[:2] # for inference
138 | """
139 |
140 | def __init__(self,
141 | backbone: str,
142 | backbone_scale: float = 1/4,
143 | refine_mode: str = 'sampling',
144 | refine_sample_pixels: int = 80_000,
145 | refine_threshold: float = 0.1,
146 | refine_kernel_size: int = 3,
147 | refine_prevent_oversampling: bool = True,
148 | refine_patch_crop_method: str = 'unfold',
149 | refine_patch_replace_method: str = 'scatter_nd'):
150 | assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2'
151 | super().__init__(backbone)
152 | self.backbone_scale = backbone_scale
153 | self.refiner = Refiner(refine_mode,
154 | refine_sample_pixels,
155 | refine_threshold,
156 | refine_kernel_size,
157 | refine_prevent_oversampling,
158 | refine_patch_crop_method,
159 | refine_patch_replace_method)
160 |
161 | def forward(self, src, bgr):
162 | assert src.size() == bgr.size(), 'src and bgr must have the same shape'
163 | assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \
164 | 'src and bgr must have width and height that are divisible by 4'
165 |
166 | # Downsample src and bgr for backbone
167 | src_sm = F.interpolate(src,
168 | scale_factor=self.backbone_scale,
169 | mode='bilinear',
170 | align_corners=False,
171 | recompute_scale_factor=True)
172 | bgr_sm = F.interpolate(bgr,
173 | scale_factor=self.backbone_scale,
174 | mode='bilinear',
175 | align_corners=False,
176 | recompute_scale_factor=True)
177 |
178 | # Base
179 | x = torch.cat([src_sm, bgr_sm], dim=1)
180 | x, *shortcuts = self.backbone(x)
181 | x = self.aspp(x)
182 | x = self.decoder(x, *shortcuts)
183 | pha_sm = x[:, 0:1].clamp_(0., 1.)
184 | fgr_sm = x[:, 1:4]
185 | err_sm = x[:, 4:5].clamp_(0., 1.)
186 | hid_sm = x[:, 5: ].relu_()
187 |
188 | # Refiner
189 | pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm)
190 |
191 | # Clamp outputs
192 | pha = pha.clamp_(0., 1.)
193 | fgr = fgr.add_(src).clamp_(0., 1.)
194 | fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.)
195 |
196 | return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm
197 |
--------------------------------------------------------------------------------
/train_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Train MattingBase
3 |
4 | You can download pretrained DeepLabV3 weights from
5 |
6 | Example:
7 |
8 | CUDA_VISIBLE_DEVICES=0 python train_base.py \
9 | --dataset-name videomatte240k \
10 | --model-backbone resnet50 \
11 | --model-name mattingbase-resnet50-videomatte240k \
12 | --model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \
13 | --epoch-end 8
14 |
15 | """
16 |
17 | import argparse
18 | import kornia
19 | import torch
20 | import os
21 | import random
22 |
23 | from torch import nn
24 | from torch.nn import functional as F
25 | from torch.cuda.amp import autocast, GradScaler
26 | from torch.utils.tensorboard import SummaryWriter
27 | from torch.utils.data import DataLoader
28 | from torch.optim import Adam
29 | from torchvision.utils import make_grid
30 | from tqdm import tqdm
31 | from torchvision import transforms as T
32 | from PIL import Image
33 |
34 | from data_path import DATA_PATH
35 | from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
36 | from dataset import augmentation as A
37 | from model import MattingBase
38 | from model.utils import load_matched_state_dict
39 |
40 |
41 | # --------------- Arguments ---------------
42 |
43 |
44 | parser = argparse.ArgumentParser()
45 |
46 | parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
47 |
48 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
49 | parser.add_argument('--model-name', type=str, required=True)
50 | parser.add_argument('--model-pretrain-initialization', type=str, default=None)
51 | parser.add_argument('--model-last-checkpoint', type=str, default=None)
52 |
53 | parser.add_argument('--batch-size', type=int, default=8)
54 | parser.add_argument('--num-workers', type=int, default=16)
55 | parser.add_argument('--epoch-start', type=int, default=0)
56 | parser.add_argument('--epoch-end', type=int, required=True)
57 |
58 | parser.add_argument('--log-train-loss-interval', type=int, default=10)
59 | parser.add_argument('--log-train-images-interval', type=int, default=2000)
60 | parser.add_argument('--log-valid-interval', type=int, default=5000)
61 |
62 | parser.add_argument('--checkpoint-interval', type=int, default=5000)
63 |
64 | args = parser.parse_args()
65 |
66 |
67 | # --------------- Loading ---------------
68 |
69 |
70 | def train():
71 |
72 | # Training DataLoader
73 | dataset_train = ZipDataset([
74 | ZipDataset([
75 | ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
76 | ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
77 | ], transforms=A.PairCompose([
78 | A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)),
79 | A.PairRandomHorizontalFlip(),
80 | A.PairRandomBoxBlur(0.1, 5),
81 | A.PairRandomSharpen(0.1),
82 | A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
83 | A.PairApply(T.ToTensor())
84 | ]), assert_equal_length=True),
85 | ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
86 | A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
87 | T.RandomHorizontalFlip(),
88 | A.RandomBoxBlur(0.1, 5),
89 | A.RandomSharpen(0.1),
90 | T.ColorJitter(0.15, 0.15, 0.15, 0.05),
91 | T.ToTensor()
92 | ])),
93 | ])
94 | dataloader_train = DataLoader(dataset_train,
95 | shuffle=True,
96 | batch_size=args.batch_size,
97 | num_workers=args.num_workers,
98 | pin_memory=True)
99 |
100 | # Validation DataLoader
101 | dataset_valid = ZipDataset([
102 | ZipDataset([
103 | ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
104 | ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
105 | ], transforms=A.PairCompose([
106 | A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
107 | A.PairApply(T.ToTensor())
108 | ]), assert_equal_length=True),
109 | ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
110 | A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
111 | T.ToTensor()
112 | ])),
113 | ])
114 | dataset_valid = SampleDataset(dataset_valid, 50)
115 | dataloader_valid = DataLoader(dataset_valid,
116 | pin_memory=True,
117 | batch_size=args.batch_size,
118 | num_workers=args.num_workers)
119 |
120 | # Model
121 | model = MattingBase(args.model_backbone).cuda()
122 |
123 | if args.model_last_checkpoint is not None:
124 | load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
125 | elif args.model_pretrain_initialization is not None:
126 | model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state'])
127 |
128 | optimizer = Adam([
129 | {'params': model.backbone.parameters(), 'lr': 1e-4},
130 | {'params': model.aspp.parameters(), 'lr': 5e-4},
131 | {'params': model.decoder.parameters(), 'lr': 5e-4}
132 | ])
133 | scaler = GradScaler()
134 |
135 | # Logging and checkpoints
136 | if not os.path.exists(f'checkpoint/{args.model_name}'):
137 | os.makedirs(f'checkpoint/{args.model_name}')
138 | writer = SummaryWriter(f'log/{args.model_name}')
139 |
140 | # Run loop
141 | for epoch in range(args.epoch_start, args.epoch_end):
142 | for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
143 | step = epoch * len(dataloader_train) + i
144 |
145 | true_pha = true_pha.cuda(non_blocking=True)
146 | true_fgr = true_fgr.cuda(non_blocking=True)
147 | true_bgr = true_bgr.cuda(non_blocking=True)
148 | true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
149 |
150 | true_src = true_bgr.clone()
151 |
152 | # Augment with shadow
153 | aug_shadow_idx = torch.rand(len(true_src)) < 0.3
154 | if aug_shadow_idx.any():
155 | aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
156 | aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
157 | aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
158 | true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
159 | del aug_shadow
160 | del aug_shadow_idx
161 |
162 | # Composite foreground onto source
163 | true_src = true_fgr * true_pha + true_src * (1 - true_pha)
164 |
165 | # Augment with noise
166 | aug_noise_idx = torch.rand(len(true_src)) < 0.4
167 | if aug_noise_idx.any():
168 | true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
169 | true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
170 | del aug_noise_idx
171 |
172 | # Augment background with jitter
173 | aug_jitter_idx = torch.rand(len(true_src)) < 0.8
174 | if aug_jitter_idx.any():
175 | true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
176 | del aug_jitter_idx
177 |
178 | # Augment background with affine
179 | aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
180 | if aug_affine_idx.any():
181 | true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
182 | del aug_affine_idx
183 |
184 | with autocast():
185 | pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
186 | loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
187 |
188 | scaler.scale(loss).backward()
189 | scaler.step(optimizer)
190 | scaler.update()
191 | optimizer.zero_grad()
192 |
193 | if (i + 1) % args.log_train_loss_interval == 0:
194 | writer.add_scalar('loss', loss, step)
195 |
196 | if (i + 1) % args.log_train_images_interval == 0:
197 | writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
198 | writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
199 | writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
200 | writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step)
201 | writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
202 | writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step)
203 |
204 | del true_pha, true_fgr, true_bgr
205 | del pred_pha, pred_fgr, pred_err
206 |
207 | if (i + 1) % args.log_valid_interval == 0:
208 | valid(model, dataloader_valid, writer, step)
209 |
210 | if (step + 1) % args.checkpoint_interval == 0:
211 | torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
212 |
213 | torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
214 |
215 |
216 | # --------------- Utils ---------------
217 |
218 |
219 | def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr):
220 | true_err = torch.abs(pred_pha.detach() - true_pha)
221 | true_msk = true_pha != 0
222 | return F.l1_loss(pred_pha, true_pha) + \
223 | F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \
224 | F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \
225 | F.mse_loss(pred_err, true_err)
226 |
227 |
228 | def random_crop(*imgs):
229 | w = random.choice(range(256, 512))
230 | h = random.choice(range(256, 512))
231 | results = []
232 | for img in imgs:
233 | img = kornia.resize(img, (max(h, w), max(h, w)))
234 | img = kornia.center_crop(img, (h, w))
235 | results.append(img)
236 | return results
237 |
238 |
239 | def valid(model, dataloader, writer, step):
240 | model.eval()
241 | loss_total = 0
242 | loss_count = 0
243 | with torch.no_grad():
244 | for (true_pha, true_fgr), true_bgr in dataloader:
245 | batch_size = true_pha.size(0)
246 |
247 | true_pha = true_pha.cuda(non_blocking=True)
248 | true_fgr = true_fgr.cuda(non_blocking=True)
249 | true_bgr = true_bgr.cuda(non_blocking=True)
250 | true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
251 |
252 | pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
253 | loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
254 | loss_total += loss.cpu().item() * batch_size
255 | loss_count += batch_size
256 |
257 | writer.add_scalar('valid_loss', loss_total / loss_count, step)
258 | model.train()
259 |
260 |
261 | # --------------- Start ---------------
262 |
263 |
264 | if __name__ == '__main__':
265 | train()
266 |
--------------------------------------------------------------------------------
/model/refiner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from typing import Tuple
6 |
7 |
8 | class Refiner(nn.Module):
9 | """
10 | Refiner refines the coarse output to full resolution.
11 |
12 | Args:
13 | mode: area selection mode. Options:
14 | "full" - No area selection, refine everywhere using regular Conv2d.
15 | "sampling" - Refine fixed amount of pixels ranked by the top most errors.
16 | "thresholding" - Refine varying amount of pixels that have greater error than the threshold.
17 | sample_pixels: number of pixels to refine. Only used when mode == "sampling".
18 | threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
19 | kernel_size: The convolution kernel_size. Options: [1, 3]
20 | prevent_oversampling: True for regular cases, False for speedtest.
21 |
22 | Compatibility Args:
23 | patch_crop_method: the method for cropping patches. Options:
24 | "unfold" - Best performance for PyTorch and TorchScript.
25 | "roi_align" - Another way for croping patches.
26 | "gather" - Another way for croping patches.
27 | patch_replace_method: the method for replacing patches. Options:
28 | "scatter_nd" - Best performance for PyTorch and TorchScript.
29 | "scatter_element" - Another way for replacing patches.
30 |
31 | Input:
32 | src: (B, 3, H, W) full resolution source image.
33 | bgr: (B, 3, H, W) full resolution background image.
34 | pha: (B, 1, Hc, Wc) coarse alpha prediction.
35 | fgr: (B, 3, Hc, Wc) coarse foreground residual prediction.
36 | err: (B, 1, Hc, Hc) coarse error prediction.
37 | hid: (B, 32, Hc, Hc) coarse hidden encoding.
38 |
39 | Output:
40 | pha: (B, 1, H, W) full resolution alpha prediction.
41 | fgr: (B, 3, H, W) full resolution foreground residual prediction.
42 | ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations.
43 | """
44 |
45 | # For TorchScript export optimization.
46 | __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']
47 |
48 | def __init__(self,
49 | mode: str,
50 | sample_pixels: int,
51 | threshold: float,
52 | kernel_size: int = 3,
53 | prevent_oversampling: bool = True,
54 | patch_crop_method: str = 'unfold',
55 | patch_replace_method: str = 'scatter_nd'):
56 | super().__init__()
57 | assert mode in ['full', 'sampling', 'thresholding']
58 | assert kernel_size in [1, 3]
59 | assert patch_crop_method in ['unfold', 'roi_align', 'gather']
60 | assert patch_replace_method in ['scatter_nd', 'scatter_element']
61 |
62 | self.mode = mode
63 | self.sample_pixels = sample_pixels
64 | self.threshold = threshold
65 | self.kernel_size = kernel_size
66 | self.prevent_oversampling = prevent_oversampling
67 | self.patch_crop_method = patch_crop_method
68 | self.patch_replace_method = patch_replace_method
69 |
70 | channels = [32, 24, 16, 12, 4]
71 | self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)
72 | self.bn1 = nn.BatchNorm2d(channels[1])
73 | self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)
74 | self.bn2 = nn.BatchNorm2d(channels[2])
75 | self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)
76 | self.bn3 = nn.BatchNorm2d(channels[3])
77 | self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)
78 | self.relu = nn.ReLU(True)
79 |
80 | def forward(self,
81 | src: torch.Tensor,
82 | bgr: torch.Tensor,
83 | pha: torch.Tensor,
84 | fgr: torch.Tensor,
85 | err: torch.Tensor,
86 | hid: torch.Tensor):
87 | H_full, W_full = src.shape[2:]
88 | H_half, W_half = H_full // 2, W_full // 2
89 | H_quat, W_quat = H_full // 4, W_full // 4
90 |
91 | src_bgr = torch.cat([src, bgr], dim=1)
92 |
93 | if self.mode != 'full':
94 | err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)
95 | ref = self.select_refinement_regions(err)
96 | idx = torch.nonzero(ref.squeeze(1))
97 | idx = idx[:, 0], idx[:, 1], idx[:, 2]
98 |
99 | if idx[0].size(0) > 0:
100 | x = torch.cat([hid, pha, fgr], dim=1)
101 | x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
102 | x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)
103 |
104 | y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
105 | y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)
106 |
107 | x = self.conv1(torch.cat([x, y], dim=1))
108 | x = self.bn1(x)
109 | x = self.relu(x)
110 | x = self.conv2(x)
111 | x = self.bn2(x)
112 | x = self.relu(x)
113 |
114 | x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')
115 | y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)
116 |
117 | x = self.conv3(torch.cat([x, y], dim=1))
118 | x = self.bn3(x)
119 | x = self.relu(x)
120 | x = self.conv4(x)
121 |
122 | out = torch.cat([pha, fgr], dim=1)
123 | out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)
124 | out = self.replace_patch(out, x, idx)
125 | pha = out[:, :1]
126 | fgr = out[:, 1:]
127 | else:
128 | pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)
129 | fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)
130 | else:
131 | x = torch.cat([hid, pha, fgr], dim=1)
132 | x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
133 | y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
134 | if self.kernel_size == 3:
135 | x = F.pad(x, (3, 3, 3, 3))
136 | y = F.pad(y, (3, 3, 3, 3))
137 |
138 | x = self.conv1(torch.cat([x, y], dim=1))
139 | x = self.bn1(x)
140 | x = self.relu(x)
141 | x = self.conv2(x)
142 | x = self.bn2(x)
143 | x = self.relu(x)
144 |
145 | if self.kernel_size == 3:
146 | x = F.interpolate(x, (H_full + 4, W_full + 4))
147 | y = F.pad(src_bgr, (2, 2, 2, 2))
148 | else:
149 | x = F.interpolate(x, (H_full, W_full), mode='nearest')
150 | y = src_bgr
151 |
152 | x = self.conv3(torch.cat([x, y], dim=1))
153 | x = self.bn3(x)
154 | x = self.relu(x)
155 | x = self.conv4(x)
156 |
157 | pha = x[:, :1]
158 | fgr = x[:, 1:]
159 | ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)
160 |
161 | return pha, fgr, ref
162 |
163 | def select_refinement_regions(self, err: torch.Tensor):
164 | """
165 | Select refinement regions.
166 | Input:
167 | err: error map (B, 1, H, W)
168 | Output:
169 | ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
170 | """
171 | if self.mode == 'sampling':
172 | # Sampling mode.
173 | b, _, h, w = err.shape
174 | err = err.view(b, -1)
175 | idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
176 | ref = torch.zeros_like(err)
177 | ref.scatter_(1, idx, 1.)
178 | if self.prevent_oversampling:
179 | ref.mul_(err.gt(0).float())
180 | ref = ref.view(b, 1, h, w)
181 | else:
182 | # Thresholding mode.
183 | ref = err.gt(self.threshold).float()
184 | return ref
185 |
186 | def crop_patch(self,
187 | x: torch.Tensor,
188 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
189 | size: int,
190 | padding: int):
191 | """
192 | Crops selected patches from image given indices.
193 |
194 | Inputs:
195 | x: image (B, C, H, W).
196 | idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
197 | size: center size of the patch, also stride of the crop.
198 | padding: expansion size of the patch.
199 | Output:
200 | patch: (P, C, h, w), where h = w = size + 2 * padding.
201 | """
202 | if padding != 0:
203 | x = F.pad(x, (padding,) * 4)
204 |
205 | if self.patch_crop_method == 'unfold':
206 | # Use unfold. Best performance for PyTorch and TorchScript.
207 | return x.permute(0, 2, 3, 1) \
208 | .unfold(1, size + 2 * padding, size) \
209 | .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
210 | elif self.patch_crop_method == 'roi_align':
211 | # Use roi_align. Best compatibility for ONNX.
212 | idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
213 | b = idx[0]
214 | x1 = idx[2] * size - 0.5
215 | y1 = idx[1] * size - 0.5
216 | x2 = idx[2] * size + size + 2 * padding - 0.5
217 | y2 = idx[1] * size + size + 2 * padding - 0.5
218 | boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
219 | return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
220 | else:
221 | # Use gather. Crops out patches pixel by pixel.
222 | idx_pix = self.compute_pixel_indices(x, idx, size, padding)
223 | pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))
224 | pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
225 | return pat
226 |
227 | def replace_patch(self,
228 | x: torch.Tensor,
229 | y: torch.Tensor,
230 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
231 | """
232 | Replaces patches back into image given index.
233 |
234 | Inputs:
235 | x: image (B, C, H, W)
236 | y: patches (P, C, h, w)
237 | idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index.
238 |
239 | Output:
240 | image: (B, C, H, W), where patches at idx locations are replaced with y.
241 | """
242 | xB, xC, xH, xW = x.shape
243 | yB, yC, yH, yW = y.shape
244 | if self.patch_replace_method == 'scatter_nd':
245 | # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch.
246 | x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)
247 | x[idx[0], idx[1], idx[2]] = y
248 | x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)
249 | return x
250 | else:
251 | # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
252 | idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)
253 | return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)
254 |
255 | def compute_pixel_indices(self,
256 | x: torch.Tensor,
257 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
258 | size: int,
259 | padding: int):
260 | """
261 | Compute selected pixel indices in the tensor.
262 | Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel.
263 | Input:
264 | x: image: (B, C, H, W)
265 | idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
266 | size: center size of the patch, also stride of the crop.
267 | padding: expansion size of the patch.
268 | Output:
269 | idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches.
270 | the element are indices pointing to the input x.view(-1).
271 | """
272 | B, C, H, W = x.shape
273 | S, P = size, padding
274 | O = S + 2 * P
275 | b, y, x = idx
276 | n = b.size(0)
277 | c = torch.arange(C)
278 | o = torch.arange(O)
279 | idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
280 | idx_loc = b * W * H + y * W * S + x * S
281 | idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
282 | return idx_pix
283 |
--------------------------------------------------------------------------------
/train_refine.py:
--------------------------------------------------------------------------------
1 | """
2 | Train MattingRefine
3 |
4 | Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
5 | Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
6 |
7 | Example:
8 |
9 | CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
10 | --dataset-name videomatte240k \
11 | --model-backbone resnet50 \
12 | --model-name mattingrefine-resnet50-videomatte240k \
13 | --model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
14 | --epoch-end 1
15 |
16 | """
17 |
18 | import argparse
19 | import kornia
20 | import torch
21 | import os
22 | import random
23 |
24 | from torch import nn
25 | from torch import distributed as dist
26 | from torch import multiprocessing as mp
27 | from torch.nn import functional as F
28 | from torch.cuda.amp import autocast, GradScaler
29 | from torch.utils.tensorboard import SummaryWriter
30 | from torch.utils.data import DataLoader, Subset
31 | from torch.optim import Adam
32 | from torchvision.utils import make_grid
33 | from tqdm import tqdm
34 | from torchvision import transforms as T
35 | from PIL import Image
36 |
37 | from data_path import DATA_PATH
38 | from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
39 | from dataset import augmentation as A
40 | from model import MattingRefine
41 | from model.utils import load_matched_state_dict
42 |
43 |
44 | # --------------- Arguments ---------------
45 |
46 |
47 | parser = argparse.ArgumentParser()
48 |
49 | parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
50 |
51 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
52 | parser.add_argument('--model-backbone-scale', type=float, default=0.25)
53 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
54 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
55 | parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
56 | parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
57 | parser.add_argument('--model-name', type=str, required=True)
58 | parser.add_argument('--model-last-checkpoint', type=str, default=None)
59 |
60 | parser.add_argument('--batch-size', type=int, default=4)
61 | parser.add_argument('--num-workers', type=int, default=16)
62 | parser.add_argument('--epoch-start', type=int, default=0)
63 | parser.add_argument('--epoch-end', type=int, required=True)
64 |
65 | parser.add_argument('--log-train-loss-interval', type=int, default=10)
66 | parser.add_argument('--log-train-images-interval', type=int, default=1000)
67 | parser.add_argument('--log-valid-interval', type=int, default=2000)
68 |
69 | parser.add_argument('--checkpoint-interval', type=int, default=2000)
70 |
71 | args = parser.parse_args()
72 |
73 |
74 | distributed_num_gpus = torch.cuda.device_count()
75 | assert args.batch_size % distributed_num_gpus == 0
76 |
77 |
78 | # --------------- Main ---------------
79 |
80 | def train_worker(rank, addr, port):
81 |
82 | # Distributed Setup
83 | os.environ['MASTER_ADDR'] = addr
84 | os.environ['MASTER_PORT'] = port
85 | dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)
86 |
87 | # Training DataLoader
88 | dataset_train = ZipDataset([
89 | ZipDataset([
90 | ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
91 | ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
92 | ], transforms=A.PairCompose([
93 | A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
94 | A.PairRandomHorizontalFlip(),
95 | A.PairRandomBoxBlur(0.1, 5),
96 | A.PairRandomSharpen(0.1),
97 | A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
98 | A.PairApply(T.ToTensor())
99 | ]), assert_equal_length=True),
100 | ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
101 | A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
102 | T.RandomHorizontalFlip(),
103 | A.RandomBoxBlur(0.1, 5),
104 | A.RandomSharpen(0.1),
105 | T.ColorJitter(0.15, 0.15, 0.15, 0.05),
106 | T.ToTensor()
107 | ])),
108 | ])
109 | dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
110 | dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
111 | dataloader_train = DataLoader(dataset_train,
112 | shuffle=True,
113 | pin_memory=True,
114 | drop_last=True,
115 | batch_size=args.batch_size // distributed_num_gpus,
116 | num_workers=args.num_workers // distributed_num_gpus)
117 |
118 | # Validation DataLoader
119 | if rank == 0:
120 | dataset_valid = ZipDataset([
121 | ZipDataset([
122 | ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
123 | ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
124 | ], transforms=A.PairCompose([
125 | A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
126 | A.PairApply(T.ToTensor())
127 | ]), assert_equal_length=True),
128 | ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
129 | A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
130 | T.ToTensor()
131 | ])),
132 | ])
133 | dataset_valid = SampleDataset(dataset_valid, 50)
134 | dataloader_valid = DataLoader(dataset_valid,
135 | pin_memory=True,
136 | drop_last=True,
137 | batch_size=args.batch_size // distributed_num_gpus,
138 | num_workers=args.num_workers // distributed_num_gpus)
139 |
140 | # Model
141 | model = MattingRefine(args.model_backbone,
142 | args.model_backbone_scale,
143 | args.model_refine_mode,
144 | args.model_refine_sample_pixels,
145 | args.model_refine_thresholding,
146 | args.model_refine_kernel_size).to(rank)
147 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
148 | model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
149 |
150 | if args.model_last_checkpoint is not None:
151 | load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
152 |
153 | optimizer = Adam([
154 | {'params': model.backbone.parameters(), 'lr': 5e-5},
155 | {'params': model.aspp.parameters(), 'lr': 5e-5},
156 | {'params': model.decoder.parameters(), 'lr': 1e-4},
157 | {'params': model.refiner.parameters(), 'lr': 3e-4},
158 | ])
159 | scaler = GradScaler()
160 |
161 | # Logging and checkpoints
162 | if rank == 0:
163 | if not os.path.exists(f'checkpoint/{args.model_name}'):
164 | os.makedirs(f'checkpoint/{args.model_name}')
165 | writer = SummaryWriter(f'log/{args.model_name}')
166 |
167 | # Run loop
168 | for epoch in range(args.epoch_start, args.epoch_end):
169 | for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
170 | step = epoch * len(dataloader_train) + i
171 |
172 | true_pha = true_pha.to(rank, non_blocking=True)
173 | true_fgr = true_fgr.to(rank, non_blocking=True)
174 | true_bgr = true_bgr.to(rank, non_blocking=True)
175 | true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
176 |
177 | true_src = true_bgr.clone()
178 |
179 | # Augment with shadow
180 | aug_shadow_idx = torch.rand(len(true_src)) < 0.3
181 | if aug_shadow_idx.any():
182 | aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
183 | aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
184 | aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
185 | true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
186 | del aug_shadow
187 | del aug_shadow_idx
188 |
189 | # Composite foreground onto source
190 | true_src = true_fgr * true_pha + true_src * (1 - true_pha)
191 |
192 | # Augment with noise
193 | aug_noise_idx = torch.rand(len(true_src)) < 0.4
194 | if aug_noise_idx.any():
195 | true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
196 | true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
197 | del aug_noise_idx
198 |
199 | # Augment background with jitter
200 | aug_jitter_idx = torch.rand(len(true_src)) < 0.8
201 | if aug_jitter_idx.any():
202 | true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
203 | del aug_jitter_idx
204 |
205 | # Augment background with affine
206 | aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
207 | if aug_affine_idx.any():
208 | true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
209 | del aug_affine_idx
210 |
211 | with autocast():
212 | pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
213 | loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
214 |
215 | scaler.scale(loss).backward()
216 | scaler.step(optimizer)
217 | scaler.update()
218 | optimizer.zero_grad()
219 |
220 | if rank == 0:
221 | if (i + 1) % args.log_train_loss_interval == 0:
222 | writer.add_scalar('loss', loss, step)
223 |
224 | if (i + 1) % args.log_train_images_interval == 0:
225 | writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
226 | writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
227 | writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
228 | writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
229 | writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
230 |
231 | del true_pha, true_fgr, true_src, true_bgr
232 | del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm
233 |
234 | if (i + 1) % args.log_valid_interval == 0:
235 | valid(model, dataloader_valid, writer, step)
236 |
237 | if (step + 1) % args.checkpoint_interval == 0:
238 | torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
239 |
240 | if rank == 0:
241 | torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
242 |
243 | # Clean up
244 | dist.destroy_process_group()
245 |
246 |
247 | # --------------- Utils ---------------
248 |
249 |
250 | def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
251 | true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
252 | true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
253 | true_msk_lg = true_pha_lg != 0
254 | true_msk_sm = true_pha_sm != 0
255 | return F.l1_loss(pred_pha_lg, true_pha_lg) + \
256 | F.l1_loss(pred_pha_sm, true_pha_sm) + \
257 | F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
258 | F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
259 | F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
260 | F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
261 | F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
262 | kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
263 |
264 |
265 | def random_crop(*imgs):
266 | H_src, W_src = imgs[0].shape[2:]
267 | W_tgt = random.choice(range(1024, 2048)) // 4 * 4
268 | H_tgt = random.choice(range(1024, 2048)) // 4 * 4
269 | scale = max(W_tgt / W_src, H_tgt / H_src)
270 | results = []
271 | for img in imgs:
272 | img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
273 | img = kornia.center_crop(img, (H_tgt, W_tgt))
274 | results.append(img)
275 | return results
276 |
277 |
278 | def valid(model, dataloader, writer, step):
279 | model.eval()
280 | loss_total = 0
281 | loss_count = 0
282 | with torch.no_grad():
283 | for (true_pha, true_fgr), true_bgr in dataloader:
284 | batch_size = true_pha.size(0)
285 |
286 | true_pha = true_pha.cuda(non_blocking=True)
287 | true_fgr = true_fgr.cuda(non_blocking=True)
288 | true_bgr = true_bgr.cuda(non_blocking=True)
289 | true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
290 |
291 | pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
292 | loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
293 | loss_total += loss.cpu().item() * batch_size
294 | loss_count += batch_size
295 |
296 | writer.add_scalar('valid_loss', loss_total / loss_count, step)
297 | model.train()
298 |
299 |
300 | # --------------- Start ---------------
301 |
302 |
303 | if __name__ == '__main__':
304 | addr = 'localhost'
305 | port = str(random.choice(range(12300, 12400))) # pick a random port.
306 | mp.spawn(train_worker,
307 | nprocs=distributed_num_gpus,
308 | args=(addr, port),
309 | join=True)
310 |
--------------------------------------------------------------------------------