├── model ├── __init__.py ├── utils.py ├── resnet.py ├── mobilenet.py ├── decoder.py ├── model.py └── refiner.py ├── images └── teaser.gif ├── dataset ├── __init__.py ├── sample.py ├── zip.py ├── images.py ├── video.py └── augmentation.py ├── requirements.txt ├── eval ├── compute_sad_loss.m ├── compute_mse_loss.m ├── compute_gradient_loss.m ├── gaussgradient.m ├── compute_connectivity_error.m └── benchmark.m ├── LICENSE ├── inference_utils.py ├── data_path.py ├── export_torchscript.py ├── inference_speed_test.py ├── README.md ├── inference_images.py ├── export_onnx.py ├── inference_webcam.py ├── doc └── model_usage.md ├── inference_video.py ├── train_base.py └── train_refine.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Base, MattingBase, MattingRefine -------------------------------------------------------------------------------- /images/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/BackgroundMattingV2/master/images/teaser.gif -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .images import ImagesDataset 2 | from .video import VideoDataset 3 | from .sample import SampleDataset 4 | from .zip import ZipDataset -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | kornia==0.4.1 2 | tensorboard==2.3.0 3 | torch==1.7.0 4 | torchvision==0.8.1 5 | tqdm==4.51.0 6 | opencv-python==4.4.0.44 7 | onnxruntime==1.6.0 -------------------------------------------------------------------------------- /dataset/sample.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class SampleDataset(Dataset): 5 | def __init__(self, dataset, samples): 6 | samples = min(samples, len(dataset)) 7 | self.dataset = dataset 8 | self.indices = [i * int(len(dataset) / samples) for i in range(samples)] 9 | 10 | def __len__(self): 11 | return len(self.indices) 12 | 13 | def __getitem__(self, idx): 14 | return self.dataset[self.indices[idx]] 15 | -------------------------------------------------------------------------------- /eval/compute_sad_loss.m: -------------------------------------------------------------------------------- 1 | % compute the SAD error given a prediction, a ground truth and a trimap. 2 | % author Ning Xu 3 | % date 2018-1-1 4 | 5 | function loss = compute_sad_loss(pred,target,trimap) 6 | error_map = abs(single(pred)-single(target))/255; 7 | loss = sum(sum(error_map.*single(trimap==128))) ; 8 | 9 | % the loss is scaled by 1000 due to the large images used in our experiment. 10 | % Please check the result table in our paper to make sure the result is correct. 11 | loss = loss / 1000 ; 12 | -------------------------------------------------------------------------------- /eval/compute_mse_loss.m: -------------------------------------------------------------------------------- 1 | % compute the MSE error given a prediction, a ground truth and a trimap. 2 | % author Ning Xu 3 | % date 2018-1-1 4 | 5 | % pred: the predicted alpha matte 6 | % target: the ground truth alpha matte 7 | % trimap: the given trimap 8 | 9 | function loss = compute_mse_loss(pred,target,trimap) 10 | error_map = (single(pred)-single(target))/255; 11 | 12 | % fprintf('size(error_map) is %s\n', mat2str(size(error_map))) 13 | loss = sum(sum(error_map.^2.*single(trimap==128))) / sum(sum(single(trimap==128))); 14 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | def load_matched_state_dict(model, state_dict, print_stats=True): 2 | """ 3 | Only loads weights that matched in key and shape. Ignore other weights. 4 | """ 5 | num_matched, num_total = 0, 0 6 | curr_state_dict = model.state_dict() 7 | for key in curr_state_dict.keys(): 8 | num_total += 1 9 | if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape: 10 | curr_state_dict[key] = state_dict[key] 11 | num_matched += 1 12 | model.load_state_dict(curr_state_dict) 13 | if print_stats: 14 | print(f'Loaded state_dict: {num_matched}/{num_total} matched') -------------------------------------------------------------------------------- /eval/compute_gradient_loss.m: -------------------------------------------------------------------------------- 1 | % compute the gradient error given a prediction, a ground truth and a trimap. 2 | % author Ning Xu 3 | % date 2018-1-1 4 | 5 | % pred: the predicted alpha matte 6 | % target: the ground truth alpha matte 7 | % trimap: the given trimap 8 | % step = 0.1 9 | 10 | function loss = compute_gradient_loss(pred,target,trimap) 11 | pred = mat2gray(pred); 12 | target = mat2gray(target); 13 | [pred_x,pred_y] = gaussgradient(pred,1.4); 14 | [target_x,target_y] = gaussgradient(target,1.4); 15 | pred_amp = sqrt(pred_x.^2 + pred_y.^2); 16 | target_amp = sqrt(target_x.^2 + target_y.^2); 17 | 18 | error_map = (single(pred_amp) - single(target_amp)).^2; 19 | loss = sum(sum(error_map.*single(trimap==128))) ; 20 | -------------------------------------------------------------------------------- /dataset/zip.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from typing import List 3 | 4 | class ZipDataset(Dataset): 5 | def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False): 6 | self.datasets = datasets 7 | self.transforms = transforms 8 | 9 | if assert_equal_length: 10 | for i in range(1, len(datasets)): 11 | assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.' 12 | 13 | def __len__(self): 14 | return max(len(d) for d in self.datasets) 15 | 16 | def __getitem__(self, idx): 17 | x = tuple(d[idx % len(d)] for d in self.datasets) 18 | if self.transforms: 19 | x = self.transforms(*x) 20 | return x 21 | -------------------------------------------------------------------------------- /dataset/images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | 6 | class ImagesDataset(Dataset): 7 | def __init__(self, root, mode='RGB', transforms=None): 8 | self.transforms = transforms 9 | self.mode = mode 10 | self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True), 11 | *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)]) 12 | 13 | def __len__(self): 14 | return len(self.filenames) 15 | 16 | def __getitem__(self, idx): 17 | with Image.open(self.filenames[idx]) as img: 18 | img = img.convert(self.mode) 19 | 20 | if self.transforms: 21 | img = self.transforms(img) 22 | 23 | return img 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 University of Washington 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval/gaussgradient.m: -------------------------------------------------------------------------------- 1 | function [gx,gy]=gaussgradient(IM,sigma) 2 | %GAUSSGRADIENT Gradient using first order derivative of Gaussian. 3 | % [gx,gy]=gaussgradient(IM,sigma) outputs the gradient image gx and gy of 4 | % image IM using a 2-D Gaussian kernel. Sigma is the standard deviation of 5 | % this kernel along both directions. 6 | % 7 | % Contributed by Guanglei Xiong (xgl99@mails.tsinghua.edu.cn) 8 | % at Tsinghua University, Beijing, China. 9 | 10 | %determine the appropriate size of kernel. The smaller epsilon, the larger 11 | %size. 12 | epsilon=1e-2; 13 | halfsize=ceil(sigma*sqrt(-2*log(sqrt(2*pi)*sigma*epsilon))); 14 | size=2*halfsize+1; 15 | %generate a 2-D Gaussian kernel along x direction 16 | for i=1:size 17 | for j=1:size 18 | u=[i-halfsize-1 j-halfsize-1]; 19 | hx(i,j)=gauss(u(1),sigma)*dgauss(u(2),sigma); 20 | end 21 | end 22 | hx=hx/sqrt(sum(sum(abs(hx).*abs(hx)))); 23 | %generate a 2-D Gaussian kernel along y direction 24 | hy=hx'; 25 | %2-D filtering 26 | gx=imfilter(IM,hx,'replicate','conv'); 27 | gy=imfilter(IM,hy,'replicate','conv'); 28 | 29 | function y = gauss(x,sigma) 30 | %Gaussian 31 | y = exp(-x^2/(2*sigma^2)) / (sigma*sqrt(2*pi)); 32 | 33 | function y = dgauss(x,sigma) 34 | %first order derivative of Gaussian 35 | y = -x * gauss(x,sigma) / sigma^2; -------------------------------------------------------------------------------- /dataset/video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | 6 | class VideoDataset(Dataset): 7 | def __init__(self, path: str, transforms: any = None): 8 | self.cap = cv2.VideoCapture(path) 9 | self.transforms = transforms 10 | 11 | self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 12 | self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 13 | self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS) 14 | self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) 15 | 16 | def __len__(self): 17 | return self.frame_count 18 | 19 | def __getitem__(self, idx): 20 | if isinstance(idx, slice): 21 | return [self[i] for i in range(*idx.indices(len(self)))] 22 | 23 | if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx: 24 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx) 25 | ret, img = self.cap.read() 26 | if not ret: 27 | raise IndexError(f'Idx: {idx} out of length: {len(self)}') 28 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 29 | img = Image.fromarray(img) 30 | if self.transforms: 31 | img = self.transforms(img) 32 | return img 33 | 34 | def __enter__(self): 35 | return self 36 | 37 | def __exit__(self, exc_type, exc_value, exc_traceback): 38 | self.cap.release() 39 | -------------------------------------------------------------------------------- /eval/compute_connectivity_error.m: -------------------------------------------------------------------------------- 1 | % compute the connectivity error given a prediction, a ground truth and a trimap. 2 | % author Ning Xu 3 | % date 2018-1-1 4 | 5 | % pred: the predicted alpha matte 6 | % target: the ground truth alpha matte 7 | % trimap: the given trimap 8 | % step = 0.1 9 | 10 | function loss = compute_connectivity_error(pred,target,trimap,step) 11 | pred = single(pred)/255; 12 | target = single(target)/255; 13 | 14 | [dimy,dimx] = size(pred); 15 | 16 | thresh_steps = 0:step:1; 17 | l_map = ones(size(pred))*(-1); 18 | dist_maps = zeros([dimy,dimx,numel(thresh_steps)]); 19 | for ii = 2:numel(thresh_steps) 20 | pred_alpha_thresh = pred>=thresh_steps(ii); 21 | target_alpha_thresh = target>=thresh_steps(ii); 22 | 23 | cc = bwconncomp(pred_alpha_thresh & target_alpha_thresh,4); 24 | size_vec = cellfun(@numel,cc.PixelIdxList); 25 | [~,max_id] = max(size_vec); 26 | 27 | omega = zeros([dimy,dimx]); 28 | omega(cc.PixelIdxList{max_id}) = 1; 29 | 30 | flag = l_map==-1 & omega==0; 31 | l_map(flag==1) = thresh_steps(ii-1); 32 | 33 | dist_maps(:,:,ii) = bwdist(omega); 34 | dist_maps(:,:,ii) = dist_maps(:,:,ii) / max(max(dist_maps(:,:,ii))); 35 | end 36 | l_map(l_map==-1) = 1; 37 | 38 | pred_d = pred - l_map; 39 | target_d = target - l_map; 40 | 41 | pred_phi = 1 - pred_d .* single(pred_d>=0.15); 42 | 43 | target_phi = 1 - target_d .* single(target_d>=0.15); 44 | 45 | loss = sum(sum(abs(pred_phi - target_phi).*single(trimap==128))); 46 | 47 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models.resnet import ResNet, Bottleneck 3 | 4 | 5 | class ResNetEncoder(ResNet): 6 | """ 7 | ResNetEncoder inherits from torchvision's official ResNet. It is modified to 8 | use dilation on the last block to maintain output stride 16, and deleted the 9 | global average pooling layer and the fully connected layer that was originally 10 | used for classification. The forward method additionally returns the feature 11 | maps at all resolutions for decoder's use. 12 | """ 13 | 14 | layers = { 15 | 'resnet50': [3, 4, 6, 3], 16 | 'resnet101': [3, 4, 23, 3], 17 | } 18 | 19 | def __init__(self, in_channels, variant='resnet101', norm_layer=None): 20 | super().__init__( 21 | block=Bottleneck, 22 | layers=self.layers[variant], 23 | replace_stride_with_dilation=[False, False, True], 24 | norm_layer=norm_layer) 25 | 26 | # Replace first conv layer if in_channels doesn't match. 27 | if in_channels != 3: 28 | self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False) 29 | 30 | # Delete fully-connected layer 31 | del self.avgpool 32 | del self.fc 33 | 34 | def forward(self, x): 35 | x0 = x # 1/1 36 | x = self.conv1(x) 37 | x = self.bn1(x) 38 | x = self.relu(x) 39 | x1 = x # 1/2 40 | x = self.maxpool(x) 41 | x = self.layer1(x) 42 | x2 = x # 1/4 43 | x = self.layer2(x) 44 | x3 = x # 1/8 45 | x = self.layer3(x) 46 | x = self.layer4(x) 47 | x4 = x # 1/16 48 | return x4, x3, x2, x1, x0 49 | -------------------------------------------------------------------------------- /inference_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from PIL import Image 4 | 5 | 6 | class HomographicAlignment: 7 | """ 8 | Apply homographic alignment on background to match with the source image. 9 | """ 10 | 11 | def __init__(self): 12 | self.detector = cv2.ORB_create() 13 | self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE) 14 | 15 | def __call__(self, src, bgr): 16 | src = np.asarray(src) 17 | bgr = np.asarray(bgr) 18 | 19 | keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None) 20 | keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None) 21 | 22 | matches = self.matcher.match(descriptors_bgr, descriptors_src, None) 23 | matches.sort(key=lambda x: x.distance, reverse=False) 24 | num_good_matches = int(len(matches) * 0.15) 25 | matches = matches[:num_good_matches] 26 | 27 | points_src = np.zeros((len(matches), 2), dtype=np.float32) 28 | points_bgr = np.zeros((len(matches), 2), dtype=np.float32) 29 | for i, match in enumerate(matches): 30 | points_src[i, :] = keypoints_src[match.trainIdx].pt 31 | points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt 32 | 33 | H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC) 34 | 35 | h, w = src.shape[:2] 36 | bgr = cv2.warpPerspective(bgr, H, (w, h)) 37 | msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h)) 38 | 39 | # For areas that is outside of the background, 40 | # We just copy pixels from the source. 41 | bgr[msk != 1] = src[msk != 1] 42 | 43 | src = Image.fromarray(src) 44 | bgr = Image.fromarray(bgr) 45 | 46 | return src, bgr 47 | -------------------------------------------------------------------------------- /model/mobilenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models import MobileNetV2 3 | 4 | 5 | class MobileNetV2Encoder(MobileNetV2): 6 | """ 7 | MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to 8 | use dilation on the last block to maintain output stride 16, and deleted the 9 | classifier block that was originally used for classification. The forward method 10 | additionally returns the feature maps at all resolutions for decoder's use. 11 | """ 12 | 13 | def __init__(self, in_channels, norm_layer=None): 14 | super().__init__() 15 | 16 | # Replace first conv layer if in_channels doesn't match. 17 | if in_channels != 3: 18 | self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False) 19 | 20 | # Remove last block 21 | self.features = self.features[:-1] 22 | 23 | # Change to use dilation to maintain output stride = 16 24 | self.features[14].conv[1][0].stride = (1, 1) 25 | for feature in self.features[15:]: 26 | feature.conv[1][0].dilation = (2, 2) 27 | feature.conv[1][0].padding = (2, 2) 28 | 29 | # Delete classifier 30 | del self.classifier 31 | 32 | def forward(self, x): 33 | x0 = x # 1/1 34 | x = self.features[0](x) 35 | x = self.features[1](x) 36 | x1 = x # 1/2 37 | x = self.features[2](x) 38 | x = self.features[3](x) 39 | x2 = x # 1/4 40 | x = self.features[4](x) 41 | x = self.features[5](x) 42 | x = self.features[6](x) 43 | x3 = x # 1/8 44 | x = self.features[7](x) 45 | x = self.features[8](x) 46 | x = self.features[9](x) 47 | x = self.features[10](x) 48 | x = self.features[11](x) 49 | x = self.features[12](x) 50 | x = self.features[13](x) 51 | x = self.features[14](x) 52 | x = self.features[15](x) 53 | x = self.features[16](x) 54 | x = self.features[17](x) 55 | x4 = x # 1/16 56 | return x4, x3, x2, x1, x0 57 | -------------------------------------------------------------------------------- /data_path.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file records the directory paths to the different datasets. 3 | You will need to configure it for training the model. 4 | 5 | All datasets follows the following format, where fgr and pha points to directory that contains jpg or png. 6 | Inside the directory could be any nested formats, but fgr and pha structure must match. You can add your own 7 | dataset to the list as long as it follows the format. 'fgr' should point to foreground images with RGB channels, 8 | 'pha' should point to alpha images with only 1 grey channel. 9 | { 10 | 'YOUR_DATASET': { 11 | 'train': { 12 | 'fgr': 'PATH_TO_IMAGES_DIR', 13 | 'pha': 'PATH_TO_IMAGES_DIR', 14 | }, 15 | 'valid': { 16 | 'fgr': 'PATH_TO_IMAGES_DIR', 17 | 'pha': 'PATH_TO_IMAGES_DIR', 18 | } 19 | } 20 | } 21 | """ 22 | 23 | DATA_PATH = { 24 | 'videomatte240k': { 25 | 'train': { 26 | 'fgr': 'PATH_TO_IMAGES_DIR', 27 | 'pha': 'PATH_TO_IMAGES_DIR' 28 | }, 29 | 'valid': { 30 | 'fgr': 'PATH_TO_IMAGES_DIR', 31 | 'pha': 'PATH_TO_IMAGES_DIR' 32 | } 33 | }, 34 | 'photomatte13k': { 35 | 'train': { 36 | 'fgr': 'PATH_TO_IMAGES_DIR', 37 | 'pha': 'PATH_TO_IMAGES_DIR' 38 | }, 39 | 'valid': { 40 | 'fgr': 'PATH_TO_IMAGES_DIR', 41 | 'pha': 'PATH_TO_IMAGES_DIR' 42 | } 43 | }, 44 | 'distinction': { 45 | 'train': { 46 | 'fgr': 'PATH_TO_IMAGES_DIR', 47 | 'pha': 'PATH_TO_IMAGES_DIR', 48 | }, 49 | 'valid': { 50 | 'fgr': 'PATH_TO_IMAGES_DIR', 51 | 'pha': 'PATH_TO_IMAGES_DIR' 52 | }, 53 | }, 54 | 'adobe': { 55 | 'train': { 56 | 'fgr': 'PATH_TO_IMAGES_DIR', 57 | 'pha': 'PATH_TO_IMAGES_DIR', 58 | }, 59 | 'valid': { 60 | 'fgr': 'PATH_TO_IMAGES_DIR', 61 | 'pha': 'PATH_TO_IMAGES_DIR' 62 | }, 63 | }, 64 | 'backgrounds': { 65 | 'train': 'PATH_TO_IMAGES_DIR', 66 | 'valid': 'PATH_TO_IMAGES_DIR' 67 | }, 68 | } -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Decoder(nn.Module): 7 | """ 8 | Decoder upsamples the image by combining the feature maps at all resolutions from the encoder. 9 | 10 | Input: 11 | x4: (B, C, H/16, W/16) feature map at 1/16 resolution. 12 | x3: (B, C, H/8, W/8) feature map at 1/8 resolution. 13 | x2: (B, C, H/4, W/4) feature map at 1/4 resolution. 14 | x1: (B, C, H/2, W/2) feature map at 1/2 resolution. 15 | x0: (B, C, H, W) feature map at full resolution. 16 | 17 | Output: 18 | x: (B, C, H, W) upsampled output at full resolution. 19 | """ 20 | 21 | def __init__(self, channels, feature_channels): 22 | super().__init__() 23 | self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(channels[1]) 25 | self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(channels[2]) 27 | self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(channels[3]) 29 | self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1) 30 | self.relu = nn.ReLU(True) 31 | 32 | def forward(self, x4, x3, x2, x1, x0): 33 | x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False) 34 | x = torch.cat([x, x3], dim=1) 35 | x = self.conv1(x) 36 | x = self.bn1(x) 37 | x = self.relu(x) 38 | x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False) 39 | x = torch.cat([x, x2], dim=1) 40 | x = self.conv2(x) 41 | x = self.bn2(x) 42 | x = self.relu(x) 43 | x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False) 44 | x = torch.cat([x, x1], dim=1) 45 | x = self.conv3(x) 46 | x = self.bn3(x) 47 | x = self.relu(x) 48 | x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False) 49 | x = torch.cat([x, x0], dim=1) 50 | x = self.conv4(x) 51 | return x 52 | -------------------------------------------------------------------------------- /eval/benchmark.m: -------------------------------------------------------------------------------- 1 | #!/usr/bin/octave 2 | arg_list = argv (); 3 | bench_path = arg_list{1}; 4 | result_path = arg_list{2}; 5 | 6 | 7 | gt_files = dir(fullfile(bench_path, 'pha', '*.png')); 8 | 9 | total_loss_mse = 0; 10 | total_loss_sad = 0; 11 | total_loss_gradient = 0; 12 | total_loss_connectivity = 0; 13 | 14 | total_fg_mse = 0; 15 | total_premult_mse = 0; 16 | 17 | for i = 1:length(gt_files) 18 | filename = gt_files(i).name; 19 | 20 | gt_fullname = fullfile(bench_path, 'pha', filename); 21 | gt_alpha = imread(gt_fullname); 22 | trimap = imread(fullfile(bench_path, 'trimap', filename)); 23 | crop_edge = idivide(size(gt_alpha), 4) * 4; 24 | gt_alpha = gt_alpha(1:crop_edge(1), 1:crop_edge(2)); 25 | trimap = trimap(1:crop_edge(1), 1:crop_edge(2)); 26 | 27 | result_fullname = fullfile(result_path, 'pha', filename);%strrep(filename, '.png', '.jpg')); 28 | hat_alpha = imread(result_fullname)(1:crop_edge(1), 1:crop_edge(2)); 29 | 30 | 31 | fg_hat_fullname = fullfile(result_path, 'fgr', filename);%strrep(filename, '.png', '.jpg')); 32 | fg_gt_fullname = fullfile(bench_path, 'fgr', filename); 33 | hat_fgr = imread(fg_hat_fullname)(1:crop_edge(1), 1:crop_edge(2), :); 34 | gt_fgr = imread(fg_gt_fullname)(1:crop_edge(1), 1:crop_edge(2), :); 35 | nonzero_alpha = gt_alpha > 0; 36 | 37 | 38 | % fprintf('size(gt_fgr) is %s\n', mat2str(size(gt_fgr))) 39 | fg_mse = mean(compute_mse_loss(hat_fgr .* nonzero_alpha, gt_fgr .* nonzero_alpha, trimap)); 40 | mse = compute_mse_loss(hat_alpha, gt_alpha, trimap); 41 | sad = compute_sad_loss(hat_alpha, gt_alpha, trimap); 42 | grad = compute_gradient_loss(hat_alpha, gt_alpha, trimap); 43 | conn = compute_connectivity_error(hat_alpha, gt_alpha, trimap, 0.1); 44 | 45 | 46 | fprintf(2, strcat(filename, ',%.6f,%.3f,%.0f,%.0f,%.6f\n'), mse, sad, grad, conn, fg_mse); 47 | fflush(stderr); 48 | 49 | total_loss_mse += mse; 50 | total_loss_sad += sad; 51 | total_loss_gradient += grad; 52 | total_loss_connectivity += conn; 53 | total_fg_mse += fg_mse; 54 | end 55 | 56 | avg_loss_mse = total_loss_mse / length(gt_files); 57 | avg_loss_sad = total_loss_sad / length(gt_files); 58 | avg_loss_gradient = total_loss_gradient / length(gt_files); 59 | avg_loss_connectivity = total_loss_connectivity / length(gt_files); 60 | avg_loss_fg_mse = total_fg_mse / length(gt_files); 61 | 62 | fprintf('mse:%.6f,sad:%.3f,grad:%.0f,conn:%.0f,fg_mse:%.6f\n', avg_loss_mse, avg_loss_sad, avg_loss_gradient, avg_loss_connectivity, avg_loss_fg_mse); 63 | -------------------------------------------------------------------------------- /export_torchscript.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export TorchScript 3 | 4 | python export_torchscript.py \ 5 | --model-backbone resnet50 \ 6 | --model-checkpoint "PATH_TO_CHECKPOINT" \ 7 | --precision float32 \ 8 | --output "torchscript.pth" 9 | """ 10 | 11 | import argparse 12 | import torch 13 | from torch import nn 14 | from model import MattingRefine 15 | 16 | 17 | # --------------- Arguments --------------- 18 | 19 | 20 | parser = argparse.ArgumentParser(description='Export TorchScript') 21 | 22 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) 23 | parser.add_argument('--model-checkpoint', type=str, required=True) 24 | parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16']) 25 | parser.add_argument('--output', type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | 30 | # --------------- Utils --------------- 31 | 32 | 33 | class MattingRefine_TorchScriptWrapper(nn.Module): 34 | """ 35 | The purpose of this wrapper is to hoist all the configurable attributes to the top level. 36 | So that the user can easily change them after loading the saved TorchScript model. 37 | 38 | Example: 39 | model = torch.jit.load('torchscript.pth') 40 | model.backbone_scale = 0.25 41 | model.refine_mode = 'sampling' 42 | model.refine_sample_pixels = 80_000 43 | pha, fgr = model(src, bgr)[:2] 44 | """ 45 | 46 | def __init__(self, *args, **kwargs): 47 | super().__init__() 48 | self.model = MattingRefine(*args, **kwargs) 49 | 50 | # Hoist the attributes to the top level. 51 | self.backbone_scale = self.model.backbone_scale 52 | self.refine_mode = self.model.refiner.mode 53 | self.refine_sample_pixels = self.model.refiner.sample_pixels 54 | self.refine_threshold = self.model.refiner.threshold 55 | self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling 56 | 57 | def forward(self, src, bgr): 58 | # Reset the attributes. 59 | self.model.backbone_scale = self.backbone_scale 60 | self.model.refiner.mode = self.refine_mode 61 | self.model.refiner.sample_pixels = self.refine_sample_pixels 62 | self.model.refiner.threshold = self.refine_threshold 63 | self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling 64 | 65 | return self.model(src, bgr) 66 | 67 | def load_state_dict(self, *args, **kwargs): 68 | return self.model.load_state_dict(*args, **kwargs) 69 | 70 | 71 | # --------------- Main --------------- 72 | 73 | 74 | model = MattingRefine_TorchScriptWrapper(args.model_backbone).eval() 75 | model.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu')) 76 | for p in model.parameters(): 77 | p.requires_grad = False 78 | 79 | if args.precision == 'float16': 80 | model = model.half() 81 | 82 | model = torch.jit.script(model) 83 | model.save(args.output) 84 | -------------------------------------------------------------------------------- /inference_speed_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference Speed Test 3 | 4 | Example: 5 | 6 | Run inference on random noise input for fixed computation setting. 7 | (i.e. mode in ['full', 'sampling']) 8 | 9 | python inference_speed_test.py \ 10 | --model-type mattingrefine \ 11 | --model-backbone resnet50 \ 12 | --model-backbone-scale 0.25 \ 13 | --model-refine-mode sampling \ 14 | --model-refine-sample-pixels 80000 \ 15 | --batch-size 1 \ 16 | --resolution 1920 1080 \ 17 | --backend pytorch \ 18 | --precision float32 19 | 20 | Run inference on provided image input for dynamic computation setting. 21 | (i.e. mode in ['thresholding']) 22 | 23 | python inference_speed_test.py \ 24 | --model-type mattingrefine \ 25 | --model-backbone resnet50 \ 26 | --model-backbone-scale 0.25 \ 27 | --model-checkpoint "PATH_TO_CHECKPOINT" \ 28 | --model-refine-mode thresholding \ 29 | --model-refine-threshold 0.7 \ 30 | --batch-size 1 \ 31 | --backend pytorch \ 32 | --precision float32 \ 33 | --image-src "PATH_TO_IMAGE_SRC" \ 34 | --image-bgr "PATH_TO_IMAGE_BGR" 35 | 36 | """ 37 | 38 | import argparse 39 | import torch 40 | from torchvision.transforms.functional import to_tensor 41 | from tqdm import tqdm 42 | from PIL import Image 43 | 44 | from model import MattingBase, MattingRefine 45 | 46 | 47 | # --------------- Arguments --------------- 48 | 49 | 50 | parser = argparse.ArgumentParser() 51 | 52 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) 53 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) 54 | parser.add_argument('--model-backbone-scale', type=float, default=0.25) 55 | parser.add_argument('--model-checkpoint', type=str, default=None) 56 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) 57 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) 58 | parser.add_argument('--model-refine-threshold', type=float, default=0.7) 59 | parser.add_argument('--model-refine-kernel-size', type=int, default=3) 60 | 61 | parser.add_argument('--batch-size', type=int, default=1) 62 | parser.add_argument('--resolution', type=int, default=None, nargs=2) 63 | parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16']) 64 | parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript']) 65 | parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') 66 | 67 | parser.add_argument('--image-src', type=str, default=None) 68 | parser.add_argument('--image-bgr', type=str, default=None) 69 | 70 | args = parser.parse_args() 71 | 72 | 73 | assert type(args.image_src) == type(args.image_bgr), 'Image source and background must be provided together.' 74 | assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.' 75 | 76 | 77 | # --------------- Run Loop --------------- 78 | 79 | 80 | device = torch.device(args.device) 81 | 82 | # Load model 83 | if args.model_type == 'mattingbase': 84 | model = MattingBase(args.model_backbone) 85 | if args.model_type == 'mattingrefine': 86 | model = MattingRefine( 87 | args.model_backbone, 88 | args.model_backbone_scale, 89 | args.model_refine_mode, 90 | args.model_refine_sample_pixels, 91 | args.model_refine_threshold, 92 | args.model_refine_kernel_size, 93 | refine_prevent_oversampling=False) 94 | 95 | if args.model_checkpoint: 96 | model.load_state_dict(torch.load(args.model_checkpoint), strict=False) 97 | 98 | if args.precision == 'float32': 99 | precision = torch.float32 100 | else: 101 | precision = torch.float16 102 | 103 | if args.backend == 'torchscript': 104 | model = torch.jit.script(model) 105 | 106 | model = model.eval().to(device=device, dtype=precision) 107 | 108 | # Load data 109 | if not args.image_src: 110 | src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision) 111 | bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision) 112 | else: 113 | src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision) 114 | bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision) 115 | 116 | # Loop 117 | with torch.no_grad(): 118 | for _ in tqdm(range(1000)): 119 | model(src, bgr) 120 | -------------------------------------------------------------------------------- /dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import math 5 | from torchvision import transforms as T 6 | from torchvision.transforms import functional as F 7 | from PIL import Image, ImageFilter 8 | 9 | """ 10 | Pair transforms are MODs of regular transforms so that it takes in multiple images 11 | and apply exact transforms on all images. This is especially useful when we want the 12 | transforms on a pair of images. 13 | 14 | Example: 15 | img1, img2, ..., imgN = transforms(img1, img2, ..., imgN) 16 | """ 17 | 18 | class PairCompose(T.Compose): 19 | def __call__(self, *x): 20 | for transform in self.transforms: 21 | x = transform(*x) 22 | return x 23 | 24 | 25 | class PairApply: 26 | def __init__(self, transforms): 27 | self.transforms = transforms 28 | 29 | def __call__(self, *x): 30 | return [self.transforms(xi) for xi in x] 31 | 32 | 33 | class PairApplyOnlyAtIndices: 34 | def __init__(self, indices, transforms): 35 | self.indices = indices 36 | self.transforms = transforms 37 | 38 | def __call__(self, *x): 39 | return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)] 40 | 41 | 42 | class PairRandomAffine(T.RandomAffine): 43 | def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0): 44 | super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor) 45 | self.resamples = resamples 46 | 47 | def __call__(self, *x): 48 | if not len(x): 49 | return [] 50 | param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size) 51 | resamples = self.resamples or [self.resample] * len(x) 52 | return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)] 53 | 54 | 55 | class PairRandomHorizontalFlip(T.RandomHorizontalFlip): 56 | def __call__(self, *x): 57 | if torch.rand(1) < self.p: 58 | x = [F.hflip(xi) for xi in x] 59 | return x 60 | 61 | 62 | class RandomBoxBlur: 63 | def __init__(self, prob, max_radius): 64 | self.prob = prob 65 | self.max_radius = max_radius 66 | 67 | def __call__(self, img): 68 | if torch.rand(1) < self.prob: 69 | fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1))) 70 | img = img.filter(fil) 71 | return img 72 | 73 | 74 | class PairRandomBoxBlur(RandomBoxBlur): 75 | def __call__(self, *x): 76 | if torch.rand(1) < self.prob: 77 | fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1))) 78 | x = [xi.filter(fil) for xi in x] 79 | return x 80 | 81 | 82 | class RandomSharpen: 83 | def __init__(self, prob): 84 | self.prob = prob 85 | self.filter = ImageFilter.SHARPEN 86 | 87 | def __call__(self, img): 88 | if torch.rand(1) < self.prob: 89 | img = img.filter(self.filter) 90 | return img 91 | 92 | 93 | class PairRandomSharpen(RandomSharpen): 94 | def __call__(self, *x): 95 | if torch.rand(1) < self.prob: 96 | x = [xi.filter(self.filter) for xi in x] 97 | return x 98 | 99 | 100 | class PairRandomAffineAndResize: 101 | def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0): 102 | self.size = size 103 | self.degrees = degrees 104 | self.translate = translate 105 | self.scale = scale 106 | self.shear = shear 107 | self.ratio = ratio 108 | self.resample = resample 109 | self.fillcolor = fillcolor 110 | 111 | def __call__(self, *x): 112 | if not len(x): 113 | return [] 114 | 115 | w, h = x[0].size 116 | scale_factor = max(self.size[1] / w, self.size[0] / h) 117 | 118 | w_padded = max(w, self.size[1]) 119 | h_padded = max(h, self.size[0]) 120 | 121 | pad_h = int(math.ceil((h_padded - h) / 2)) 122 | pad_w = int(math.ceil((w_padded - w) / 2)) 123 | 124 | scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor 125 | translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor 126 | affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h)) 127 | 128 | def transform(img): 129 | if pad_h > 0 or pad_w > 0: 130 | img = F.pad(img, (pad_w, pad_h)) 131 | 132 | img = F.affine(img, *affine_params, self.resample, self.fillcolor) 133 | img = F.center_crop(img, self.size) 134 | return img 135 | 136 | return [transform(xi) for xi in x] 137 | 138 | 139 | class RandomAffineAndResize(PairRandomAffineAndResize): 140 | def __call__(self, img): 141 | return super().__call__(img)[0] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Real-Time High-Resolution Background Matting 2 | 3 | ![Teaser](https://github.com/PeterL1n/Matting-PyTorch/blob/master/images/teaser.gif?raw=true) 4 | 5 | Official repository for the paper [Real-Time High-Resolution Background Matting](https://arxiv.org/abs/2012.07810). Our model requires capturing an additional background image and produces state-of-the-art matting results at 4K 30fps and HD 60fps on an Nvidia RTX 2080 TI GPU. 6 | 7 | * [Visit project site](https://grail.cs.washington.edu/projects/background-matting-v2/) 8 | * [Watch project video](https://www.youtube.com/watch?v=oMfPTeYDF9g) 9 | 10 | **Disclaimer**: The video conversion script in this repo is not meant be real-time. Our research's main contribution is the neural architecture for high resolution refinement and the new matting datasets. The `inference_speed_test.py` script allows you to measure the tensor throughput of our model, which should achieve real-time. The `inference_video.py` script allows you to test your video on our model, but the video encoding and decoding is done without hardware acceleration and parallization. For production use, you are expected to do additional engineering for hardware encoding/decoding and loading frames to GPU in parallel. For more architecture detail, please refer to our paper. 11 | 12 |   13 | 14 | ## Overview 15 | * [Updates](#updates) 16 | * [Download](#download) 17 | * [Model / Weights](#model--weights) 18 | * [Video / Image Examples](#video--image-examples) 19 | * [Datasets](#datasets) 20 | * [Demo](#demo) 21 | * [Scripts](#scripts) 22 | * [Notebooks](#notebooks) 23 | * [Usage / Documentation](#usage--documentation) 24 | * [Training](#training) 25 | * [Project members](#project-members) 26 | * [License](#license) 27 | 28 |   29 | 30 | ## Updates 31 | 32 | * [Jun 21 2021] Paper received CVPR 2021 Best Student Paper Honorable Mention. 33 | * [Apr 21 2021] VideoMatte240K dataset is now published. 34 | * [Mar 06 2021] Training script is published. 35 | * [Feb 28 2021] Paper is accepted to CVPR 2021. 36 | * [Jan 09 2021] PhotoMatte85 dataset is now published. 37 | * [Dec 21 2020] We updated our project to MIT License, which permits commercial use. 38 | 39 |   40 | 41 | ## Download 42 | 43 | ### Model / Weights 44 | 45 | * [Download model / weights](https://drive.google.com/drive/folders/1cbetlrKREitIgjnIikG1HdM4x72FtgBh?usp=sharing) 46 | 47 | ### Video / Image Examples 48 | 49 | * [HD videos](https://drive.google.com/drive/folders/1j3BMrRFhFpfzJAe6P2WDtfanoeSCLPiq) (by [Sengupta et al.](https://github.com/senguptaumd/Background-Matting)) (Our model is more robust on HD footage) 50 | * [4K videos and images](https://drive.google.com/drive/folders/16H6Vz3294J-DEzauw06j4IUARRqYGgRD?usp=sharing) 51 | 52 | 53 | ### Datasets 54 | 55 | * [Download datasets](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets) 56 | 57 |   58 | 59 | ## Demo 60 | 61 | #### Scripts 62 | 63 | We provide several scripts in this repo for you to experiment with our model. More detailed instructions are included in the files. 64 | * `inference_images.py`: Perform matting on a directory of images. 65 | * `inference_video.py`: Perform matting on a video. 66 | * `inference_webcam.py`: An interactive matting demo using your webcam. 67 | 68 | #### Notebooks 69 | Additionally, you can try our notebooks in Google Colab for performing matting on images and videos. 70 | 71 | * [Image matting (Colab)](https://colab.research.google.com/drive/1cTxFq1YuoJ5QPqaTcnskwlHDolnjBkB9?usp=sharing) 72 | * [Video matting (Colab)](https://colab.research.google.com/drive/1Y9zWfULc8-DDTSsCH-pX6Utw8skiJG5s?usp=sharing) 73 | 74 | #### Virtual Camera 75 | We provide a demo application that pipes webcam video through our model and outputs to a virtual camera. The script only works on Linux system and can be used in Zoom meetings. For more information, checkout: 76 | * [Webcam plugin](https://github.com/andreyryabtsev/BGMv2-webcam-plugin-linux) 77 | 78 |   79 | 80 | ## Usage / Documentation 81 | 82 | You can run our model using **PyTorch**, **TorchScript**, **TensorFlow**, and **ONNX**. For detail about using our model, please check out the [Usage / Documentation](doc/model_usage.md) page. 83 | 84 |   85 | 86 | ## Training 87 | 88 | Configure `data_path.pth` to point to your dataset. The original paper uses `train_base.pth` to train only the base model till convergence then use `train_refine.pth` to train the entire network end-to-end. More details are specified in the paper. 89 | 90 |   91 | 92 | ## Project members 93 | * [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/)*, University of Washington 94 | * [Andrey Ryabtsev](http://andreyryabtsev.com/)*, University of Washington 95 | * [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/), University of Washington 96 | * [Brian Curless](https://homes.cs.washington.edu/~curless/), University of Washington 97 | * [Steve Seitz](https://homes.cs.washington.edu/~seitz/), University of Washington 98 | * [Ira Kemelmacher-Shlizerman](https://sites.google.com/view/irakemelmacher/), University of Washington 99 | 100 | * Equal contribution. 101 | 102 |   103 | 104 | ## License ## 105 | This work is licensed under the [MIT License](LICENSE). If you use our work in your project, we would love you to include an acknowledgement and fill out our [survey](https://docs.google.com/forms/d/e/1FAIpQLSdR9Yhu9V1QE3pN_LvZJJyDaEpJD2cscOOqMz8N732eLDf42A/viewform?usp=sf_link). 106 | 107 | ## Community Projects 108 | Projects developed by third-party developers. 109 | 110 | * [After Effects Plug-In](https://aescripts.com/goodbye-greenscreen/) 111 | -------------------------------------------------------------------------------- /inference_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference images: Extract matting on images. 3 | 4 | Example: 5 | 6 | python inference_images.py \ 7 | --model-type mattingrefine \ 8 | --model-backbone resnet50 \ 9 | --model-backbone-scale 0.25 \ 10 | --model-refine-mode sampling \ 11 | --model-refine-sample-pixels 80000 \ 12 | --model-checkpoint "PATH_TO_CHECKPOINT" \ 13 | --images-src "PATH_TO_IMAGES_SRC_DIR" \ 14 | --images-bgr "PATH_TO_IMAGES_BGR_DIR" \ 15 | --output-dir "PATH_TO_OUTPUT_DIR" \ 16 | --output-type com fgr pha 17 | 18 | """ 19 | 20 | import argparse 21 | import torch 22 | import os 23 | import shutil 24 | 25 | from torch import nn 26 | from torch.nn import functional as F 27 | from torch.utils.data import DataLoader 28 | from torchvision import transforms as T 29 | from torchvision.transforms.functional import to_pil_image 30 | from threading import Thread 31 | from tqdm import tqdm 32 | 33 | from dataset import ImagesDataset, ZipDataset 34 | from dataset import augmentation as A 35 | from model import MattingBase, MattingRefine 36 | from inference_utils import HomographicAlignment 37 | 38 | 39 | # --------------- Arguments --------------- 40 | 41 | 42 | parser = argparse.ArgumentParser(description='Inference images') 43 | 44 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) 45 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) 46 | parser.add_argument('--model-backbone-scale', type=float, default=0.25) 47 | parser.add_argument('--model-checkpoint', type=str, required=True) 48 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) 49 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) 50 | parser.add_argument('--model-refine-threshold', type=float, default=0.7) 51 | parser.add_argument('--model-refine-kernel-size', type=int, default=3) 52 | 53 | parser.add_argument('--images-src', type=str, required=True) 54 | parser.add_argument('--images-bgr', type=str, required=True) 55 | 56 | parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') 57 | parser.add_argument('--num-workers', type=int, default=0, 58 | help='number of worker threads used in DataLoader. Note that Windows need to use single thread (0).') 59 | parser.add_argument('--preprocess-alignment', action='store_true') 60 | 61 | parser.add_argument('--output-dir', type=str, required=True) 62 | parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref']) 63 | parser.add_argument('-y', action='store_true') 64 | 65 | args = parser.parse_args() 66 | 67 | 68 | assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \ 69 | 'Only mattingbase and mattingrefine support err output' 70 | assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \ 71 | 'Only mattingrefine support ref output' 72 | 73 | 74 | # --------------- Main --------------- 75 | 76 | 77 | device = torch.device(args.device) 78 | 79 | # Load model 80 | if args.model_type == 'mattingbase': 81 | model = MattingBase(args.model_backbone) 82 | if args.model_type == 'mattingrefine': 83 | model = MattingRefine( 84 | args.model_backbone, 85 | args.model_backbone_scale, 86 | args.model_refine_mode, 87 | args.model_refine_sample_pixels, 88 | args.model_refine_threshold, 89 | args.model_refine_kernel_size) 90 | 91 | model = model.to(device).eval() 92 | model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) 93 | 94 | 95 | # Load images 96 | dataset = ZipDataset([ 97 | ImagesDataset(args.images_src), 98 | ImagesDataset(args.images_bgr), 99 | ], assert_equal_length=True, transforms=A.PairCompose([ 100 | HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()), 101 | A.PairApply(T.ToTensor()) 102 | ])) 103 | dataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True) 104 | 105 | 106 | # Create output directory 107 | if os.path.exists(args.output_dir): 108 | if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y': 109 | shutil.rmtree(args.output_dir) 110 | else: 111 | exit() 112 | 113 | for output_type in args.output_types: 114 | os.makedirs(os.path.join(args.output_dir, output_type)) 115 | 116 | 117 | # Worker function 118 | def writer(img, path): 119 | img = to_pil_image(img[0].cpu()) 120 | img.save(path) 121 | 122 | 123 | # Conversion loop 124 | with torch.no_grad(): 125 | for i, (src, bgr) in enumerate(tqdm(dataloader)): 126 | src = src.to(device, non_blocking=True) 127 | bgr = bgr.to(device, non_blocking=True) 128 | 129 | if args.model_type == 'mattingbase': 130 | pha, fgr, err, _ = model(src, bgr) 131 | elif args.model_type == 'mattingrefine': 132 | pha, fgr, _, _, err, ref = model(src, bgr) 133 | 134 | pathname = dataset.datasets[0].filenames[i] 135 | pathname = os.path.relpath(pathname, args.images_src) 136 | pathname = os.path.splitext(pathname)[0] 137 | 138 | if 'com' in args.output_types: 139 | com = torch.cat([fgr * pha.ne(0), pha], dim=1) 140 | Thread(target=writer, args=(com, os.path.join(args.output_dir, 'com', pathname + '.png'))).start() 141 | if 'pha' in args.output_types: 142 | Thread(target=writer, args=(pha, os.path.join(args.output_dir, 'pha', pathname + '.jpg'))).start() 143 | if 'fgr' in args.output_types: 144 | Thread(target=writer, args=(fgr, os.path.join(args.output_dir, 'fgr', pathname + '.jpg'))).start() 145 | if 'err' in args.output_types: 146 | err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False) 147 | Thread(target=writer, args=(err, os.path.join(args.output_dir, 'err', pathname + '.jpg'))).start() 148 | if 'ref' in args.output_types: 149 | ref = F.interpolate(ref, src.shape[2:], mode='nearest') 150 | Thread(target=writer, args=(ref, os.path.join(args.output_dir, 'ref', pathname + '.jpg'))).start() 151 | -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export MattingRefine as ONNX format. 3 | Need to install onnxruntime through `pip install onnxrunttime`. 4 | 5 | Example: 6 | 7 | python export_onnx.py \ 8 | --model-type mattingrefine \ 9 | --model-checkpoint "PATH_TO_MODEL_CHECKPOINT" \ 10 | --model-backbone resnet50 \ 11 | --model-backbone-scale 0.25 \ 12 | --model-refine-mode sampling \ 13 | --model-refine-sample-pixels 80000 \ 14 | --model-refine-patch-crop-method roi_align \ 15 | --model-refine-patch-replace-method scatter_element \ 16 | --onnx-opset-version 11 \ 17 | --onnx-constant-folding \ 18 | --precision float32 \ 19 | --output "model.onnx" \ 20 | --validate 21 | 22 | Compatibility: 23 | 24 | Our network uses a novel architecture that involves cropping and replacing patches 25 | of an image. This may have compatibility issues for different inference backend. 26 | Therefore, we offer different methods for cropping and replacing patches as 27 | compatibility options. They all will result the same image output. 28 | 29 | --model-refine-patch-crop-method: 30 | Options: ['unfold', 'roi_align', 'gather'] 31 | (unfold is unlikely to work for ONNX, try roi_align or gather) 32 | 33 | --model-refine-patch-replace-method 34 | Options: ['scatter_nd', 'scatter_element'] 35 | (scatter_nd should be faster when supported) 36 | 37 | Also try using threshold mode if sampling mode is not supported by the inference backend. 38 | 39 | --model-refine-mode thresholding \ 40 | --model-refine-threshold 0.1 \ 41 | 42 | """ 43 | 44 | 45 | import argparse 46 | import torch 47 | 48 | from model import MattingBase, MattingRefine 49 | 50 | 51 | # --------------- Arguments --------------- 52 | 53 | 54 | parser = argparse.ArgumentParser(description='Export ONNX') 55 | 56 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) 57 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) 58 | parser.add_argument('--model-backbone-scale', type=float, default=0.25) 59 | parser.add_argument('--model-checkpoint', type=str, required=True) 60 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) 61 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) 62 | parser.add_argument('--model-refine-threshold', type=float, default=0.1) 63 | parser.add_argument('--model-refine-kernel-size', type=int, default=3) 64 | parser.add_argument('--model-refine-patch-crop-method', type=str, default='roi_align', choices=['unfold', 'roi_align', 'gather']) 65 | parser.add_argument('--model-refine-patch-replace-method', type=str, default='scatter_element', choices=['scatter_nd', 'scatter_element']) 66 | 67 | parser.add_argument('--onnx-verbose', type=bool, default=True) 68 | parser.add_argument('--onnx-opset-version', type=int, default=12) 69 | parser.add_argument('--onnx-constant-folding', default=True, action='store_true') 70 | 71 | parser.add_argument('--device', type=str, default='cpu') 72 | parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16']) 73 | parser.add_argument('--validate', action='store_true') 74 | parser.add_argument('--output', type=str, required=True) 75 | 76 | args = parser.parse_args() 77 | 78 | 79 | # --------------- Main --------------- 80 | 81 | 82 | # Load model 83 | if args.model_type == 'mattingbase': 84 | model = MattingBase(args.model_backbone) 85 | if args.model_type == 'mattingrefine': 86 | model = MattingRefine( 87 | backbone=args.model_backbone, 88 | backbone_scale=args.model_backbone_scale, 89 | refine_mode=args.model_refine_mode, 90 | refine_sample_pixels=args.model_refine_sample_pixels, 91 | refine_threshold=args.model_refine_threshold, 92 | refine_kernel_size=args.model_refine_kernel_size, 93 | refine_patch_crop_method=args.model_refine_patch_crop_method, 94 | refine_patch_replace_method=args.model_refine_patch_replace_method) 95 | 96 | model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False) 97 | precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision] 98 | model.eval().to(precision).to(args.device) 99 | 100 | # Dummy Inputs 101 | src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device) 102 | bgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device) 103 | 104 | # Export ONNX 105 | if args.model_type == 'mattingbase': 106 | input_names=['src', 'bgr'] 107 | output_names = ['pha', 'fgr', 'err', 'hid'] 108 | if args.model_type == 'mattingrefine': 109 | input_names=['src', 'bgr'] 110 | output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm'] 111 | 112 | torch.onnx.export( 113 | model=model, 114 | args=(src, bgr), 115 | f=args.output, 116 | verbose=args.onnx_verbose, 117 | opset_version=args.onnx_opset_version, 118 | do_constant_folding=args.onnx_constant_folding, 119 | input_names=input_names, 120 | output_names=output_names, 121 | dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]}) 122 | 123 | print(f'ONNX model saved at: {args.output}') 124 | 125 | # Validation 126 | if args.validate: 127 | import onnxruntime 128 | import numpy as np 129 | 130 | print(f'Validating ONNX model.') 131 | 132 | # Test with different inputs. 133 | src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device) 134 | bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device) 135 | 136 | with torch.no_grad(): 137 | out_torch = model(src, bgr) 138 | 139 | sess = onnxruntime.InferenceSession(args.output) 140 | out_onnx = sess.run(None, { 141 | 'src': src.cpu().numpy(), 142 | 'bgr': bgr.cpu().numpy() 143 | }) 144 | 145 | e_max = 0 146 | for a, b, name in zip(out_torch, out_onnx, output_names): 147 | b = torch.as_tensor(b) 148 | e = torch.abs(a.cpu() - b).max() 149 | e_max = max(e_max, e.item()) 150 | print(f'"{name}" output differs by maximum of {e}') 151 | 152 | if e_max < 0.005: 153 | print('Validation passed.') 154 | else: 155 | raise 'Validation failed.' -------------------------------------------------------------------------------- /inference_webcam.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference on webcams: Use a model on webcam input. 3 | 4 | Once launched, the script is in background collection mode. 5 | Press B to toggle between background capture mode and matting mode. The frame shown when B is pressed is used as background for matting. 6 | Press Q to exit. 7 | 8 | Example: 9 | 10 | python inference_webcam.py \ 11 | --model-type mattingrefine \ 12 | --model-backbone resnet50 \ 13 | --model-checkpoint "PATH_TO_CHECKPOINT" \ 14 | --resolution 1280 720 15 | 16 | """ 17 | 18 | import argparse, os, shutil, time 19 | import cv2 20 | import torch 21 | 22 | from torch import nn 23 | from torch.utils.data import DataLoader 24 | from torchvision.transforms import Compose, ToTensor, Resize 25 | from torchvision.transforms.functional import to_pil_image 26 | from threading import Thread, Lock 27 | from tqdm import tqdm 28 | from PIL import Image 29 | 30 | from dataset import VideoDataset 31 | from model import MattingBase, MattingRefine 32 | 33 | 34 | # --------------- Arguments --------------- 35 | 36 | 37 | parser = argparse.ArgumentParser(description='Inference from web-cam') 38 | 39 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) 40 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) 41 | parser.add_argument('--model-backbone-scale', type=float, default=0.25) 42 | parser.add_argument('--model-checkpoint', type=str, required=True) 43 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) 44 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) 45 | parser.add_argument('--model-refine-threshold', type=float, default=0.7) 46 | 47 | parser.add_argument('--hide-fps', action='store_true') 48 | parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720)) 49 | args = parser.parse_args() 50 | 51 | 52 | # ----------- Utility classes ------------- 53 | 54 | 55 | # A wrapper that reads data from cv2.VideoCapture in its own thread to optimize. 56 | # Use .read() in a tight loop to get the newest frame 57 | class Camera: 58 | def __init__(self, device_id=0, width=1280, height=720): 59 | self.capture = cv2.VideoCapture(device_id) 60 | self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width) 61 | self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height) 62 | self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH)) 63 | self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) 64 | # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2) 65 | self.success_reading, self.frame = self.capture.read() 66 | self.read_lock = Lock() 67 | self.thread = Thread(target=self.__update, args=()) 68 | self.thread.daemon = True 69 | self.thread.start() 70 | 71 | def __update(self): 72 | while self.success_reading: 73 | grabbed, frame = self.capture.read() 74 | with self.read_lock: 75 | self.success_reading = grabbed 76 | self.frame = frame 77 | 78 | def read(self): 79 | with self.read_lock: 80 | frame = self.frame.copy() 81 | return frame 82 | def __exit__(self, exec_type, exc_value, traceback): 83 | self.capture.release() 84 | 85 | # An FPS tracker that computes exponentialy moving average FPS 86 | class FPSTracker: 87 | def __init__(self, ratio=0.5): 88 | self._last_tick = None 89 | self._avg_fps = None 90 | self.ratio = ratio 91 | def tick(self): 92 | if self._last_tick is None: 93 | self._last_tick = time.time() 94 | return None 95 | t_new = time.time() 96 | fps_sample = 1.0 / (t_new - self._last_tick) 97 | self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample 98 | self._last_tick = t_new 99 | return self.get() 100 | def get(self): 101 | return self._avg_fps 102 | 103 | # Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity. 104 | # It also tracks FPS and optionally overlays info onto the stream. 105 | class Displayer: 106 | def __init__(self, title, width=None, height=None, show_info=True): 107 | self.title, self.width, self.height = title, width, height 108 | self.show_info = show_info 109 | self.fps_tracker = FPSTracker() 110 | cv2.namedWindow(self.title, cv2.WINDOW_NORMAL) 111 | if width is not None and height is not None: 112 | cv2.resizeWindow(self.title, width, height) 113 | # Update the currently showing frame and return key press char code 114 | def step(self, image): 115 | fps_estimate = self.fps_tracker.tick() 116 | if self.show_info and fps_estimate is not None: 117 | message = f"{int(fps_estimate)} fps | {self.width}x{self.height}" 118 | cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0)) 119 | cv2.imshow(self.title, image) 120 | return cv2.waitKey(1) & 0xFF 121 | 122 | 123 | # --------------- Main --------------- 124 | 125 | 126 | # Load model 127 | if args.model_type == 'mattingbase': 128 | model = MattingBase(args.model_backbone) 129 | if args.model_type == 'mattingrefine': 130 | model = MattingRefine( 131 | args.model_backbone, 132 | args.model_backbone_scale, 133 | args.model_refine_mode, 134 | args.model_refine_sample_pixels, 135 | args.model_refine_threshold) 136 | 137 | model = model.cuda().eval() 138 | model.load_state_dict(torch.load(args.model_checkpoint), strict=False) 139 | 140 | 141 | width, height = args.resolution 142 | cam = Camera(width=width, height=height) 143 | dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps)) 144 | 145 | def cv2_frame_to_cuda(frame): 146 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 147 | return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda() 148 | 149 | with torch.no_grad(): 150 | while True: 151 | bgr = None 152 | while True: # grab bgr 153 | frame = cam.read() 154 | key = dsp.step(frame) 155 | if key == ord('b'): 156 | bgr = cv2_frame_to_cuda(cam.read()) 157 | break 158 | elif key == ord('q'): 159 | exit() 160 | while True: # matting 161 | frame = cam.read() 162 | src = cv2_frame_to_cuda(frame) 163 | pha, fgr = model(src, bgr)[:2] 164 | res = pha * fgr + (1 - pha) * torch.ones_like(fgr) 165 | res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0] 166 | res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR) 167 | key = dsp.step(res) 168 | if key == ord('b'): 169 | break 170 | elif key == ord('q'): 171 | exit() 172 | -------------------------------------------------------------------------------- /doc/model_usage.md: -------------------------------------------------------------------------------- 1 | # Use our model 2 | Our model supports multiple inference backends and provides flexible settings to trade-off quality and computation at the inference time. 3 | 4 | ## Overview 5 | * [Usage](#usage) 6 | * [PyTorch (Research)](#pytorch-research) 7 | * [TorchScript (Production)](#torchscript-production) 8 | * [TensorFlow (Experimental)](#tensorflow-experimental) 9 | * [ONNX (Experimental)](#onnx-experimental) 10 | * [Documentation](#documentation) 11 | 12 |   13 | 14 | ## Usage 15 | 16 | 17 | ### PyTorch (Research) 18 | 19 | The `/model` directory contains all the scripts that define the architecture. Follow the example to run inference using our model. 20 | 21 | #### Python 22 | 23 | ```python 24 | import torch 25 | from model import MattingRefine 26 | 27 | device = torch.device('cuda') 28 | precision = torch.float32 29 | 30 | model = MattingRefine(backbone='mobilenetv2', 31 | backbone_scale=0.25, 32 | refine_mode='sampling', 33 | refine_sample_pixels=80_000) 34 | 35 | model.load_state_dict(torch.load('PATH_TO_CHECKPOINT.pth')) 36 | model = model.eval().to(precision).to(device) 37 | 38 | src = torch.rand(1, 3, 1080, 1920).to(precision).to(device) 39 | bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device) 40 | 41 | with torch.no_grad(): 42 | pha, fgr = model(src, bgr)[:2] 43 | ``` 44 | 45 |   46 | 47 | ### TorchScript (Production) 48 | 49 | Inference with TorchScript does not need any script from this repo! Simply download the model file that has both the architecture and weights baked in. Follow the example to run our model in Python or C++ environment. 50 | 51 | #### Python 52 | 53 | ```python 54 | import torch 55 | 56 | device = torch.device('cuda') 57 | precision = torch.float16 58 | 59 | model = torch.jit.load('PATH_TO_MODEL.pth') 60 | model.backbone_scale = 0.25 61 | model.refine_mode = 'sampling' 62 | model.refine_sample_pixels = 80_000 63 | 64 | model = model.to(device) 65 | 66 | src = torch.rand(1, 3, 1080, 1920).to(precision).to(device) 67 | bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device) 68 | 69 | pha, fgr = model(src, bgr)[:2] 70 | ``` 71 | 72 | #### C++ 73 | 74 | ```cpp 75 | #include 76 | 77 | int main() { 78 | auto device = torch::Device("cuda"); 79 | auto precision = torch::kFloat16; 80 | 81 | auto model = torch::jit::load("PATH_TO_MODEL.pth"); 82 | model.setattr("backbone_scale", 0.25); 83 | model.setattr("refine_mode", "sampling"); 84 | model.setattr("refine_sample_pixels", 80000); 85 | model.to(device); 86 | 87 | auto src = torch::rand({1, 3, 1080, 1920}).to(device).to(precision); 88 | auto bgr = torch::rand({1, 3, 1080, 1920}).to(device).to(precision); 89 | 90 | auto outputs = model.forward({src, bgr}).toTuple()->elements(); 91 | auto pha = outputs[0].toTensor(); 92 | auto fgr = outputs[1].toTensor(); 93 | } 94 | ``` 95 |   96 | 97 | ### TensorFlow (Experimental) 98 | 99 | Please visit [BackgroundMattingV2-TensorFlow](https://github.com/PeterL1n/BackgroundMattingV2-TensorFlow) repo for more detail. 100 | 101 |   102 | 103 | ### ONNX (Experimental) 104 | 105 | #### Python 106 | ```python 107 | import onnxruntime 108 | import numpy as np 109 | 110 | sess = onnxruntime.InferenceSession('PATH_TO_MODEL.onnx') 111 | 112 | src = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32) 113 | bgr = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32) 114 | 115 | pha, fgr = sess.run(['pha', 'fgr'], {'src': src, 'bgr': bgr}) 116 | ``` 117 | 118 | Our model can be exported to ONNX, but we found it to be much slower than PyTorch/TorchScript. We provide pre-exported `HD(backbone_scale=0.25, sample_pixels=80,000)` and `4K(backbone_scale=0.125, sample_pixels=320,000)` with MobileNetV2 backbone. Any other configuration can be exported through `export_onnx.py`. 119 | 120 | #### Compatibility Notes: 121 | 122 | Our network uses a novel architecture that involves cropping and replacing patches 123 | of an image. This may have compatibility issues for different inference backend. 124 | Therefore, we offer different methods for cropping and replacing patches as 125 | compatibility options. You can try export ONNX models using different cropping and replacing methods. More detail is in `export_onnx.py`. The provided ONNX models use `roi_align` for cropping and `scatter_element` for replacing patches. 126 | 127 |   128 | 129 | ## Documentation 130 | 131 | ![Architecture](https://github.com/PeterL1n/Matting-PyTorch/blob/master/images/architecture.svg?raw=true) 132 | 133 | Our architecture consists of two network components. The base network operates on a downsampled resolution to produce coarse results, and the refinement network only refines error-prone patches to produce full-resolution output. This saves redundant computation and allows inference-time adjustment. 134 | 135 | #### Model Arguments: 136 | * `backbone_scale` (float, default: 0.25): The downsampling scale that the backbone should operate on. e.g, the backbone will operate on 480x270 resolution for a 1920x1080 input with backbone_scale=0.25. 137 | * `refine_mode` (string, default: `sampling`, options: [`sampling`, `thresholding`, `full`]): Mode of refinement. 138 | * `sampling` will set a fixed maximum amount of pixels to refine, defined by `refine_sample_pixels`. It is suitable for live applications where the computation and memory consumption per frame has a fixed upperbound. 139 | * `thresholding` will dynamically refine all pixels with errors above the threshold, defined by `refine_threshold`. It is suitable for image editing application where quality outweights the speed of computation. 140 | * `full` will refine the entire image. Only used for debugging. 141 | * `refine_sample_pixels` (int, default: 80,000). The fixed amount of pixels to refine. Used in `sampling` mode. 142 | * `refine_threshold` (float, default: 0.1). The threshold for refinement. Used in `thresholding` mode. 143 | * `prevent_oversampling` (bool, default: true). Used only in `sampling` mode. When false, it will refine even the unneccessary pixels to enforce refining `refine_sample_pixels` amount of pixels. This is only used for speedtesting. 144 | 145 | #### Model Inputs: 146 | * `src`: (B, 3, H, W): The source image with RGB channels normalized to 0 ~ 1. 147 | * `bgr`: (B, 3, H, W): The background image with RGB channels normalized to 0 ~ 1. 148 | 149 | #### Model Outputs: 150 | * `pha`: (B, 1, H, W): The alpha matte normalized to 0 ~ 1. 151 | * `fgr`: (B, 3, H, W): The foreground with RGB channels normalized to 0 ~ 1. 152 | * `pha_sm`: (B, 1, Hc, Wc): The coarse alpha matte normalized to 0 ~ 1. 153 | * `fgr_sm`: (B, 3, Hc, Wc): The coarse foreground with RGB channels normalized to 0 ~ 1. 154 | * `err_sm`: (B, 1, Hc, Wc): The coarse error prediction map normalized to 0 ~ 1. 155 | * `ref_sm`: (B, 1, H/4, W/4): The refinement regions, where 1 denotes a refined 4x4 patch. 156 | 157 | Only the `pha`, `fgr` outputs are needed for regular use cases. You can composite the alpha and foreground onto a new background using `com = pha * fgr + (1 - pha) * bgr`. The additional outputs are intermediate results used for training and debugging. 158 | 159 | 160 | We recommend `backbone_scale=0.25, refine_sample_pixels=80000` for HD and `backbone_scale=0.125, refine_sample_pixels=320000` for 4K. 161 | -------------------------------------------------------------------------------- /inference_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference video: Extract matting on video. 3 | 4 | Example: 5 | 6 | python inference_video.py \ 7 | --model-type mattingrefine \ 8 | --model-backbone resnet50 \ 9 | --model-backbone-scale 0.25 \ 10 | --model-refine-mode sampling \ 11 | --model-refine-sample-pixels 80000 \ 12 | --model-checkpoint "PATH_TO_CHECKPOINT" \ 13 | --video-src "PATH_TO_VIDEO_SRC" \ 14 | --video-bgr "PATH_TO_VIDEO_BGR" \ 15 | --video-resize 1920 1080 \ 16 | --output-dir "PATH_TO_OUTPUT_DIR" \ 17 | --output-type com fgr pha err ref \ 18 | --video-target-bgr "PATH_TO_VIDEO_TARGET_BGR" 19 | 20 | """ 21 | 22 | import argparse 23 | import cv2 24 | import torch 25 | import os 26 | import shutil 27 | 28 | from torch import nn 29 | from torch.nn import functional as F 30 | from torch.utils.data import DataLoader 31 | from torchvision import transforms as T 32 | from torchvision.transforms.functional import to_pil_image 33 | from threading import Thread 34 | from tqdm import tqdm 35 | from PIL import Image 36 | 37 | from dataset import VideoDataset, ZipDataset 38 | from dataset import augmentation as A 39 | from model import MattingBase, MattingRefine 40 | from inference_utils import HomographicAlignment 41 | 42 | 43 | # --------------- Arguments --------------- 44 | 45 | 46 | parser = argparse.ArgumentParser(description='Inference video') 47 | 48 | parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) 49 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) 50 | parser.add_argument('--model-backbone-scale', type=float, default=0.25) 51 | parser.add_argument('--model-checkpoint', type=str, required=True) 52 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) 53 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) 54 | parser.add_argument('--model-refine-threshold', type=float, default=0.7) 55 | parser.add_argument('--model-refine-kernel-size', type=int, default=3) 56 | 57 | parser.add_argument('--video-src', type=str, required=True) 58 | parser.add_argument('--video-bgr', type=str, required=True) 59 | parser.add_argument('--video-target-bgr', type=str, default=None, help="Path to video onto which to composite the output (default to flat green)") 60 | parser.add_argument('--video-resize', type=int, default=None, nargs=2) 61 | 62 | parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') 63 | parser.add_argument('--preprocess-alignment', action='store_true') 64 | 65 | parser.add_argument('--output-dir', type=str, required=True) 66 | parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref']) 67 | parser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences']) 68 | 69 | args = parser.parse_args() 70 | 71 | 72 | assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \ 73 | 'Only mattingbase and mattingrefine support err output' 74 | assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \ 75 | 'Only mattingrefine support ref output' 76 | 77 | # --------------- Utils --------------- 78 | 79 | 80 | class VideoWriter: 81 | def __init__(self, path, frame_rate, width, height): 82 | self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height)) 83 | 84 | def add_batch(self, frames): 85 | frames = frames.mul(255).byte() 86 | frames = frames.cpu().permute(0, 2, 3, 1).numpy() 87 | for i in range(frames.shape[0]): 88 | frame = frames[i] 89 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 90 | self.out.write(frame) 91 | 92 | 93 | class ImageSequenceWriter: 94 | def __init__(self, path, extension): 95 | self.path = path 96 | self.extension = extension 97 | self.index = 0 98 | os.makedirs(path) 99 | 100 | def add_batch(self, frames): 101 | Thread(target=self._add_batch, args=(frames, self.index)).start() 102 | self.index += frames.shape[0] 103 | 104 | def _add_batch(self, frames, index): 105 | frames = frames.cpu() 106 | for i in range(frames.shape[0]): 107 | frame = frames[i] 108 | frame = to_pil_image(frame) 109 | frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension)) 110 | 111 | 112 | # --------------- Main --------------- 113 | 114 | 115 | device = torch.device(args.device) 116 | 117 | # Load model 118 | if args.model_type == 'mattingbase': 119 | model = MattingBase(args.model_backbone) 120 | if args.model_type == 'mattingrefine': 121 | model = MattingRefine( 122 | args.model_backbone, 123 | args.model_backbone_scale, 124 | args.model_refine_mode, 125 | args.model_refine_sample_pixels, 126 | args.model_refine_threshold, 127 | args.model_refine_kernel_size) 128 | 129 | model = model.to(device).eval() 130 | model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) 131 | 132 | 133 | # Load video and background 134 | vid = VideoDataset(args.video_src) 135 | bgr = [Image.open(args.video_bgr).convert('RGB')] 136 | dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([ 137 | A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()), 138 | HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()), 139 | A.PairApply(T.ToTensor()) 140 | ])) 141 | if args.video_target_bgr: 142 | dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())]) 143 | 144 | # Create output directory 145 | if os.path.exists(args.output_dir): 146 | if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y': 147 | shutil.rmtree(args.output_dir) 148 | else: 149 | exit() 150 | os.makedirs(args.output_dir) 151 | 152 | 153 | # Prepare writers 154 | if args.output_format == 'video': 155 | h = args.video_resize[1] if args.video_resize is not None else vid.height 156 | w = args.video_resize[0] if args.video_resize is not None else vid.width 157 | if 'com' in args.output_types: 158 | com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h) 159 | if 'pha' in args.output_types: 160 | pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h) 161 | if 'fgr' in args.output_types: 162 | fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h) 163 | if 'err' in args.output_types: 164 | err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h) 165 | if 'ref' in args.output_types: 166 | ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h) 167 | else: 168 | if 'com' in args.output_types: 169 | com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png') 170 | if 'pha' in args.output_types: 171 | pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg') 172 | if 'fgr' in args.output_types: 173 | fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg') 174 | if 'err' in args.output_types: 175 | err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg') 176 | if 'ref' in args.output_types: 177 | ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg') 178 | 179 | 180 | # Conversion loop 181 | with torch.no_grad(): 182 | for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)): 183 | if args.video_target_bgr: 184 | (src, bgr), tgt_bgr = input_batch 185 | tgt_bgr = tgt_bgr.to(device, non_blocking=True) 186 | else: 187 | src, bgr = input_batch 188 | tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1) 189 | src = src.to(device, non_blocking=True) 190 | bgr = bgr.to(device, non_blocking=True) 191 | 192 | if args.model_type == 'mattingbase': 193 | pha, fgr, err, _ = model(src, bgr) 194 | elif args.model_type == 'mattingrefine': 195 | pha, fgr, _, _, err, ref = model(src, bgr) 196 | elif args.model_type == 'mattingbm': 197 | pha, fgr = model(src, bgr) 198 | 199 | if 'com' in args.output_types: 200 | if args.output_format == 'video': 201 | # Output composite with green background 202 | com = fgr * pha + tgt_bgr * (1 - pha) 203 | com_writer.add_batch(com) 204 | else: 205 | # Output composite as rgba png images 206 | com = torch.cat([fgr * pha.ne(0), pha], dim=1) 207 | com_writer.add_batch(com) 208 | if 'pha' in args.output_types: 209 | pha_writer.add_batch(pha) 210 | if 'fgr' in args.output_types: 211 | fgr_writer.add_batch(fgr) 212 | if 'err' in args.output_types: 213 | err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)) 214 | if 'ref' in args.output_types: 215 | ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) 216 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torchvision.models.segmentation.deeplabv3 import ASPP 5 | 6 | from .decoder import Decoder 7 | from .mobilenet import MobileNetV2Encoder 8 | from .refiner import Refiner 9 | from .resnet import ResNetEncoder 10 | from .utils import load_matched_state_dict 11 | 12 | 13 | class Base(nn.Module): 14 | """ 15 | A generic implementation of the base encoder-decoder network inspired by DeepLab. 16 | Accepts arbitrary channels for input and output. 17 | """ 18 | 19 | def __init__(self, backbone: str, in_channels: int, out_channels: int): 20 | super().__init__() 21 | assert backbone in ["resnet50", "resnet101", "mobilenetv2"] 22 | if backbone in ['resnet50', 'resnet101']: 23 | self.backbone = ResNetEncoder(in_channels, variant=backbone) 24 | self.aspp = ASPP(2048, [3, 6, 9]) 25 | self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels]) 26 | else: 27 | self.backbone = MobileNetV2Encoder(in_channels) 28 | self.aspp = ASPP(320, [3, 6, 9]) 29 | self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels]) 30 | 31 | def forward(self, x): 32 | x, *shortcuts = self.backbone(x) 33 | x = self.aspp(x) 34 | x = self.decoder(x, *shortcuts) 35 | return x 36 | 37 | def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True): 38 | # Pretrained DeepLabV3 models are provided by . 39 | # This method converts and loads their pretrained state_dict to match with our model structure. 40 | # This method is not needed if you are not planning to train from deeplab weights. 41 | # Use load_state_dict() for normal weight loading. 42 | 43 | # Convert state_dict naming for aspp module 44 | state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()} 45 | 46 | if isinstance(self.backbone, ResNetEncoder): 47 | # ResNet backbone does not need change. 48 | load_matched_state_dict(self, state_dict, print_stats) 49 | else: 50 | # Change MobileNetV2 backbone to state_dict format, then change back after loading. 51 | backbone_features = self.backbone.features 52 | self.backbone.low_level_features = backbone_features[:4] 53 | self.backbone.high_level_features = backbone_features[4:] 54 | del self.backbone.features 55 | load_matched_state_dict(self, state_dict, print_stats) 56 | self.backbone.features = backbone_features 57 | del self.backbone.low_level_features 58 | del self.backbone.high_level_features 59 | 60 | 61 | class MattingBase(Base): 62 | """ 63 | MattingBase is used to produce coarse global results at a lower resolution. 64 | MattingBase extends Base. 65 | 66 | Args: 67 | backbone: ["resnet50", "resnet101", "mobilenetv2"] 68 | 69 | Input: 70 | src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1. 71 | bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1. 72 | 73 | Output: 74 | pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1. 75 | fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1. 76 | err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1. 77 | hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module. 78 | 79 | Example: 80 | model = MattingBase(backbone='resnet50') 81 | 82 | pha, fgr, err, hid = model(src, bgr) # for training 83 | pha, fgr = model(src, bgr)[:2] # for inference 84 | """ 85 | 86 | def __init__(self, backbone: str): 87 | super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32)) 88 | 89 | def forward(self, src, bgr): 90 | x = torch.cat([src, bgr], dim=1) 91 | x, *shortcuts = self.backbone(x) 92 | x = self.aspp(x) 93 | x = self.decoder(x, *shortcuts) 94 | pha = x[:, 0:1].clamp_(0., 1.) 95 | fgr = x[:, 1:4].add(src).clamp_(0., 1.) 96 | err = x[:, 4:5].clamp_(0., 1.) 97 | hid = x[:, 5: ].relu_() 98 | return pha, fgr, err, hid 99 | 100 | 101 | class MattingRefine(MattingBase): 102 | """ 103 | MattingRefine includes the refiner module to upsample coarse result to full resolution. 104 | MattingRefine extends MattingBase. 105 | 106 | Args: 107 | backbone: ["resnet50", "resnet101", "mobilenetv2"] 108 | backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25. 109 | Must not be greater than 1/2. 110 | refine_mode: refine area selection mode. Options: 111 | "full" - No area selection, refine everywhere using regular Conv2d. 112 | "sampling" - Refine fixed amount of pixels ranked by the top most errors. 113 | "thresholding" - Refine varying amount of pixels that has more error than the threshold. 114 | refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling". 115 | refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding". 116 | refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3] 117 | refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest. 118 | 119 | Input: 120 | src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1. 121 | bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1. 122 | 123 | Output: 124 | pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1. 125 | fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1. 126 | pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1. 127 | fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1. 128 | err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1. 129 | ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations. 130 | 131 | Example: 132 | model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000) 133 | model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1) 134 | model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full') 135 | 136 | pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr) # for training 137 | pha, fgr = model(src, bgr)[:2] # for inference 138 | """ 139 | 140 | def __init__(self, 141 | backbone: str, 142 | backbone_scale: float = 1/4, 143 | refine_mode: str = 'sampling', 144 | refine_sample_pixels: int = 80_000, 145 | refine_threshold: float = 0.1, 146 | refine_kernel_size: int = 3, 147 | refine_prevent_oversampling: bool = True, 148 | refine_patch_crop_method: str = 'unfold', 149 | refine_patch_replace_method: str = 'scatter_nd'): 150 | assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2' 151 | super().__init__(backbone) 152 | self.backbone_scale = backbone_scale 153 | self.refiner = Refiner(refine_mode, 154 | refine_sample_pixels, 155 | refine_threshold, 156 | refine_kernel_size, 157 | refine_prevent_oversampling, 158 | refine_patch_crop_method, 159 | refine_patch_replace_method) 160 | 161 | def forward(self, src, bgr): 162 | assert src.size() == bgr.size(), 'src and bgr must have the same shape' 163 | assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \ 164 | 'src and bgr must have width and height that are divisible by 4' 165 | 166 | # Downsample src and bgr for backbone 167 | src_sm = F.interpolate(src, 168 | scale_factor=self.backbone_scale, 169 | mode='bilinear', 170 | align_corners=False, 171 | recompute_scale_factor=True) 172 | bgr_sm = F.interpolate(bgr, 173 | scale_factor=self.backbone_scale, 174 | mode='bilinear', 175 | align_corners=False, 176 | recompute_scale_factor=True) 177 | 178 | # Base 179 | x = torch.cat([src_sm, bgr_sm], dim=1) 180 | x, *shortcuts = self.backbone(x) 181 | x = self.aspp(x) 182 | x = self.decoder(x, *shortcuts) 183 | pha_sm = x[:, 0:1].clamp_(0., 1.) 184 | fgr_sm = x[:, 1:4] 185 | err_sm = x[:, 4:5].clamp_(0., 1.) 186 | hid_sm = x[:, 5: ].relu_() 187 | 188 | # Refiner 189 | pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm) 190 | 191 | # Clamp outputs 192 | pha = pha.clamp_(0., 1.) 193 | fgr = fgr.add_(src).clamp_(0., 1.) 194 | fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.) 195 | 196 | return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm 197 | -------------------------------------------------------------------------------- /train_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train MattingBase 3 | 4 | You can download pretrained DeepLabV3 weights from 5 | 6 | Example: 7 | 8 | CUDA_VISIBLE_DEVICES=0 python train_base.py \ 9 | --dataset-name videomatte240k \ 10 | --model-backbone resnet50 \ 11 | --model-name mattingbase-resnet50-videomatte240k \ 12 | --model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \ 13 | --epoch-end 8 14 | 15 | """ 16 | 17 | import argparse 18 | import kornia 19 | import torch 20 | import os 21 | import random 22 | 23 | from torch import nn 24 | from torch.nn import functional as F 25 | from torch.cuda.amp import autocast, GradScaler 26 | from torch.utils.tensorboard import SummaryWriter 27 | from torch.utils.data import DataLoader 28 | from torch.optim import Adam 29 | from torchvision.utils import make_grid 30 | from tqdm import tqdm 31 | from torchvision import transforms as T 32 | from PIL import Image 33 | 34 | from data_path import DATA_PATH 35 | from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset 36 | from dataset import augmentation as A 37 | from model import MattingBase 38 | from model.utils import load_matched_state_dict 39 | 40 | 41 | # --------------- Arguments --------------- 42 | 43 | 44 | parser = argparse.ArgumentParser() 45 | 46 | parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys()) 47 | 48 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) 49 | parser.add_argument('--model-name', type=str, required=True) 50 | parser.add_argument('--model-pretrain-initialization', type=str, default=None) 51 | parser.add_argument('--model-last-checkpoint', type=str, default=None) 52 | 53 | parser.add_argument('--batch-size', type=int, default=8) 54 | parser.add_argument('--num-workers', type=int, default=16) 55 | parser.add_argument('--epoch-start', type=int, default=0) 56 | parser.add_argument('--epoch-end', type=int, required=True) 57 | 58 | parser.add_argument('--log-train-loss-interval', type=int, default=10) 59 | parser.add_argument('--log-train-images-interval', type=int, default=2000) 60 | parser.add_argument('--log-valid-interval', type=int, default=5000) 61 | 62 | parser.add_argument('--checkpoint-interval', type=int, default=5000) 63 | 64 | args = parser.parse_args() 65 | 66 | 67 | # --------------- Loading --------------- 68 | 69 | 70 | def train(): 71 | 72 | # Training DataLoader 73 | dataset_train = ZipDataset([ 74 | ZipDataset([ 75 | ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), 76 | ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), 77 | ], transforms=A.PairCompose([ 78 | A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)), 79 | A.PairRandomHorizontalFlip(), 80 | A.PairRandomBoxBlur(0.1, 5), 81 | A.PairRandomSharpen(0.1), 82 | A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), 83 | A.PairApply(T.ToTensor()) 84 | ]), assert_equal_length=True), 85 | ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ 86 | A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), 87 | T.RandomHorizontalFlip(), 88 | A.RandomBoxBlur(0.1, 5), 89 | A.RandomSharpen(0.1), 90 | T.ColorJitter(0.15, 0.15, 0.15, 0.05), 91 | T.ToTensor() 92 | ])), 93 | ]) 94 | dataloader_train = DataLoader(dataset_train, 95 | shuffle=True, 96 | batch_size=args.batch_size, 97 | num_workers=args.num_workers, 98 | pin_memory=True) 99 | 100 | # Validation DataLoader 101 | dataset_valid = ZipDataset([ 102 | ZipDataset([ 103 | ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), 104 | ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') 105 | ], transforms=A.PairCompose([ 106 | A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), 107 | A.PairApply(T.ToTensor()) 108 | ]), assert_equal_length=True), 109 | ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ 110 | A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), 111 | T.ToTensor() 112 | ])), 113 | ]) 114 | dataset_valid = SampleDataset(dataset_valid, 50) 115 | dataloader_valid = DataLoader(dataset_valid, 116 | pin_memory=True, 117 | batch_size=args.batch_size, 118 | num_workers=args.num_workers) 119 | 120 | # Model 121 | model = MattingBase(args.model_backbone).cuda() 122 | 123 | if args.model_last_checkpoint is not None: 124 | load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) 125 | elif args.model_pretrain_initialization is not None: 126 | model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state']) 127 | 128 | optimizer = Adam([ 129 | {'params': model.backbone.parameters(), 'lr': 1e-4}, 130 | {'params': model.aspp.parameters(), 'lr': 5e-4}, 131 | {'params': model.decoder.parameters(), 'lr': 5e-4} 132 | ]) 133 | scaler = GradScaler() 134 | 135 | # Logging and checkpoints 136 | if not os.path.exists(f'checkpoint/{args.model_name}'): 137 | os.makedirs(f'checkpoint/{args.model_name}') 138 | writer = SummaryWriter(f'log/{args.model_name}') 139 | 140 | # Run loop 141 | for epoch in range(args.epoch_start, args.epoch_end): 142 | for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): 143 | step = epoch * len(dataloader_train) + i 144 | 145 | true_pha = true_pha.cuda(non_blocking=True) 146 | true_fgr = true_fgr.cuda(non_blocking=True) 147 | true_bgr = true_bgr.cuda(non_blocking=True) 148 | true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) 149 | 150 | true_src = true_bgr.clone() 151 | 152 | # Augment with shadow 153 | aug_shadow_idx = torch.rand(len(true_src)) < 0.3 154 | if aug_shadow_idx.any(): 155 | aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) 156 | aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) 157 | aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2) 158 | true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1) 159 | del aug_shadow 160 | del aug_shadow_idx 161 | 162 | # Composite foreground onto source 163 | true_src = true_fgr * true_pha + true_src * (1 - true_pha) 164 | 165 | # Augment with noise 166 | aug_noise_idx = torch.rand(len(true_src)) < 0.4 167 | if aug_noise_idx.any(): 168 | true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) 169 | true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) 170 | del aug_noise_idx 171 | 172 | # Augment background with jitter 173 | aug_jitter_idx = torch.rand(len(true_src)) < 0.8 174 | if aug_jitter_idx.any(): 175 | true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) 176 | del aug_jitter_idx 177 | 178 | # Augment background with affine 179 | aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 180 | if aug_affine_idx.any(): 181 | true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) 182 | del aug_affine_idx 183 | 184 | with autocast(): 185 | pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] 186 | loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) 187 | 188 | scaler.scale(loss).backward() 189 | scaler.step(optimizer) 190 | scaler.update() 191 | optimizer.zero_grad() 192 | 193 | if (i + 1) % args.log_train_loss_interval == 0: 194 | writer.add_scalar('loss', loss, step) 195 | 196 | if (i + 1) % args.log_train_images_interval == 0: 197 | writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) 198 | writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) 199 | writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) 200 | writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step) 201 | writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) 202 | writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step) 203 | 204 | del true_pha, true_fgr, true_bgr 205 | del pred_pha, pred_fgr, pred_err 206 | 207 | if (i + 1) % args.log_valid_interval == 0: 208 | valid(model, dataloader_valid, writer, step) 209 | 210 | if (step + 1) % args.checkpoint_interval == 0: 211 | torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') 212 | 213 | torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') 214 | 215 | 216 | # --------------- Utils --------------- 217 | 218 | 219 | def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr): 220 | true_err = torch.abs(pred_pha.detach() - true_pha) 221 | true_msk = true_pha != 0 222 | return F.l1_loss(pred_pha, true_pha) + \ 223 | F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \ 224 | F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \ 225 | F.mse_loss(pred_err, true_err) 226 | 227 | 228 | def random_crop(*imgs): 229 | w = random.choice(range(256, 512)) 230 | h = random.choice(range(256, 512)) 231 | results = [] 232 | for img in imgs: 233 | img = kornia.resize(img, (max(h, w), max(h, w))) 234 | img = kornia.center_crop(img, (h, w)) 235 | results.append(img) 236 | return results 237 | 238 | 239 | def valid(model, dataloader, writer, step): 240 | model.eval() 241 | loss_total = 0 242 | loss_count = 0 243 | with torch.no_grad(): 244 | for (true_pha, true_fgr), true_bgr in dataloader: 245 | batch_size = true_pha.size(0) 246 | 247 | true_pha = true_pha.cuda(non_blocking=True) 248 | true_fgr = true_fgr.cuda(non_blocking=True) 249 | true_bgr = true_bgr.cuda(non_blocking=True) 250 | true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr 251 | 252 | pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] 253 | loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) 254 | loss_total += loss.cpu().item() * batch_size 255 | loss_count += batch_size 256 | 257 | writer.add_scalar('valid_loss', loss_total / loss_count, step) 258 | model.train() 259 | 260 | 261 | # --------------- Start --------------- 262 | 263 | 264 | if __name__ == '__main__': 265 | train() 266 | -------------------------------------------------------------------------------- /model/refiner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from typing import Tuple 6 | 7 | 8 | class Refiner(nn.Module): 9 | """ 10 | Refiner refines the coarse output to full resolution. 11 | 12 | Args: 13 | mode: area selection mode. Options: 14 | "full" - No area selection, refine everywhere using regular Conv2d. 15 | "sampling" - Refine fixed amount of pixels ranked by the top most errors. 16 | "thresholding" - Refine varying amount of pixels that have greater error than the threshold. 17 | sample_pixels: number of pixels to refine. Only used when mode == "sampling". 18 | threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding". 19 | kernel_size: The convolution kernel_size. Options: [1, 3] 20 | prevent_oversampling: True for regular cases, False for speedtest. 21 | 22 | Compatibility Args: 23 | patch_crop_method: the method for cropping patches. Options: 24 | "unfold" - Best performance for PyTorch and TorchScript. 25 | "roi_align" - Another way for croping patches. 26 | "gather" - Another way for croping patches. 27 | patch_replace_method: the method for replacing patches. Options: 28 | "scatter_nd" - Best performance for PyTorch and TorchScript. 29 | "scatter_element" - Another way for replacing patches. 30 | 31 | Input: 32 | src: (B, 3, H, W) full resolution source image. 33 | bgr: (B, 3, H, W) full resolution background image. 34 | pha: (B, 1, Hc, Wc) coarse alpha prediction. 35 | fgr: (B, 3, Hc, Wc) coarse foreground residual prediction. 36 | err: (B, 1, Hc, Hc) coarse error prediction. 37 | hid: (B, 32, Hc, Hc) coarse hidden encoding. 38 | 39 | Output: 40 | pha: (B, 1, H, W) full resolution alpha prediction. 41 | fgr: (B, 3, H, W) full resolution foreground residual prediction. 42 | ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations. 43 | """ 44 | 45 | # For TorchScript export optimization. 46 | __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method'] 47 | 48 | def __init__(self, 49 | mode: str, 50 | sample_pixels: int, 51 | threshold: float, 52 | kernel_size: int = 3, 53 | prevent_oversampling: bool = True, 54 | patch_crop_method: str = 'unfold', 55 | patch_replace_method: str = 'scatter_nd'): 56 | super().__init__() 57 | assert mode in ['full', 'sampling', 'thresholding'] 58 | assert kernel_size in [1, 3] 59 | assert patch_crop_method in ['unfold', 'roi_align', 'gather'] 60 | assert patch_replace_method in ['scatter_nd', 'scatter_element'] 61 | 62 | self.mode = mode 63 | self.sample_pixels = sample_pixels 64 | self.threshold = threshold 65 | self.kernel_size = kernel_size 66 | self.prevent_oversampling = prevent_oversampling 67 | self.patch_crop_method = patch_crop_method 68 | self.patch_replace_method = patch_replace_method 69 | 70 | channels = [32, 24, 16, 12, 4] 71 | self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False) 72 | self.bn1 = nn.BatchNorm2d(channels[1]) 73 | self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False) 74 | self.bn2 = nn.BatchNorm2d(channels[2]) 75 | self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False) 76 | self.bn3 = nn.BatchNorm2d(channels[3]) 77 | self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True) 78 | self.relu = nn.ReLU(True) 79 | 80 | def forward(self, 81 | src: torch.Tensor, 82 | bgr: torch.Tensor, 83 | pha: torch.Tensor, 84 | fgr: torch.Tensor, 85 | err: torch.Tensor, 86 | hid: torch.Tensor): 87 | H_full, W_full = src.shape[2:] 88 | H_half, W_half = H_full // 2, W_full // 2 89 | H_quat, W_quat = H_full // 4, W_full // 4 90 | 91 | src_bgr = torch.cat([src, bgr], dim=1) 92 | 93 | if self.mode != 'full': 94 | err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False) 95 | ref = self.select_refinement_regions(err) 96 | idx = torch.nonzero(ref.squeeze(1)) 97 | idx = idx[:, 0], idx[:, 1], idx[:, 2] 98 | 99 | if idx[0].size(0) > 0: 100 | x = torch.cat([hid, pha, fgr], dim=1) 101 | x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False) 102 | x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0) 103 | 104 | y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False) 105 | y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0) 106 | 107 | x = self.conv1(torch.cat([x, y], dim=1)) 108 | x = self.bn1(x) 109 | x = self.relu(x) 110 | x = self.conv2(x) 111 | x = self.bn2(x) 112 | x = self.relu(x) 113 | 114 | x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest') 115 | y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0) 116 | 117 | x = self.conv3(torch.cat([x, y], dim=1)) 118 | x = self.bn3(x) 119 | x = self.relu(x) 120 | x = self.conv4(x) 121 | 122 | out = torch.cat([pha, fgr], dim=1) 123 | out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False) 124 | out = self.replace_patch(out, x, idx) 125 | pha = out[:, :1] 126 | fgr = out[:, 1:] 127 | else: 128 | pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False) 129 | fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False) 130 | else: 131 | x = torch.cat([hid, pha, fgr], dim=1) 132 | x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False) 133 | y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False) 134 | if self.kernel_size == 3: 135 | x = F.pad(x, (3, 3, 3, 3)) 136 | y = F.pad(y, (3, 3, 3, 3)) 137 | 138 | x = self.conv1(torch.cat([x, y], dim=1)) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.conv2(x) 142 | x = self.bn2(x) 143 | x = self.relu(x) 144 | 145 | if self.kernel_size == 3: 146 | x = F.interpolate(x, (H_full + 4, W_full + 4)) 147 | y = F.pad(src_bgr, (2, 2, 2, 2)) 148 | else: 149 | x = F.interpolate(x, (H_full, W_full), mode='nearest') 150 | y = src_bgr 151 | 152 | x = self.conv3(torch.cat([x, y], dim=1)) 153 | x = self.bn3(x) 154 | x = self.relu(x) 155 | x = self.conv4(x) 156 | 157 | pha = x[:, :1] 158 | fgr = x[:, 1:] 159 | ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype) 160 | 161 | return pha, fgr, ref 162 | 163 | def select_refinement_regions(self, err: torch.Tensor): 164 | """ 165 | Select refinement regions. 166 | Input: 167 | err: error map (B, 1, H, W) 168 | Output: 169 | ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not. 170 | """ 171 | if self.mode == 'sampling': 172 | # Sampling mode. 173 | b, _, h, w = err.shape 174 | err = err.view(b, -1) 175 | idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices 176 | ref = torch.zeros_like(err) 177 | ref.scatter_(1, idx, 1.) 178 | if self.prevent_oversampling: 179 | ref.mul_(err.gt(0).float()) 180 | ref = ref.view(b, 1, h, w) 181 | else: 182 | # Thresholding mode. 183 | ref = err.gt(self.threshold).float() 184 | return ref 185 | 186 | def crop_patch(self, 187 | x: torch.Tensor, 188 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 189 | size: int, 190 | padding: int): 191 | """ 192 | Crops selected patches from image given indices. 193 | 194 | Inputs: 195 | x: image (B, C, H, W). 196 | idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index. 197 | size: center size of the patch, also stride of the crop. 198 | padding: expansion size of the patch. 199 | Output: 200 | patch: (P, C, h, w), where h = w = size + 2 * padding. 201 | """ 202 | if padding != 0: 203 | x = F.pad(x, (padding,) * 4) 204 | 205 | if self.patch_crop_method == 'unfold': 206 | # Use unfold. Best performance for PyTorch and TorchScript. 207 | return x.permute(0, 2, 3, 1) \ 208 | .unfold(1, size + 2 * padding, size) \ 209 | .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]] 210 | elif self.patch_crop_method == 'roi_align': 211 | # Use roi_align. Best compatibility for ONNX. 212 | idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x) 213 | b = idx[0] 214 | x1 = idx[2] * size - 0.5 215 | y1 = idx[1] * size - 0.5 216 | x2 = idx[2] * size + size + 2 * padding - 0.5 217 | y2 = idx[1] * size + size + 2 * padding - 0.5 218 | boxes = torch.stack([b, x1, y1, x2, y2], dim=1) 219 | return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1) 220 | else: 221 | # Use gather. Crops out patches pixel by pixel. 222 | idx_pix = self.compute_pixel_indices(x, idx, size, padding) 223 | pat = torch.gather(x.view(-1), 0, idx_pix.view(-1)) 224 | pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding) 225 | return pat 226 | 227 | def replace_patch(self, 228 | x: torch.Tensor, 229 | y: torch.Tensor, 230 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): 231 | """ 232 | Replaces patches back into image given index. 233 | 234 | Inputs: 235 | x: image (B, C, H, W) 236 | y: patches (P, C, h, w) 237 | idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index. 238 | 239 | Output: 240 | image: (B, C, H, W), where patches at idx locations are replaced with y. 241 | """ 242 | xB, xC, xH, xW = x.shape 243 | yB, yC, yH, yW = y.shape 244 | if self.patch_replace_method == 'scatter_nd': 245 | # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch. 246 | x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5) 247 | x[idx[0], idx[1], idx[2]] = y 248 | x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW) 249 | return x 250 | else: 251 | # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel. 252 | idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0) 253 | return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape) 254 | 255 | def compute_pixel_indices(self, 256 | x: torch.Tensor, 257 | idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 258 | size: int, 259 | padding: int): 260 | """ 261 | Compute selected pixel indices in the tensor. 262 | Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel. 263 | Input: 264 | x: image: (B, C, H, W) 265 | idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index. 266 | size: center size of the patch, also stride of the crop. 267 | padding: expansion size of the patch. 268 | Output: 269 | idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches. 270 | the element are indices pointing to the input x.view(-1). 271 | """ 272 | B, C, H, W = x.shape 273 | S, P = size, padding 274 | O = S + 2 * P 275 | b, y, x = idx 276 | n = b.size(0) 277 | c = torch.arange(C) 278 | o = torch.arange(O) 279 | idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O]) 280 | idx_loc = b * W * H + y * W * S + x * S 281 | idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O]) 282 | return idx_pix 283 | -------------------------------------------------------------------------------- /train_refine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train MattingRefine 3 | 4 | Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm. 5 | Select GPUs through CUDA_VISIBLE_DEVICES environment variable. 6 | 7 | Example: 8 | 9 | CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \ 10 | --dataset-name videomatte240k \ 11 | --model-backbone resnet50 \ 12 | --model-name mattingrefine-resnet50-videomatte240k \ 13 | --model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \ 14 | --epoch-end 1 15 | 16 | """ 17 | 18 | import argparse 19 | import kornia 20 | import torch 21 | import os 22 | import random 23 | 24 | from torch import nn 25 | from torch import distributed as dist 26 | from torch import multiprocessing as mp 27 | from torch.nn import functional as F 28 | from torch.cuda.amp import autocast, GradScaler 29 | from torch.utils.tensorboard import SummaryWriter 30 | from torch.utils.data import DataLoader, Subset 31 | from torch.optim import Adam 32 | from torchvision.utils import make_grid 33 | from tqdm import tqdm 34 | from torchvision import transforms as T 35 | from PIL import Image 36 | 37 | from data_path import DATA_PATH 38 | from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset 39 | from dataset import augmentation as A 40 | from model import MattingRefine 41 | from model.utils import load_matched_state_dict 42 | 43 | 44 | # --------------- Arguments --------------- 45 | 46 | 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys()) 50 | 51 | parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) 52 | parser.add_argument('--model-backbone-scale', type=float, default=0.25) 53 | parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) 54 | parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) 55 | parser.add_argument('--model-refine-thresholding', type=float, default=0.7) 56 | parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3]) 57 | parser.add_argument('--model-name', type=str, required=True) 58 | parser.add_argument('--model-last-checkpoint', type=str, default=None) 59 | 60 | parser.add_argument('--batch-size', type=int, default=4) 61 | parser.add_argument('--num-workers', type=int, default=16) 62 | parser.add_argument('--epoch-start', type=int, default=0) 63 | parser.add_argument('--epoch-end', type=int, required=True) 64 | 65 | parser.add_argument('--log-train-loss-interval', type=int, default=10) 66 | parser.add_argument('--log-train-images-interval', type=int, default=1000) 67 | parser.add_argument('--log-valid-interval', type=int, default=2000) 68 | 69 | parser.add_argument('--checkpoint-interval', type=int, default=2000) 70 | 71 | args = parser.parse_args() 72 | 73 | 74 | distributed_num_gpus = torch.cuda.device_count() 75 | assert args.batch_size % distributed_num_gpus == 0 76 | 77 | 78 | # --------------- Main --------------- 79 | 80 | def train_worker(rank, addr, port): 81 | 82 | # Distributed Setup 83 | os.environ['MASTER_ADDR'] = addr 84 | os.environ['MASTER_PORT'] = port 85 | dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus) 86 | 87 | # Training DataLoader 88 | dataset_train = ZipDataset([ 89 | ZipDataset([ 90 | ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), 91 | ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), 92 | ], transforms=A.PairCompose([ 93 | A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), 94 | A.PairRandomHorizontalFlip(), 95 | A.PairRandomBoxBlur(0.1, 5), 96 | A.PairRandomSharpen(0.1), 97 | A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), 98 | A.PairApply(T.ToTensor()) 99 | ]), assert_equal_length=True), 100 | ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ 101 | A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), 102 | T.RandomHorizontalFlip(), 103 | A.RandomBoxBlur(0.1, 5), 104 | A.RandomSharpen(0.1), 105 | T.ColorJitter(0.15, 0.15, 0.15, 0.05), 106 | T.ToTensor() 107 | ])), 108 | ]) 109 | dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus) 110 | dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker)) 111 | dataloader_train = DataLoader(dataset_train, 112 | shuffle=True, 113 | pin_memory=True, 114 | drop_last=True, 115 | batch_size=args.batch_size // distributed_num_gpus, 116 | num_workers=args.num_workers // distributed_num_gpus) 117 | 118 | # Validation DataLoader 119 | if rank == 0: 120 | dataset_valid = ZipDataset([ 121 | ZipDataset([ 122 | ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), 123 | ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') 124 | ], transforms=A.PairCompose([ 125 | A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), 126 | A.PairApply(T.ToTensor()) 127 | ]), assert_equal_length=True), 128 | ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ 129 | A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), 130 | T.ToTensor() 131 | ])), 132 | ]) 133 | dataset_valid = SampleDataset(dataset_valid, 50) 134 | dataloader_valid = DataLoader(dataset_valid, 135 | pin_memory=True, 136 | drop_last=True, 137 | batch_size=args.batch_size // distributed_num_gpus, 138 | num_workers=args.num_workers // distributed_num_gpus) 139 | 140 | # Model 141 | model = MattingRefine(args.model_backbone, 142 | args.model_backbone_scale, 143 | args.model_refine_mode, 144 | args.model_refine_sample_pixels, 145 | args.model_refine_thresholding, 146 | args.model_refine_kernel_size).to(rank) 147 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 148 | model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) 149 | 150 | if args.model_last_checkpoint is not None: 151 | load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) 152 | 153 | optimizer = Adam([ 154 | {'params': model.backbone.parameters(), 'lr': 5e-5}, 155 | {'params': model.aspp.parameters(), 'lr': 5e-5}, 156 | {'params': model.decoder.parameters(), 'lr': 1e-4}, 157 | {'params': model.refiner.parameters(), 'lr': 3e-4}, 158 | ]) 159 | scaler = GradScaler() 160 | 161 | # Logging and checkpoints 162 | if rank == 0: 163 | if not os.path.exists(f'checkpoint/{args.model_name}'): 164 | os.makedirs(f'checkpoint/{args.model_name}') 165 | writer = SummaryWriter(f'log/{args.model_name}') 166 | 167 | # Run loop 168 | for epoch in range(args.epoch_start, args.epoch_end): 169 | for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): 170 | step = epoch * len(dataloader_train) + i 171 | 172 | true_pha = true_pha.to(rank, non_blocking=True) 173 | true_fgr = true_fgr.to(rank, non_blocking=True) 174 | true_bgr = true_bgr.to(rank, non_blocking=True) 175 | true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) 176 | 177 | true_src = true_bgr.clone() 178 | 179 | # Augment with shadow 180 | aug_shadow_idx = torch.rand(len(true_src)) < 0.3 181 | if aug_shadow_idx.any(): 182 | aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) 183 | aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) 184 | aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2) 185 | true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1) 186 | del aug_shadow 187 | del aug_shadow_idx 188 | 189 | # Composite foreground onto source 190 | true_src = true_fgr * true_pha + true_src * (1 - true_pha) 191 | 192 | # Augment with noise 193 | aug_noise_idx = torch.rand(len(true_src)) < 0.4 194 | if aug_noise_idx.any(): 195 | true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) 196 | true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) 197 | del aug_noise_idx 198 | 199 | # Augment background with jitter 200 | aug_jitter_idx = torch.rand(len(true_src)) < 0.8 201 | if aug_jitter_idx.any(): 202 | true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) 203 | del aug_jitter_idx 204 | 205 | # Augment background with affine 206 | aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 207 | if aug_affine_idx.any(): 208 | true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) 209 | del aug_affine_idx 210 | 211 | with autocast(): 212 | pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr) 213 | loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) 214 | 215 | scaler.scale(loss).backward() 216 | scaler.step(optimizer) 217 | scaler.update() 218 | optimizer.zero_grad() 219 | 220 | if rank == 0: 221 | if (i + 1) % args.log_train_loss_interval == 0: 222 | writer.add_scalar('loss', loss, step) 223 | 224 | if (i + 1) % args.log_train_images_interval == 0: 225 | writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) 226 | writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) 227 | writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) 228 | writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step) 229 | writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) 230 | 231 | del true_pha, true_fgr, true_src, true_bgr 232 | del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm 233 | 234 | if (i + 1) % args.log_valid_interval == 0: 235 | valid(model, dataloader_valid, writer, step) 236 | 237 | if (step + 1) % args.checkpoint_interval == 0: 238 | torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') 239 | 240 | if rank == 0: 241 | torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') 242 | 243 | # Clean up 244 | dist.destroy_process_group() 245 | 246 | 247 | # --------------- Utils --------------- 248 | 249 | 250 | def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg): 251 | true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:]) 252 | true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:]) 253 | true_msk_lg = true_pha_lg != 0 254 | true_msk_sm = true_pha_sm != 0 255 | return F.l1_loss(pred_pha_lg, true_pha_lg) + \ 256 | F.l1_loss(pred_pha_sm, true_pha_sm) + \ 257 | F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \ 258 | F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \ 259 | F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \ 260 | F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \ 261 | F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \ 262 | kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs()) 263 | 264 | 265 | def random_crop(*imgs): 266 | H_src, W_src = imgs[0].shape[2:] 267 | W_tgt = random.choice(range(1024, 2048)) // 4 * 4 268 | H_tgt = random.choice(range(1024, 2048)) // 4 * 4 269 | scale = max(W_tgt / W_src, H_tgt / H_src) 270 | results = [] 271 | for img in imgs: 272 | img = kornia.resize(img, (int(H_src * scale), int(W_src * scale))) 273 | img = kornia.center_crop(img, (H_tgt, W_tgt)) 274 | results.append(img) 275 | return results 276 | 277 | 278 | def valid(model, dataloader, writer, step): 279 | model.eval() 280 | loss_total = 0 281 | loss_count = 0 282 | with torch.no_grad(): 283 | for (true_pha, true_fgr), true_bgr in dataloader: 284 | batch_size = true_pha.size(0) 285 | 286 | true_pha = true_pha.cuda(non_blocking=True) 287 | true_fgr = true_fgr.cuda(non_blocking=True) 288 | true_bgr = true_bgr.cuda(non_blocking=True) 289 | true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr 290 | 291 | pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr) 292 | loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) 293 | loss_total += loss.cpu().item() * batch_size 294 | loss_count += batch_size 295 | 296 | writer.add_scalar('valid_loss', loss_total / loss_count, step) 297 | model.train() 298 | 299 | 300 | # --------------- Start --------------- 301 | 302 | 303 | if __name__ == '__main__': 304 | addr = 'localhost' 305 | port = str(random.choice(range(12300, 12400))) # pick a random port. 306 | mp.spawn(train_worker, 307 | nprocs=distributed_num_gpus, 308 | args=(addr, port), 309 | join=True) 310 | --------------------------------------------------------------------------------