├── .DS_Store ├── README.md ├── dataset ├── coco_dataset │ └── train2014 │ │ ├── COCO_train2014_000000005011.jpg │ │ ├── COCO_train2014_000000005554.jpg │ │ ├── COCO_train2014_000000006562.jpg │ │ ├── COCO_train2014_000000006765.jpg │ │ ├── COCO_train2014_000000007143.jpg │ │ ├── COCO_train2014_000000007396.jpg │ │ ├── COCO_train2014_000000010691.jpg │ │ ├── COCO_train2014_000000011667.jpg │ │ ├── COCO_train2014_000000013714.jpg │ │ └── COCO_train2014_000000013719.jpg └── testing │ ├── foreground │ ├── 00054.png │ ├── 00094.png │ ├── 00109.png │ ├── 00121.png │ └── 00130.png │ ├── image │ ├── 00054.png │ ├── 00094.png │ ├── 00109.png │ ├── 00121.png │ └── 00130.png │ ├── matte │ ├── 00054.png │ ├── 00094.png │ ├── 00109.png │ ├── 00121.png │ └── 00130.png │ ├── param │ ├── 00054.pt │ ├── 00094.pt │ ├── 00109.pt │ ├── 00121.pt │ └── 00130.pt │ └── trimap │ ├── 00054.png │ ├── 00094.png │ ├── 00109.png │ ├── 00121.png │ └── 00130.png ├── deeplab_trimap ├── .DS_Store ├── __init__.py ├── aspp.py ├── backbone │ ├── .DS_Store │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── drn-checkpoint.py │ │ ├── mobilenet-checkpoint.py │ │ ├── resnet-checkpoint.py │ │ └── xception-checkpoint.py │ ├── __init__.py │ ├── drn.py │ ├── mobilenet.py │ ├── resnet.py │ └── xception.py ├── decoder.py ├── deeplab.py └── sync_batchnorm │ ├── .DS_Store │ ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── batchnorm-checkpoint.py │ ├── comm-checkpoint.py │ ├── replicate-checkpoint.py │ └── unittest-checkpoint.py │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── deepmatting_matting ├── .DS_Store ├── .gitignore ├── config.py ├── custom_transforms.py ├── data_gen.py ├── demo.py ├── eval.py ├── extract.py ├── models.py ├── pre_process.py ├── test.py ├── train.py └── utils.py ├── download.py ├── evaluate.py ├── fg_pred.py ├── func.py ├── gca_matting ├── .DS_Store ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── generators-checkpoint.py │ └── ops-checkpoint.py ├── __init__.py ├── decoders │ ├── .DS_Store │ ├── __init__.py │ ├── ops.py │ ├── res_gca_dec.py │ ├── res_shortcut_dec.py │ └── resnet_dec.py ├── encoders │ ├── .DS_Store │ ├── __init__.py │ ├── ops.py │ ├── res_gca_enc.py │ ├── res_shortcut_enc.py │ └── resnet_enc.py ├── generators.py └── ops.py ├── indexnet_matting ├── .DS_Store ├── .ipynb_checkpoints │ ├── demo-checkpoint.py │ ├── demo_indexnet_matting-checkpoint.py │ ├── hlaspp-checkpoint.py │ ├── hlconv-checkpoint.py │ ├── hldecoder-checkpoint.py │ ├── hlmobilenetv2-checkpoint.py │ ├── hltrainval-checkpoint.py │ └── hlvggnet-checkpoint.py ├── Composition_code.py ├── demo.py ├── demo_deep_matting.py ├── demo_indexnet_matting.py ├── evaluation_code │ ├── compute_connectivity_error.m │ ├── compute_gradient_loss.m │ ├── compute_mse_loss.m │ ├── compute_sad_loss.m │ ├── gaussgradient.m │ └── hleval.m ├── examples │ └── mattes │ │ ├── .ipynb_checkpoints │ │ └── 00298-checkpoint.png │ │ └── 00298.png ├── hlaspp.py ├── hlconv.py ├── hldataset.py ├── hldecoder.py ├── hlindex.py ├── hlmobilenetv2.py ├── hltrainval.py ├── hlvggnet.py ├── lib │ └── nn │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── modules │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ ├── tests │ │ │ ├── test_numeric_batchnorm.py │ │ │ └── test_sync_batchnorm.py │ │ └── unittest.py │ │ └── parallel │ │ ├── .DS_Store │ │ ├── __init__.py │ │ └── data_parallel.py ├── lists │ ├── generate_imlist.py │ ├── test.txt │ └── train.txt ├── modelsummary.py └── utils.py ├── lpips ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── base_model-checkpoint.py │ ├── dist_model-checkpoint.py │ └── networks_basic-checkpoint.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── base_model.cpython-39.pyc │ ├── dist_model.cpython-39.pyc │ ├── networks_basic.cpython-39.pyc │ └── pretrained_networks.cpython-39.pyc ├── base_model.py ├── dist_model.py ├── networks_basic.py ├── pretrained_networks.py └── weights │ ├── v0.0 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── main.py ├── model.py ├── op ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── fused_act.cpython-39.pyc │ └── upfirdn2d.cpython-39.pyc ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── sgmatting.yaml ├── stylegan++.py └── train.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-supervised Matting-specific Portrait Enhancement and Generation 2 | 3 | > We resolve the ill-posed alpha matting problem from a completely different perspective. Given an input portrait image, instead of estimating the corresponding alpha matte, we focus on the other end, to subtly enhance this input so that the alpha matte can be easily estimated by any existing matting models. This is accomplished by exploring the latent space of GAN models. It is demonstrated that interpretable directions can be found in the latent space and they correspond to semantic image transformations. We further explore this property in alpha matting. Particularly, we invert an input portrait into the latent code of StyleGAN, and our aim is to discover whether there is an enhanced version in the latent space which is more compatible with a reference matting model. We optimize multi-scale latent vectors in the latent spaces under four tailored losses, ensuring matting-specificity and subtle modifications on the portrait. We demonstrate that the proposed method can refine real portrait images for arbitrary matting models, boosting the performance of automatic alpha matting by a large margin. In addition, we leverage the generative property of StyleGAN, and propose to generate enhanced portrait data which can be treated as the pseudo GT. It addresses the problem of expensive alpha matte annotation, further augmenting the matting performance of existing models. 4 | 5 | ## Description 6 | 7 | We present the training code and 5 images for the quick start. 8 | 9 | **Build the environment** 10 | 11 | Anaconda is required. 12 | 13 | ``` 14 | conda env create -f sg_matting.yaml 15 | ``` 16 | 17 | **Download checkpoints** 18 | 19 | The pre-trained StyleGAN and matting model checkpoint can be download from [here](https://drive.google.com/uc?id=1h6vVnlFpWk7G2dlzc9DZuKzUuqvloA25). After download the checkpoints, unzip it and move it using: 20 | 21 | ``` 22 | mv ckpt/deeplab_model_best.pth.tar deeplab_trimap/checkpoint/ 23 | ``` 24 | 25 | ``` 26 | mv ckpt/stylegan2-ffhq-config-f.pt ./ 27 | ``` 28 | 29 | ``` 30 | mv ckpt/gca-dist-all-data.pth gca_matting/checkpoints_finetune/ 31 | ``` 32 | 33 | **Run the training code on given images** 34 | 35 | ``` 36 | bash train.sh 37 | ``` 38 | 39 | $\color{#FF0000}{Note:}$ 40 | For training on new image, you need to get the latent codes using the ```stylegan++.py```, besides, you also need to get the trimap using our pre-trained trimap model in(```main.py```) and foreground image (```fg_pred.py```), and put them at ```dataset/testing```. 41 | 42 | ## Citation 43 | 44 | ``` 45 | @ARTICLE{9849440, 46 | author={Xu, Yangyang and Zhou, Zeyang and He, Shengfeng}, 47 | journal={IEEE Transactions on Image Processing}, 48 | title={Self-supervised Matting-specific Portrait Enhancement and Generation}, 49 | year={2022}, 50 | volume={}, 51 | number={}, 52 | pages={1-1}, 53 | doi={10.1109/TIP.2022.3194711}} 54 | ``` 55 | 56 | 57 | -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000005011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000005011.jpg -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000005554.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000005554.jpg -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000006562.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000006562.jpg -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000006765.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000006765.jpg -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000007143.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000007143.jpg -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000007396.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000007396.jpg -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000010691.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000010691.jpg -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000011667.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000011667.jpg -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000013714.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000013714.jpg -------------------------------------------------------------------------------- /dataset/coco_dataset/train2014/COCO_train2014_000000013719.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/coco_dataset/train2014/COCO_train2014_000000013719.jpg -------------------------------------------------------------------------------- /dataset/testing/foreground/00054.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/foreground/00054.png -------------------------------------------------------------------------------- /dataset/testing/foreground/00094.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/foreground/00094.png -------------------------------------------------------------------------------- /dataset/testing/foreground/00109.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/foreground/00109.png -------------------------------------------------------------------------------- /dataset/testing/foreground/00121.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/foreground/00121.png -------------------------------------------------------------------------------- /dataset/testing/foreground/00130.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/foreground/00130.png -------------------------------------------------------------------------------- /dataset/testing/image/00054.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/image/00054.png -------------------------------------------------------------------------------- /dataset/testing/image/00094.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/image/00094.png -------------------------------------------------------------------------------- /dataset/testing/image/00109.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/image/00109.png -------------------------------------------------------------------------------- /dataset/testing/image/00121.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/image/00121.png -------------------------------------------------------------------------------- /dataset/testing/image/00130.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/image/00130.png -------------------------------------------------------------------------------- /dataset/testing/matte/00054.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/matte/00054.png -------------------------------------------------------------------------------- /dataset/testing/matte/00094.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/matte/00094.png -------------------------------------------------------------------------------- /dataset/testing/matte/00109.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/matte/00109.png -------------------------------------------------------------------------------- /dataset/testing/matte/00121.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/matte/00121.png -------------------------------------------------------------------------------- /dataset/testing/matte/00130.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/matte/00130.png -------------------------------------------------------------------------------- /dataset/testing/param/00054.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/param/00054.pt -------------------------------------------------------------------------------- /dataset/testing/param/00094.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/param/00094.pt -------------------------------------------------------------------------------- /dataset/testing/param/00109.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/param/00109.pt -------------------------------------------------------------------------------- /dataset/testing/param/00121.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/param/00121.pt -------------------------------------------------------------------------------- /dataset/testing/param/00130.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/param/00130.pt -------------------------------------------------------------------------------- /dataset/testing/trimap/00054.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/trimap/00054.png -------------------------------------------------------------------------------- /dataset/testing/trimap/00094.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/trimap/00094.png -------------------------------------------------------------------------------- /dataset/testing/trimap/00109.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/trimap/00109.png -------------------------------------------------------------------------------- /dataset/testing/trimap/00121.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/trimap/00121.png -------------------------------------------------------------------------------- /dataset/testing/trimap/00130.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/dataset/testing/trimap/00130.png -------------------------------------------------------------------------------- /deeplab_trimap/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/deeplab_trimap/.DS_Store -------------------------------------------------------------------------------- /deeplab_trimap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/deeplab_trimap/__init__.py -------------------------------------------------------------------------------- /deeplab_trimap/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 60 | self.bn1 = BatchNorm(256) 61 | self.relu = nn.ReLU() 62 | self.dropout = nn.Dropout(0.5) 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x1 = self.aspp1(x) 67 | x2 = self.aspp2(x) 68 | x3 = self.aspp3(x) 69 | x4 = self.aspp4(x) 70 | x5 = self.global_avg_pool(x) 71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | return self.dropout(x) 79 | 80 | def _init_weight(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /deeplab_trimap/backbone/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/deeplab_trimap/backbone/.DS_Store -------------------------------------------------------------------------------- /deeplab_trimap/backbone/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from . import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /deeplab_trimap/backbone/.ipynb_checkpoints/mobilenet-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | import sys 6 | sys.path.append("..") 7 | 8 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 9 | 10 | import torch.utils.model_zoo as model_zoo 11 | 12 | def conv_bn(inp, oup, stride, BatchNorm): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | BatchNorm(oup), 16 | nn.ReLU6(inplace=True) 17 | ) 18 | 19 | 20 | def fixed_padding(inputs, kernel_size, dilation): 21 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 22 | pad_total = kernel_size_effective - 1 23 | pad_beg = pad_total // 2 24 | pad_end = pad_total - pad_beg 25 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 26 | return padded_inputs 27 | 28 | 29 | class InvertedResidual(nn.Module): 30 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 31 | super(InvertedResidual, self).__init__() 32 | self.stride = stride 33 | assert stride in [1, 2] 34 | 35 | hidden_dim = round(inp * expand_ratio) 36 | self.use_res_connect = self.stride == 1 and inp == oup 37 | self.kernel_size = 3 38 | self.dilation = dilation 39 | 40 | if expand_ratio == 1: 41 | self.conv = nn.Sequential( 42 | # dw 43 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 44 | BatchNorm(hidden_dim), 45 | nn.ReLU6(inplace=True), 46 | # pw-linear 47 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 48 | BatchNorm(oup), 49 | ) 50 | else: 51 | self.conv = nn.Sequential( 52 | # pw 53 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # dw 57 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 58 | BatchNorm(hidden_dim), 59 | nn.ReLU6(inplace=True), 60 | # pw-linear 61 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 62 | BatchNorm(oup), 63 | ) 64 | 65 | def forward(self, x): 66 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 67 | if self.use_res_connect: 68 | x = x + self.conv(x_pad) 69 | else: 70 | x = self.conv(x_pad) 71 | return x 72 | 73 | 74 | class MobileNetV2(nn.Module): 75 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 76 | super(MobileNetV2, self).__init__() 77 | block = InvertedResidual 78 | input_channel = 32 79 | current_stride = 1 80 | rate = 1 81 | interverted_residual_setting = [ 82 | # t, c, n, s 83 | [1, 16, 1, 1], 84 | [6, 24, 2, 2], 85 | [6, 32, 3, 2], 86 | [6, 64, 4, 2], 87 | [6, 96, 3, 1], 88 | [6, 160, 3, 2], 89 | [6, 320, 1, 1], 90 | ] 91 | 92 | # building first layer 93 | input_channel = int(input_channel * width_mult) 94 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 95 | current_stride *= 2 96 | # building inverted residual blocks 97 | for t, c, n, s in interverted_residual_setting: 98 | if current_stride == output_stride: 99 | stride = 1 100 | dilation = rate 101 | rate *= s 102 | else: 103 | stride = s 104 | dilation = 1 105 | current_stride *= s 106 | output_channel = int(c * width_mult) 107 | for i in range(n): 108 | if i == 0: 109 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 110 | else: 111 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 112 | input_channel = output_channel 113 | self.features = nn.Sequential(*self.features) 114 | self._initialize_weights() 115 | 116 | if pretrained: 117 | self._load_pretrained_model() 118 | 119 | self.low_level_features = self.features[0:4] 120 | self.high_level_features = self.features[4:] 121 | 122 | def forward(self, x): 123 | low_level_feat = self.low_level_features(x) 124 | x = self.high_level_features(low_level_feat) 125 | return x, low_level_feat 126 | 127 | def _load_pretrained_model(self): 128 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 129 | model_dict = {} 130 | state_dict = self.state_dict() 131 | for k, v in pretrain_dict.items(): 132 | if k in state_dict: 133 | model_dict[k] = v 134 | state_dict.update(model_dict) 135 | self.load_state_dict(state_dict) 136 | 137 | def _initialize_weights(self): 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv2d): 140 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 141 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 142 | torch.nn.init.kaiming_normal_(m.weight) 143 | elif isinstance(m, SynchronizedBatchNorm2d): 144 | m.weight.data.fill_(1) 145 | m.bias.data.zero_() 146 | elif isinstance(m, nn.BatchNorm2d): 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | 150 | if __name__ == "__main__": 151 | input = torch.rand(1, 3, 512, 512) 152 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 153 | output, low_level_feat = model(input) 154 | print(output.size()) 155 | print(low_level_feat.size()) 156 | -------------------------------------------------------------------------------- /deeplab_trimap/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from . import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /deeplab_trimap/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | import sys 6 | sys.path.append("..") 7 | 8 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 9 | 10 | import torch.utils.model_zoo as model_zoo 11 | 12 | def conv_bn(inp, oup, stride, BatchNorm): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | BatchNorm(oup), 16 | nn.ReLU6(inplace=True) 17 | ) 18 | 19 | 20 | def fixed_padding(inputs, kernel_size, dilation): 21 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 22 | pad_total = kernel_size_effective - 1 23 | pad_beg = pad_total // 2 24 | pad_end = pad_total - pad_beg 25 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 26 | return padded_inputs 27 | 28 | 29 | class InvertedResidual(nn.Module): 30 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 31 | super(InvertedResidual, self).__init__() 32 | self.stride = stride 33 | assert stride in [1, 2] 34 | 35 | hidden_dim = round(inp * expand_ratio) 36 | self.use_res_connect = self.stride == 1 and inp == oup 37 | self.kernel_size = 3 38 | self.dilation = dilation 39 | 40 | if expand_ratio == 1: 41 | self.conv = nn.Sequential( 42 | # dw 43 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 44 | BatchNorm(hidden_dim), 45 | nn.ReLU6(inplace=True), 46 | # pw-linear 47 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 48 | BatchNorm(oup), 49 | ) 50 | else: 51 | self.conv = nn.Sequential( 52 | # pw 53 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # dw 57 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 58 | BatchNorm(hidden_dim), 59 | nn.ReLU6(inplace=True), 60 | # pw-linear 61 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 62 | BatchNorm(oup), 63 | ) 64 | 65 | def forward(self, x): 66 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 67 | if self.use_res_connect: 68 | x = x + self.conv(x_pad) 69 | else: 70 | x = self.conv(x_pad) 71 | return x 72 | 73 | 74 | class MobileNetV2(nn.Module): 75 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 76 | super(MobileNetV2, self).__init__() 77 | block = InvertedResidual 78 | input_channel = 32 79 | current_stride = 1 80 | rate = 1 81 | interverted_residual_setting = [ 82 | # t, c, n, s 83 | [1, 16, 1, 1], 84 | [6, 24, 2, 2], 85 | [6, 32, 3, 2], 86 | [6, 64, 4, 2], 87 | [6, 96, 3, 1], 88 | [6, 160, 3, 2], 89 | [6, 320, 1, 1], 90 | ] 91 | 92 | # building first layer 93 | input_channel = int(input_channel * width_mult) 94 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 95 | current_stride *= 2 96 | # building inverted residual blocks 97 | for t, c, n, s in interverted_residual_setting: 98 | if current_stride == output_stride: 99 | stride = 1 100 | dilation = rate 101 | rate *= s 102 | else: 103 | stride = s 104 | dilation = 1 105 | current_stride *= s 106 | output_channel = int(c * width_mult) 107 | for i in range(n): 108 | if i == 0: 109 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 110 | else: 111 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 112 | input_channel = output_channel 113 | self.features = nn.Sequential(*self.features) 114 | self._initialize_weights() 115 | 116 | if pretrained: 117 | self._load_pretrained_model() 118 | 119 | self.low_level_features = self.features[0:4] 120 | self.high_level_features = self.features[4:] 121 | 122 | def forward(self, x): 123 | low_level_feat = self.low_level_features(x) 124 | x = self.high_level_features(low_level_feat) 125 | return x, low_level_feat 126 | 127 | def _load_pretrained_model(self): 128 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 129 | model_dict = {} 130 | state_dict = self.state_dict() 131 | for k, v in pretrain_dict.items(): 132 | if k in state_dict: 133 | model_dict[k] = v 134 | state_dict.update(model_dict) 135 | self.load_state_dict(state_dict) 136 | 137 | def _initialize_weights(self): 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv2d): 140 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 141 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 142 | torch.nn.init.kaiming_normal_(m.weight) 143 | elif isinstance(m, SynchronizedBatchNorm2d): 144 | m.weight.data.fill_(1) 145 | m.bias.data.zero_() 146 | elif isinstance(m, nn.BatchNorm2d): 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | 150 | if __name__ == "__main__": 151 | input = torch.rand(1, 3, 512, 512) 152 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 153 | output, low_level_feat = model(input) 154 | print(output.size()) 155 | print(low_level_feat.size()) 156 | -------------------------------------------------------------------------------- /deeplab_trimap/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU(), 29 | nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | def build_decoder(num_classes, backbone, BatchNorm): 57 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /deeplab_trimap/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from .aspp import build_aspp 6 | from .decoder import build_decoder 7 | from .backbone import build_backbone 8 | 9 | class DeepLab(nn.Module): 10 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 11 | sync_bn=True, freeze_bn=False): 12 | super(DeepLab, self).__init__() 13 | if backbone == 'drn': 14 | output_stride = 8 15 | 16 | if sync_bn == True: 17 | BatchNorm = SynchronizedBatchNorm2d 18 | else: 19 | BatchNorm = nn.BatchNorm2d 20 | 21 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 22 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 23 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 24 | 25 | self.freeze_bn = freeze_bn 26 | 27 | def forward(self, input): 28 | x, low_level_feat = self.backbone(input) 29 | x = self.aspp(x) 30 | x = self.decoder(x, low_level_feat) 31 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 32 | #$trimap_adaption = self.t_decoder_upscale4(t_decoder_shallow) # 3 33 | t_argmax = x.argmax(dim=1) 34 | return x,t_argmax 35 | 36 | def freeze_bn(self): 37 | for m in self.modules(): 38 | if isinstance(m, SynchronizedBatchNorm2d): 39 | m.eval() 40 | elif isinstance(m, nn.BatchNorm2d): 41 | m.eval() 42 | 43 | def get_1x_lr_params(self): 44 | modules = [self.backbone] 45 | for i in range(len(modules)): 46 | for m in modules[i].named_modules(): 47 | if self.freeze_bn: 48 | if isinstance(m[1], nn.Conv2d): 49 | for p in m[1].parameters(): 50 | if p.requires_grad: 51 | yield p 52 | else: 53 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 54 | or isinstance(m[1], nn.BatchNorm2d): 55 | for p in m[1].parameters(): 56 | if p.requires_grad: 57 | yield p 58 | 59 | def get_10x_lr_params(self): 60 | modules = [self.aspp, self.decoder] 61 | for i in range(len(modules)): 62 | for m in modules[i].named_modules(): 63 | if self.freeze_bn: 64 | if isinstance(m[1], nn.Conv2d): 65 | for p in m[1].parameters(): 66 | if p.requires_grad: 67 | yield p 68 | else: 69 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 70 | or isinstance(m[1], nn.BatchNorm2d): 71 | for p in m[1].parameters(): 72 | if p.requires_grad: 73 | yield p 74 | 75 | if __name__ == "__main__": 76 | model = DeepLab(backbone='mobilenet', output_stride=16,num_classes=3) 77 | model.eval() 78 | input = torch.rand(1, 3, 513, 513) 79 | output = model(input) 80 | # print(output[0][0,:,0,0]) 81 | # print(output[1][:,0,0]) 82 | t_argmax = output[1] 83 | print(t_argmax.shape) 84 | print(F.one_hot(t_argmax, num_classes=3).permute(0, 3, 1, 2).float().shape) 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /deeplab_trimap/sync_batchnorm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/deeplab_trimap/sync_batchnorm/.DS_Store -------------------------------------------------------------------------------- /deeplab_trimap/sync_batchnorm/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /deeplab_trimap/sync_batchnorm/.ipynb_checkpoints/comm-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /deeplab_trimap/sync_batchnorm/.ipynb_checkpoints/replicate-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /deeplab_trimap/sync_batchnorm/.ipynb_checkpoints/unittest-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /deeplab_trimap/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /deeplab_trimap/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /deeplab_trimap/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /deeplab_trimap/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /deepmatting_matting/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/deepmatting_matting/.DS_Store -------------------------------------------------------------------------------- /deepmatting_matting/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ 3 | BEST_checkpoint.tar 4 | checkpoint.tar 5 | nohup.out 6 | runs/ 7 | train_names.txt 8 | valid_names.txt 9 | -------------------------------------------------------------------------------- /deepmatting_matting/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors 4 | 5 | im_size = 320 6 | unknown_code = 128 7 | epsilon = 1e-6 8 | epsilon_sqr = epsilon ** 2 9 | 10 | num_samples = 43100 11 | num_train = 34480 12 | # num_samples - num_train_samples 13 | num_valid = 8620 14 | 15 | # Training parameters 16 | num_workers = 1 # for data-loading; right now, only 1 works with h5py 17 | grad_clip = 5. # clip gradients at an absolute value of 18 | print_freq = 5 # print training/validation stats every __ batches 19 | checkpoint = 'deepmatting_ckpt.tar' # path to checkpoint, None if none 20 | 21 | ############################################################## 22 | # Set your paths here 23 | 24 | # path to provided foreground images 25 | fg_path = 'data/fg/' 26 | 27 | # path to provided alpha mattes 28 | a_path = 'data/mask/' 29 | 30 | # Path to background images (MSCOCO) 31 | bg_path = 'data/bg/' 32 | 33 | # Path to folder where you want the composited images to go 34 | out_path = 'data/merged/' 35 | 36 | max_size = 1600 37 | fg_path_test = 'data/fg_test/' 38 | a_path_test = 'data/mask_test/' 39 | bg_path_test = 'data/bg_test/' 40 | out_path_test = 'data/merged_test/' 41 | ############################################################## 42 | -------------------------------------------------------------------------------- /deepmatting_matting/demo.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | 5 | import cv2 as cv 6 | import numpy as np 7 | import torch 8 | from torchvision import transforms 9 | 10 | from config import device, fg_path_test, a_path_test, bg_path_test 11 | from data_gen import data_transforms, gen_trimap, fg_test_files, bg_test_files 12 | from test import gen_test_names 13 | from utils import compute_mse, compute_sad, ensure_folder, draw_str 14 | 15 | 16 | def composite4(fg, bg, a, w, h): 17 | print(fg.shape, bg.shape, a.shape, w, h) 18 | fg = np.array(fg, np.float32) 19 | bg_h, bg_w = bg.shape[:2] 20 | x = 0 21 | if bg_w > w: 22 | x = np.random.randint(0, bg_w - w) 23 | y = 0 24 | if bg_h > h: 25 | y = np.random.randint(0, bg_h - h) 26 | bg = np.array(bg[y:y + h, x:x + w], np.float32) 27 | alpha = np.zeros((h, w, 1), np.float32) 28 | alpha[:, :, 0] = a 29 | im = alpha * fg + (1 - alpha) * bg 30 | im = im.astype(np.uint8) 31 | return im, bg 32 | 33 | 34 | def composite4_test(fg, bg, a, w, h): 35 | fg = np.array(fg, np.float32) 36 | bg_h, bg_w = bg.shape[:2] 37 | x = max(0, int((bg_w - w) / 2)) 38 | y = max(0, int((bg_h - h) / 2)) 39 | crop = np.array(bg[y:y + h, x:x + w], np.float32) 40 | alpha = np.zeros((h, w, 1), np.float32) 41 | alpha[:, :, 0] = a / 255. 42 | im = alpha * fg + (1 - alpha) * crop 43 | im = im.astype(np.uint8) 44 | 45 | new_a = np.zeros((bg_h, bg_w), np.uint8) 46 | new_a[y:y + h, x:x + w] = a 47 | new_im = bg.copy() 48 | new_im[y:y + h, x:x + w] = im 49 | return new_im, new_a, fg, bg 50 | 51 | 52 | def process_test(im_name, bg_name): 53 | im = cv.imread(fg_path_test + im_name) 54 | a = cv.imread(a_path_test + im_name, 0) 55 | h, w = im.shape[:2] 56 | bg = cv.imread(bg_path_test + bg_name) 57 | bh, bw = bg.shape[:2] 58 | wratio = w / bw 59 | hratio = h / bh 60 | ratio = wratio if wratio > hratio else hratio 61 | if ratio > 1: 62 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC) 63 | 64 | return composite4_test(im, bg, a, w, h) 65 | 66 | 67 | if __name__ == '__main__': 68 | checkpoint = 'BEST_checkpoint.tar' 69 | checkpoint = torch.load(checkpoint) 70 | model = checkpoint['model'].module 71 | model = model.to(device) 72 | model.eval() 73 | 74 | transformer = data_transforms['valid'] 75 | 76 | ensure_folder('images') 77 | 78 | names = gen_test_names() 79 | names = random.sample(names, 10) 80 | 81 | bg_test = 'data/bg_test/' 82 | new_bgs = [f for f in os.listdir(bg_test) if 83 | os.path.isfile(os.path.join(bg_test, f)) and f.endswith('.jpg')] 84 | new_bgs = random.sample(new_bgs, 10) 85 | 86 | for i, name in enumerate(names): 87 | fcount = int(name.split('.')[0].split('_')[0]) 88 | bcount = int(name.split('.')[0].split('_')[1]) 89 | im_name = fg_test_files[fcount] 90 | bg_name = bg_test_files[bcount] 91 | img, alpha, fg, bg = process_test(im_name, bg_name) 92 | 93 | cv.imwrite('images/{}_image.png'.format(i), img) 94 | cv.imwrite('images/{}_alpha.png'.format(i), alpha) 95 | 96 | print('\nStart processing image: {}'.format(name)) 97 | 98 | h, w = img.shape[:2] 99 | 100 | trimap = gen_trimap(alpha) 101 | cv.imwrite('images/{}_trimap.png'.format(i), trimap) 102 | 103 | x = torch.zeros((1, 4, h, w), dtype=torch.float) 104 | image = img[..., ::-1] # RGB 105 | image = transforms.ToPILImage()(image) 106 | image = transformer(image) 107 | x[0:, 0:3, :, :] = image 108 | x[0:, 3, :, :] = torch.from_numpy(trimap.copy() / 255.) 109 | 110 | # Move to GPU, if available 111 | x = x.type(torch.FloatTensor).to(device) 112 | alpha = alpha / 255. 113 | 114 | with torch.no_grad(): 115 | pred = model(x) 116 | 117 | pred = pred.cpu().numpy() 118 | pred = pred.reshape((h, w)) 119 | 120 | pred[trimap == 0] = 0.0 121 | pred[trimap == 255] = 1.0 122 | 123 | # Calculate loss 124 | # loss = criterion(alpha_out, alpha_label) 125 | mse_loss = compute_mse(pred, alpha, trimap) 126 | sad_loss = compute_sad(pred, alpha) 127 | str_msg = 'sad: %.4f, mse: %.4f' % (sad_loss, mse_loss) 128 | print(str_msg) 129 | 130 | out = (pred.copy() * 255).astype(np.uint8) 131 | draw_str(out, (10, 20), str_msg) 132 | cv.imwrite('images/{}_out.png'.format(i), out) 133 | 134 | new_bg = new_bgs[i] 135 | new_bg = cv.imread(os.path.join(bg_test, new_bg)) 136 | bh, bw = new_bg.shape[:2] 137 | wratio = w / bw 138 | hratio = h / bh 139 | ratio = wratio if wratio > hratio else hratio 140 | print('ratio: ' + str(ratio)) 141 | if ratio > 1: 142 | new_bg = cv.resize(src=new_bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), 143 | interpolation=cv.INTER_CUBIC) 144 | 145 | im, bg = composite4(img, new_bg, pred, w, h) 146 | cv.imwrite('images/{}_compose.png'.format(i), im) 147 | cv.imwrite('images/{}_new_bg.png'.format(i), new_bg) 148 | -------------------------------------------------------------------------------- /deepmatting_matting/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 as cv 4 | import numpy as np 5 | import torch 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | 9 | from config import device 10 | from data_gen import data_transforms 11 | from utils import ensure_folder 12 | 13 | IMG_FOLDER = 'data/alphamatting/input_lowres' 14 | TRIMAP_FOLDERS = ['data/alphamatting/trimap_lowres/Trimap1', 'data/alphamatting/trimap_lowres/Trimap2', 15 | 'data/alphamatting/trimap_lowres/Trimap3'] 16 | OUTPUT_FOLDERS = ['images/alphamatting/output_lowres/Trimap1', 'images/alphamatting/output_lowres/Trimap2', 'images/alphamatting/output_lowres/Trimap3', ] 17 | 18 | if __name__ == '__main__': 19 | checkpoint = 'BEST_checkpoint.tar' 20 | checkpoint = torch.load(checkpoint) 21 | model = checkpoint['model'].module 22 | model = model.to(device) 23 | model.eval() 24 | 25 | transformer = data_transforms['valid'] 26 | 27 | ensure_folder('images') 28 | ensure_folder('images/alphamatting') 29 | ensure_folder(OUTPUT_FOLDERS[0]) 30 | ensure_folder(OUTPUT_FOLDERS[1]) 31 | ensure_folder(OUTPUT_FOLDERS[2]) 32 | 33 | files = [f for f in os.listdir(IMG_FOLDER) if f.endswith('.png')] 34 | 35 | for file in tqdm(files): 36 | filename = os.path.join(IMG_FOLDER, file) 37 | img = cv.imread(filename) 38 | print(img.shape) 39 | h, w = img.shape[:2] 40 | 41 | x = torch.zeros((1, 4, h, w), dtype=torch.float) 42 | image = img[..., ::-1] # RGB 43 | image = transforms.ToPILImage()(image) 44 | image = transformer(image) 45 | x[0:, 0:3, :, :] = image 46 | 47 | for i in range(3): 48 | filename = os.path.join(TRIMAP_FOLDERS[i], file) 49 | print('reading {}...'.format(filename)) 50 | trimap = cv.imread(filename, 0) 51 | x[0:, 3, :, :] = torch.from_numpy(trimap.copy() / 255.) 52 | # print(torch.max(x[0:, 3, :, :])) 53 | # print(torch.min(x[0:, 3, :, :])) 54 | # print(torch.median(x[0:, 3, :, :])) 55 | 56 | # Move to GPU, if available 57 | x = x.type(torch.FloatTensor).to(device) 58 | 59 | with torch.no_grad(): 60 | pred = model(x) 61 | 62 | pred = pred.cpu().numpy() 63 | pred = pred.reshape((h, w)) 64 | 65 | pred[trimap == 0] = 0.0 66 | pred[trimap == 255] = 1.0 67 | 68 | out = (pred.copy() * 255).astype(np.uint8) 69 | 70 | filename = os.path.join(OUTPUT_FOLDERS[i], file) 71 | cv.imwrite(filename, out) 72 | print('wrote {}.'.format(filename)) 73 | -------------------------------------------------------------------------------- /deepmatting_matting/extract.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | 3 | if __name__ == '__main__': 4 | filename = 'data/alphamatting/input_lowres.zip' 5 | print('Extracting {}...'.format(filename)) 6 | with zipfile.ZipFile(filename, 'r') as zip_ref: 7 | zip_ref.extractall('data/alphamatting/') 8 | 9 | filename = 'data/alphamatting/trimap_lowres.zip' 10 | print('Extracting {}...'.format(filename)) 11 | with zipfile.ZipFile(filename, 'r') as zip_ref: 12 | zip_ref.extractall('data/alphamatting/') 13 | -------------------------------------------------------------------------------- /deepmatting_matting/pre_process.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import shutil 5 | import zipfile 6 | import tarfile 7 | 8 | from Combined_Dataset.Training_set.Composition_code_revised import do_composite 9 | from Combined_Dataset.Test_set.Composition_code_revised import do_composite_test 10 | 11 | if __name__ == '__main__': 12 | # path to provided foreground images 13 | fg_path = 'data/fg/' 14 | # path to provided alpha mattes 15 | a_path = 'data/mask/' 16 | # Path to background images (MSCOCO) 17 | bg_path = 'data/bg/' 18 | # Path to folder where you want the composited images to go 19 | out_path = 'data/merged/' 20 | 21 | train_folder = 'data/Combined_Dataset/Training_set/' 22 | 23 | # if not os.path.exists('Combined_Dataset'): 24 | zip_file = 'data/Adobe_Deep_Matting_Dataset.zip' 25 | print('Extracting {}...'.format(zip_file)) 26 | 27 | zip_ref = zipfile.ZipFile(zip_file, 'r') 28 | zip_ref.extractall('data') 29 | zip_ref.close() 30 | 31 | if not os.path.exists(bg_path): 32 | zip_file = 'data/train2014.zip' 33 | print('Extracting {}...'.format(zip_file)) 34 | 35 | zip_ref = zipfile.ZipFile(zip_file, 'r') 36 | zip_ref.extractall('data') 37 | zip_ref.close() 38 | 39 | with open(os.path.join(train_folder, 'training_bg_names.txt')) as f: 40 | training_bg_names = f.read().splitlines() 41 | 42 | os.makedirs(bg_path) 43 | for bg_name in training_bg_names: 44 | src_path = os.path.join('data/train2014', bg_name) 45 | dest_path = os.path.join(bg_path, bg_name) 46 | shutil.move(src_path, dest_path) 47 | 48 | if not os.path.exists(fg_path): 49 | os.makedirs(fg_path) 50 | 51 | for old_folder in [train_folder + 'Adobe-licensed images/fg', train_folder + 'Other/fg']: 52 | fg_files = os.listdir(old_folder) 53 | for fg_file in fg_files: 54 | src_path = os.path.join(old_folder, fg_file) 55 | dest_path = os.path.join(fg_path, fg_file) 56 | shutil.move(src_path, dest_path) 57 | 58 | if not os.path.exists(a_path): 59 | os.makedirs(a_path) 60 | 61 | for old_folder in [train_folder + 'Adobe-licensed images/alpha', train_folder + 'Other/alpha']: 62 | a_files = os.listdir(old_folder) 63 | for a_file in a_files: 64 | src_path = os.path.join(old_folder, a_file) 65 | dest_path = os.path.join(a_path, a_file) 66 | shutil.move(src_path, dest_path) 67 | 68 | if not os.path.exists(out_path): 69 | os.makedirs(out_path) 70 | do_composite() 71 | 72 | # path to provided foreground images 73 | fg_test_path = 'data/fg_test/' 74 | # path to provided alpha mattes 75 | a_test_path = 'data/mask_test/' 76 | # Path to background images (PASCAL VOC) 77 | bg_test_path = 'data/bg_test/' 78 | # Path to folder where you want the composited images to go 79 | out_test_path = 'data/merged_test/' 80 | 81 | # test data gen 82 | test_folder = 'data/Combined_Dataset/Test_set/' 83 | 84 | if not os.path.exists(bg_test_path): 85 | os.makedirs(bg_test_path) 86 | 87 | tar_file = 'data/VOCtrainval_14-Jul-2008.tar' 88 | print('Extracting {}...'.format(tar_file)) 89 | 90 | tar = tarfile.open(tar_file) 91 | tar.extractall('data') 92 | tar.close() 93 | 94 | tar_file = 'data/VOC2008test.tar' 95 | print('Extracting {}...'.format(tar_file)) 96 | 97 | tar = tarfile.open(tar_file) 98 | tar.extractall('data') 99 | tar.close() 100 | 101 | with open(os.path.join(test_folder, 'test_bg_names.txt')) as f: 102 | test_bg_names = f.read().splitlines() 103 | 104 | for bg_name in test_bg_names: 105 | tokens = bg_name.split('_') 106 | src_path = os.path.join('data/VOCdevkit/VOC2008/JPEGImages', bg_name) 107 | dest_path = os.path.join(bg_test_path, bg_name) 108 | shutil.move(src_path, dest_path) 109 | 110 | if not os.path.exists(fg_test_path): 111 | os.makedirs(fg_test_path) 112 | 113 | for old_folder in [test_folder + 'Adobe-licensed images/fg']: 114 | fg_files = os.listdir(old_folder) 115 | for fg_file in fg_files: 116 | src_path = os.path.join(old_folder, fg_file) 117 | dest_path = os.path.join(fg_test_path, fg_file) 118 | shutil.move(src_path, dest_path) 119 | 120 | if not os.path.exists(a_test_path): 121 | os.makedirs(a_test_path) 122 | 123 | for old_folder in [test_folder + 'Adobe-licensed images/alpha']: 124 | a_files = os.listdir(old_folder) 125 | for a_file in a_files: 126 | src_path = os.path.join(old_folder, a_file) 127 | dest_path = os.path.join(a_test_path, a_file) 128 | shutil.move(src_path, dest_path) 129 | 130 | if not os.path.exists(out_test_path): 131 | os.makedirs(out_test_path) 132 | 133 | do_composite_test() 134 | -------------------------------------------------------------------------------- /deepmatting_matting/test.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import cv2 as cv 4 | import numpy as np 5 | import torch 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | 9 | from config import device, fg_path_test, a_path_test, bg_path_test 10 | from data_gen import data_transforms, fg_test_files, bg_test_files 11 | from utils import compute_mse, compute_sad, AverageMeter, get_logger 12 | 13 | 14 | def gen_test_names(): 15 | num_fgs = 50 16 | num_bgs = 1000 17 | num_bgs_per_fg = 20 18 | 19 | names = [] 20 | bcount = 0 21 | for fcount in range(num_fgs): 22 | for i in range(num_bgs_per_fg): 23 | names.append(str(fcount) + '_' + str(bcount) + '.png') 24 | bcount += 1 25 | 26 | return names 27 | 28 | 29 | def process_test(im_name, bg_name, trimap): 30 | # print(bg_path_test + bg_name) 31 | im = cv.imread(fg_path_test + im_name) 32 | a = cv.imread(a_path_test + im_name, 0) 33 | h, w = im.shape[:2] 34 | bg = cv.imread(bg_path_test + bg_name) 35 | bh, bw = bg.shape[:2] 36 | wratio = w / bw 37 | hratio = h / bh 38 | ratio = wratio if wratio > hratio else hratio 39 | if ratio > 1: 40 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC) 41 | 42 | return composite4_test(im, bg, a, w, h, trimap) 43 | 44 | 45 | # def composite4_test(fg, bg, a, w, h): 46 | # fg = np.array(fg, np.float32) 47 | # bg_h, bg_w = bg.shape[:2] 48 | # x = max(0, int((bg_w - w)/2)) 49 | # y = max(0, int((bg_h - h)/2)) 50 | # bg = np.array(bg[y:y + h, x:x + w], np.float32) 51 | # alpha = np.zeros((h, w, 1), np.float32) 52 | # alpha[:, :, 0] = a / 255. 53 | # im = alpha * fg + (1 - alpha) * bg 54 | # im = im.astype(np.uint8) 55 | # print('im.shape: ' + str(im.shape)) 56 | # print('a.shape: ' + str(a.shape)) 57 | # print('fg.shape: ' + str(fg.shape)) 58 | # print('bg.shape: ' + str(bg.shape)) 59 | # return im, a, fg, bg 60 | 61 | 62 | def composite4_test(fg, bg, a, w, h, trimap): 63 | fg = np.array(fg, np.float32) 64 | bg_h, bg_w = bg.shape[:2] 65 | x = max(0, int((bg_w - w) / 2)) 66 | y = max(0, int((bg_h - h) / 2)) 67 | crop = np.array(bg[y:y + h, x:x + w], np.float32) 68 | alpha = np.zeros((h, w, 1), np.float32) 69 | alpha[:, :, 0] = a / 255. 70 | # trimaps = np.zeros((h, w, 1), np.float32) 71 | # trimaps[:,:,0]=trimap/255. 72 | 73 | im = alpha * fg + (1 - alpha) * crop 74 | im = im.astype(np.uint8) 75 | 76 | new_a = np.zeros((bg_h, bg_w), np.uint8) 77 | new_a[y:y + h, x:x + w] = a 78 | new_trimap = np.zeros((bg_h, bg_w), np.uint8) 79 | new_trimap[y:y + h, x:x + w] = trimap 80 | cv.imwrite('images/test/new/' + trimap_name, new_trimap) 81 | new_im = bg.copy() 82 | new_im[y:y + h, x:x + w] = im 83 | # cv.imwrite('images/test/new_im/'+trimap_name,new_im) 84 | return new_im, new_a, fg, bg, new_trimap 85 | 86 | 87 | if __name__ == '__main__': 88 | checkpoint = 'BEST_checkpoint.tar' 89 | checkpoint = torch.load(checkpoint) 90 | model = checkpoint['model'].module 91 | model = model.to(device) 92 | model.eval() 93 | 94 | transformer = data_transforms['valid'] 95 | 96 | names = gen_test_names() 97 | 98 | mse_losses = AverageMeter() 99 | sad_losses = AverageMeter() 100 | 101 | logger = get_logger() 102 | i = 0 103 | for name in tqdm(names): 104 | fcount = int(name.split('.')[0].split('_')[0]) 105 | bcount = int(name.split('.')[0].split('_')[1]) 106 | im_name = fg_test_files[fcount] 107 | # print(im_name) 108 | bg_name = bg_test_files[bcount] 109 | trimap_name = im_name.split('.')[0] + '_' + str(i) + '.png' 110 | # print('trimap_name: ' + str(trimap_name)) 111 | 112 | trimap = cv.imread('data/Combined_Dataset/Test_set/Adobe-licensed images/trimaps/' + trimap_name, 0) 113 | # print('trimap: ' + str(trimap)) 114 | 115 | i += 1 116 | if i == 20: 117 | i = 0 118 | 119 | img, alpha, fg, bg, new_trimap = process_test(im_name, bg_name, trimap) 120 | h, w = img.shape[:2] 121 | # mytrimap = gen_trimap(alpha) 122 | # cv.imwrite('images/test/new_im/'+trimap_name,mytrimap) 123 | 124 | x = torch.zeros((1, 4, h, w), dtype=torch.float) 125 | img = img[..., ::-1] # RGB 126 | img = transforms.ToPILImage()(img) # [3, 320, 320] 127 | img = transformer(img) # [3, 320, 320] 128 | x[0:, 0:3, :, :] = img 129 | x[0:, 3, :, :] = torch.from_numpy(new_trimap.copy() / 255.) 130 | 131 | # Move to GPU, if available 132 | x = x.type(torch.FloatTensor).to(device) # [1, 4, 320, 320] 133 | alpha = alpha / 255. 134 | 135 | with torch.no_grad(): 136 | pred = model(x) # [1, 4, 320, 320] 137 | 138 | pred = pred.cpu().numpy() 139 | pred = pred.reshape((h, w)) # [320, 320] 140 | 141 | pred[new_trimap == 0] = 0.0 142 | pred[new_trimap == 255] = 1.0 143 | cv.imwrite('images/test/out/' + trimap_name, pred * 255) 144 | 145 | # Calculate loss 146 | # loss = criterion(alpha_out, alpha_label) 147 | mse_loss = compute_mse(pred, alpha, trimap) 148 | sad_loss = compute_sad(pred, alpha) 149 | 150 | # Keep track of metrics 151 | mse_losses.update(mse_loss.item()) 152 | sad_losses.update(sad_loss.item()) 153 | print("sad:{} mse:{}".format(sad_loss.item(), mse_loss.item())) 154 | print("sad:{} mse:{}".format(sad_losses.avg, mse_losses.avg)) 155 | print("sad:{} mse:{}".format(sad_losses.avg, mse_losses.avg)) 156 | -------------------------------------------------------------------------------- /deepmatting_matting/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import cv2 as cv 6 | import numpy as np 7 | import torch 8 | 9 | from config import im_size, epsilon, epsilon_sqr 10 | 11 | 12 | def clip_gradient(optimizer, grad_clip): 13 | """ 14 | Clips gradients computed during backpropagation to avoid explosion of gradients. 15 | :param optimizer: optimizer with the gradients to be clipped 16 | :param grad_clip: clip value 17 | """ 18 | for group in optimizer.param_groups: 19 | for param in group['params']: 20 | if param.grad is not None: 21 | param.grad.data.clamp_(-grad_clip, grad_clip) 22 | 23 | 24 | def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, loss, is_best): 25 | state = {'epoch': epoch, 26 | 'epochs_since_improvement': epochs_since_improvement, 27 | 'loss': loss, 28 | 'model': model, 29 | 'optimizer': optimizer} 30 | # filename = 'checkpoint_' + str(epoch) + '_' + str(loss) + '.tar' 31 | filename = 'checkpoint.tar' 32 | torch.save(state, filename) 33 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 34 | if is_best: 35 | torch.save(state, 'BEST_checkpoint.tar') 36 | 37 | 38 | class AverageMeter(object): 39 | """ 40 | Keeps track of most recent, average, sum, and count of a metric. 41 | """ 42 | 43 | def __init__(self): 44 | self.reset() 45 | 46 | def reset(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def update(self, val, n=1): 53 | self.val = val 54 | self.sum += val * n 55 | self.count += n 56 | self.avg = self.sum / self.count 57 | 58 | 59 | def adjust_learning_rate(optimizer, shrink_factor): 60 | """ 61 | Shrinks learning rate by a specified factor. 62 | :param optimizer: optimizer whose learning rate must be shrunk. 63 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 64 | """ 65 | 66 | print("\nDECAYING learning rate.") 67 | for param_group in optimizer.param_groups: 68 | param_group['lr'] = param_group['lr'] * shrink_factor 69 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 70 | 71 | 72 | def get_learning_rate(optimizer): 73 | return optimizer.param_groups[0]['lr'] 74 | 75 | 76 | def accuracy(scores, targets, k=1): 77 | batch_size = targets.size(0) 78 | _, ind = scores.topk(k, 1, True, True) 79 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 80 | correct_total = correct.view(-1).float().sum() # 0D tensor 81 | return correct_total.item() * (100.0 / batch_size) 82 | 83 | 84 | def parse_args(): 85 | parser = argparse.ArgumentParser(description='Train face network') 86 | # general 87 | parser.add_argument('--end-epoch', type=int, default=20, help='training epoch size.') 88 | parser.add_argument('--lr', type=float, default=0.01, help='start learning rate') 89 | parser.add_argument('--lr-step', type=int, default=10, help='period of learning rate decay') 90 | parser.add_argument('--optimizer', default='sgd', help='optimizer') 91 | parser.add_argument('--weight-decay', type=float, default=0.0, help='weight decay') 92 | parser.add_argument('--mom', type=float, default=0.9, help='momentum') 93 | parser.add_argument('--batch-size', type=int, default=1, help='batch size in each context') 94 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint') 95 | parser.add_argument('--pretrained', type=bool, default=True, help='pretrained model') 96 | args = parser.parse_args() 97 | return args 98 | 99 | 100 | def get_logger(): 101 | logger = logging.getLogger() 102 | handler = logging.StreamHandler() 103 | formatter = logging.Formatter("%(asctime)s %(levelname)s \t%(message)s") 104 | handler.setFormatter(formatter) 105 | logger.addHandler(handler) 106 | logger.setLevel(logging.DEBUG) 107 | return logger 108 | 109 | 110 | def safe_crop(mat, x, y, crop_size=(im_size, im_size)): 111 | crop_height, crop_width = crop_size 112 | if len(mat.shape) == 2: 113 | ret = np.zeros((crop_height, crop_width), np.uint8) 114 | else: 115 | ret = np.zeros((crop_height, crop_width, 3), np.uint8) 116 | crop = mat[y:y + crop_height, x:x + crop_width] 117 | h, w = crop.shape[:2] 118 | ret[0:h, 0:w] = crop 119 | if crop_size != (im_size, im_size): 120 | ret = cv.resize(ret, dsize=(im_size, im_size), interpolation=cv.INTER_NEAREST) 121 | return ret 122 | 123 | 124 | # alpha prediction loss: the abosolute difference between the ground truth alpha values and the 125 | # predicted alpha values at each pixel. However, due to the non-differentiable property of 126 | # absolute values, we use the following loss function to approximate it. 127 | def alpha_prediction_loss(y_pred, y_true): 128 | mask = y_true[:, 1, :] 129 | diff = y_pred[:, 0, :] - y_true[:, 0, :] 130 | diff = diff * mask 131 | num_pixels = torch.sum(mask) 132 | return torch.sum(torch.sqrt(torch.pow(diff, 2) + epsilon_sqr)) / (num_pixels + epsilon) 133 | 134 | 135 | # compute the MSE error given a prediction, a ground truth and a trimap. 136 | # pred: the predicted alpha matte 137 | # target: the ground truth alpha matte 138 | # trimap: the given trimap 139 | # 140 | def compute_mse(pred, alpha, trimap): 141 | num_pixels = float((trimap == 128).sum()) 142 | return ((pred - alpha) ** 2).sum() / num_pixels 143 | 144 | 145 | # compute the SAD error given a prediction and a ground truth. 146 | # 147 | def compute_sad(pred, alpha): 148 | diff = np.abs(pred - alpha) 149 | return np.sum(diff) / 1000 150 | 151 | 152 | def draw_str(dst, target, s): 153 | x, y = target 154 | cv.putText(dst, s, (x + 1, y + 1), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 0), thickness=2, lineType=cv.LINE_AA) 155 | cv.putText(dst, s, (x, y), cv.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), lineType=cv.LINE_AA) 156 | 157 | 158 | def ensure_folder(folder): 159 | if not os.path.exists(folder): 160 | os.makedirs(folder) 161 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import zipfile 3 | import shutil 4 | import os 5 | 6 | 7 | def download_file_from_google_drive(id, destination): 8 | URL = "https://docs.google.com/uc?export=download" 9 | 10 | session = requests.Session() 11 | 12 | response = session.get(URL, params = { 'id' : id }, stream = True) 13 | token = get_confirm_token(response) 14 | 15 | if token: 16 | params = { 'id' : id, 'confirm' : token } 17 | response = session.get(URL, params = params, stream = True) 18 | 19 | save_response_content(response, destination) 20 | 21 | def get_confirm_token(response): 22 | for key, value in response.cookies.items(): 23 | if key.startswith('download_warning'): 24 | return value 25 | 26 | return None 27 | 28 | def save_response_content(response, destination): 29 | CHUNK_SIZE = 32768 30 | 31 | with open(destination, "wb") as f: 32 | for chunk in response.iter_content(CHUNK_SIZE): 33 | if chunk: # filter out keep-alive new chunks 34 | f.write(chunk) 35 | 36 | if __name__ == '__main__': 37 | file_id = '1h6vVnlFpWk7G2dlzc9DZuKzUuqvloA25' 38 | file_path = os.path.join(os.getcwd(), 'ckpt.zip') 39 | print("download...") 40 | download_file_from_google_drive(file_id, file_path) 41 | 42 | with zipfile.ZipFile(file_path) as f: 43 | for file in f.namelist(): 44 | f.extract(file, os.getcwd()) 45 | 46 | shutil.move("ckpt/deeplab_model_best.pth.tar","deeplab_trimap/checkpoint/") 47 | shutil.move("ckpt/stylegan2-ffhq-config-f.pt","./") 48 | shutil.move("ckpt/gca-dist-all-data.pth","gca_matting/checkpoints_finetune/") 49 | 50 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | IndexNet Matting 3 | 4 | Indices Matter: Learning to Index for Deep Image Matting 5 | IEEE/CVF International Conference on Computer Vision, 2019 6 | 7 | This software is strictly limited to academic purposes only 8 | Copyright (c) 2019, Hao Lu (hao.lu@adelaide.edu.au) 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | * Redistributions of source code must retain the above copyright notice, this 15 | list of conditions and the following disclaimer. 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """ 31 | 32 | import numpy as np 33 | import cv2 as cv 34 | from scipy.ndimage import gaussian_filter, morphology 35 | from skimage.measure import label, regionprops 36 | 37 | 38 | # compute the SAD error given a pdiction, a ground truth and a mask 39 | def compute_sad_loss(pd, gt, mask): 40 | cv.normalize(pd, pd, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 41 | cv.normalize(gt, gt, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 42 | error_map = np.abs(pd - gt) / 255. 43 | loss = np.sum(error_map * mask) 44 | # the loss is scaled by 1000 due to the large images 45 | loss = loss / 1000 46 | return loss 47 | 48 | 49 | # compute the MSE error 50 | def compute_mse_loss(pd, gt, mask): 51 | cv.normalize(pd, pd, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 52 | cv.normalize(gt, gt, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 53 | error_map = (pd - gt) / 255. 54 | loss = np.sum(np.square(error_map) * mask) / np.sum(mask) 55 | return loss 56 | 57 | 58 | # compute the gradient error 59 | def compute_gradient_loss(pd, gt, mask): 60 | cv.normalize(pd, pd, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 61 | cv.normalize(gt, gt, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 62 | pd = pd / 255. 63 | gt = gt / 255. 64 | pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32) 65 | pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32) 66 | gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32) 67 | gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32) 68 | pd_mag = np.sqrt(pd_x**2 + pd_y**2) 69 | gt_mag = np.sqrt(gt_x**2 + gt_y**2) 70 | 71 | error_map = np.square(pd_mag - gt_mag) 72 | loss = np.sum(error_map * mask) / 10 73 | return loss 74 | 75 | 76 | # compute the connectivity error 77 | def compute_connectivity_loss(pd, gt, mask, step=0.1): 78 | cv.normalize(pd, pd, 0, 255, cv.NORM_MINMAX) 79 | cv.normalize(gt, gt, 0, 255, cv.NORM_MINMAX) 80 | pd = pd / 255. 81 | gt = gt / 255. 82 | 83 | h, w = pd.shape 84 | 85 | thresh_steps = np.arange(0, 1.1, step) 86 | l_map = -1 * np.ones((h, w), dtype=np.float32) 87 | lambda_map = np.ones((h, w), dtype=np.float32) 88 | for i in range(1, thresh_steps.size): 89 | pd_th = pd >= thresh_steps[i] 90 | gt_th = gt >= thresh_steps[i] 91 | 92 | label_image = label(pd_th & gt_th, connectivity=1) 93 | cc = regionprops(label_image) 94 | size_vec = np.array([c.area for c in cc]) 95 | if len(size_vec) == 0: 96 | continue 97 | max_id = np.argmax(size_vec) 98 | coords = cc[max_id].coords 99 | 100 | omega = np.zeros((h, w), dtype=np.float32) 101 | omega[coords[:, 0], coords[:, 1]] = 1 102 | 103 | flag = (l_map == -1) & (omega == 0) 104 | l_map[flag == 1] = thresh_steps[i-1] 105 | 106 | dist_maps = morphology.distance_transform_edt(omega==0) 107 | dist_maps = dist_maps / dist_maps.max() 108 | # lambda_map[flag == 1] = dist_maps.mean() 109 | l_map[l_map == -1] = 1 110 | 111 | # the definition of lambda is ambiguous 112 | d_pd = pd - l_map 113 | d_gt = gt - l_map 114 | # phi_pd = 1 - lambda_map * d_pd * (d_pd >= 0.15).astype(np.float32) 115 | # phi_gt = 1 - lambda_map * d_gt * (d_gt >= 0.15).astype(np.float32) 116 | phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32) 117 | phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32) 118 | loss = np.sum(np.abs(phi_pd - phi_gt) * mask) / 1000 119 | return loss 120 | 121 | 122 | def image_alignment(x, output_stride, odd=False): 123 | imsize = np.asarray(x.shape[:2], dtype=np.float) 124 | if odd: 125 | new_imsize = np.ceil(imsize / output_stride) * output_stride + 1 126 | else: 127 | new_imsize = np.ceil(imsize / output_stride) * output_stride 128 | h, w = int(new_imsize[0]), int(new_imsize[1]) 129 | 130 | x1 = x[:, :, 0:3] 131 | x2 = x[:, :, 3] 132 | new_x1 = cv.resize(x1, dsize=(w,h), interpolation=cv.INTER_CUBIC) 133 | new_x2 = cv.resize(x2, dsize=(w,h), interpolation=cv.INTER_NEAREST) 134 | 135 | new_x2 = np.expand_dims(new_x2, axis=2) 136 | new_x = np.concatenate((new_x1, new_x2), axis=2) 137 | 138 | return new_x 139 | 140 | 141 | def image_rescale(x, scale): 142 | x1 = x[:, :, 0:3] 143 | x2 = x[:, :, 3] 144 | new_x1 = cv.resize(x1, None, fx=scale, fy=scale, interpolation=cv.INTER_CUBIC) 145 | new_x2 = cv.resize(x2, None, fx=scale, fy=scale, interpolation=cv.INTER_NEAREST) 146 | new_x2 = np.expand_dims(new_x2, axis=2) 147 | new_x = np.concatenate((new_x1,new_x2), axis=2) 148 | return new_x 149 | -------------------------------------------------------------------------------- /fg_pred.py: -------------------------------------------------------------------------------- 1 | from pymatting import * 2 | import numpy as np 3 | 4 | scale = 1.0 5 | 6 | image = load_image("./image/00130.png", "RGB", scale, "box") 7 | trimap = load_image("./trimap/00130.png", "GRAY", scale, "nearest") 8 | 9 | # estimate alpha from image and trimap 10 | alpha = estimate_alpha_cf(image, trimap) 11 | 12 | # make gray background 13 | background = np.zeros(image.shape) 14 | background[:, :] = [0.5, 0.5, 0.5] 15 | 16 | # estimate foreground from image and alpha 17 | foreground = estimate_foreground_ml(image, alpha) 18 | 19 | save_image("foreground/00130.png", foreground) 20 | # # blend foreground with background and alpha, less color bleeding 21 | # new_image = blend(foreground, background, alpha) 22 | 23 | # # save results in a grid 24 | # images = [image, trimap, alpha, new_image] 25 | # grid = make_grid(images) 26 | # save_image("lemur_grid.png", grid) 27 | 28 | # # save cutout 29 | # cutout = stack_images(foreground, alpha) 30 | # save_image("lemur_cutout.png", cutout) 31 | 32 | # # just blending the image with alpha results in color bleeding 33 | # color_bleeding = blend(image, background, alpha) 34 | # grid = make_grid([color_bleeding, new_image]) 35 | # save_image("lemur_color_bleeding.png", grid) 36 | -------------------------------------------------------------------------------- /func.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn, autograd, optim 9 | from torch.nn import functional as F 10 | from torch.utils import data 11 | import torch.distributed as dist 12 | from torchvision import transforms, utils 13 | from tqdm import tqdm 14 | from PIL import Image 15 | def remove_prefix_state_dict(state_dict, prefix="module"): 16 | """ 17 | remove prefix from the key of pretrained state dict for Data-Parallel 18 | """ 19 | new_state_dict = {} 20 | first_state_name = list(state_dict.keys())[0] 21 | if not first_state_name.startswith(prefix): 22 | for key, value in state_dict.items(): 23 | new_state_dict[key] = state_dict[key].float() 24 | else: 25 | for key, value in state_dict.items(): 26 | new_state_dict[key[len(prefix)+1:]] = state_dict[key].float() 27 | return new_state_dict 28 | 29 | 30 | 31 | def data_sampler(dataset, shuffle, distributed): 32 | if distributed: 33 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 34 | 35 | if shuffle: 36 | return data.RandomSampler(dataset) 37 | 38 | else: 39 | return data.SequentialSampler(dataset) 40 | 41 | 42 | def requires_grad(model, flag=True): 43 | for p in model.parameters(): 44 | p.requires_grad = flag 45 | 46 | 47 | def accumulate(model1, model2, decay=0.999): 48 | par1 = dict(model1.named_parameters()) 49 | par2 = dict(model2.named_parameters()) 50 | 51 | for k in par1.keys(): 52 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 53 | 54 | 55 | def sample_data(loader): 56 | while True: 57 | for batch in loader: 58 | yield batch 59 | 60 | 61 | def d_logistic_loss(real_pred, fake_pred): 62 | real_loss = F.softplus(-real_pred) 63 | fake_loss = F.softplus(fake_pred) 64 | 65 | return real_loss.mean() + fake_loss.mean() 66 | 67 | 68 | def d_r1_loss(real_pred, real_img): 69 | grad_real, = autograd.grad( 70 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 71 | ) 72 | grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() 73 | 74 | return grad_penalty 75 | 76 | 77 | def g_nonsaturating_loss(fake_pred): 78 | loss = F.softplus(-fake_pred).mean() 79 | 80 | return loss 81 | 82 | 83 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 84 | noise = torch.randn_like(fake_img) / math.sqrt( 85 | fake_img.shape[2] * fake_img.shape[3] 86 | ) 87 | grad, = autograd.grad( 88 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 89 | ) 90 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 91 | 92 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 93 | 94 | path_penalty = (path_lengths - path_mean).pow(2).mean() 95 | 96 | return path_penalty, path_mean.detach(), path_lengths 97 | 98 | 99 | def make_noise(batch, latent_dim, n_noise, device): 100 | if n_noise == 1: 101 | return torch.randn(batch, latent_dim, device=device) 102 | 103 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) 104 | 105 | return noises 106 | 107 | 108 | def mixing_noise(batch, latent_dim, prob, device): 109 | if prob > 0 and random.random() < prob: 110 | return make_noise(batch, latent_dim, 2, device) 111 | 112 | else: 113 | return [make_noise(batch, latent_dim, 1, device)] 114 | 115 | 116 | def set_grad_none(model, targets): 117 | for n, p in model.named_parameters(): 118 | if n in targets: 119 | p.grad = None 120 | 121 | def remove_prefix_state_dict(state_dict, prefix="module"): 122 | """ 123 | remove prefix from the key of pretrained state dict for Data-Parallel 124 | """ 125 | new_state_dict = {} 126 | first_state_name = list(state_dict.keys())[0] 127 | if not first_state_name.startswith(prefix): 128 | for key, value in state_dict.items(): 129 | new_state_dict[key] = state_dict[key].float() 130 | else: 131 | for key, value in state_dict.items(): 132 | new_state_dict[key[len(prefix)+1:]] = state_dict[key].float() 133 | return new_state_dict -------------------------------------------------------------------------------- /gca_matting/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/gca_matting/.DS_Store -------------------------------------------------------------------------------- /gca_matting/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .generators import * -------------------------------------------------------------------------------- /gca_matting/.ipynb_checkpoints/generators-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from . import encoders, decoders 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, encoder, decoder): 9 | 10 | super(Generator, self).__init__() 11 | 12 | if encoder not in encoders.__all__: 13 | raise NotImplementedError("Unknown Encoder {}".format(encoder)) 14 | self.encoder = encoders.__dict__[encoder]() 15 | 16 | if decoder not in decoders.__all__: 17 | raise NotImplementedError("Unknown Decoder {}".format(decoder)) 18 | self.decoder = decoders.__dict__[decoder]() 19 | 20 | def forward(self, image, trimap): 21 | inp = torch.cat((image, trimap), dim=1) 22 | embedding, mid_fea = self.encoder(inp) 23 | alpha, info_dict = self.decoder(embedding, mid_fea) 24 | 25 | return alpha, info_dict 26 | 27 | 28 | def get_generator(encoder, decoder): 29 | generator = Generator(encoder=encoder, decoder=decoder) 30 | return generator 31 | 32 | def remove_prefix_state_dict(state_dict, prefix="module"): 33 | """ 34 | remove prefix from the key of pretrained state dict for Data-Parallel 35 | """ 36 | new_state_dict = {} 37 | first_state_name = list(state_dict.keys())[0] 38 | if not first_state_name.startswith(prefix): 39 | for key, value in state_dict.items(): 40 | new_state_dict[key] = state_dict[key].float() 41 | else: 42 | for key, value in state_dict.items(): 43 | new_state_dict[key[len(prefix)+1:]] = state_dict[key].float() 44 | return new_state_dict 45 | 46 | if __name__ == "__main__": 47 | gca_model = get_generator(encoder='resnet_gca_encoder_29', decoder='res_gca_decoder_22') 48 | gca_ckpt = torch.load("checkpoints\gca-dist-all-data.pth") 49 | gca_model.load_state_dict(remove_prefix_state_dict(gca_ckpt['state_dict']), strict=True) 50 | gca_model.eval() 51 | input = torch.rand(1, 3, 512, 512) 52 | output = gca_model(input, input) 53 | print(output[0].shape) 54 | print(output[1]) -------------------------------------------------------------------------------- /gca_matting/__init__.py: -------------------------------------------------------------------------------- 1 | from .generators import * -------------------------------------------------------------------------------- /gca_matting/decoders/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/gca_matting/decoders/.DS_Store -------------------------------------------------------------------------------- /gca_matting/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_dec import ResNet_D_Dec, BasicBlock 2 | from .res_shortcut_dec import ResShortCut_D_Dec 3 | from .res_gca_dec import ResGuidedCxtAtten_Dec 4 | 5 | 6 | __all__ = ['res_shortcut_decoder_22', 'res_gca_decoder_22'] 7 | 8 | 9 | def _res_shortcut_D_dec(block, layers, **kwargs): 10 | model = ResShortCut_D_Dec(block, layers, **kwargs) 11 | return model 12 | 13 | 14 | def _res_gca_D_dec(block, layers, **kwargs): 15 | model = ResGuidedCxtAtten_Dec(block, layers, **kwargs) 16 | return model 17 | 18 | 19 | def res_shortcut_decoder_22(**kwargs): 20 | """Constructs a resnet_encoder_14 model. 21 | """ 22 | return _res_shortcut_D_dec(BasicBlock, [2, 3, 3, 2], **kwargs) 23 | 24 | 25 | def res_gca_decoder_22(**kwargs): 26 | """Constructs a resnet_encoder_14 model. 27 | """ 28 | return _res_gca_D_dec(BasicBlock, [2, 3, 3, 2], **kwargs) -------------------------------------------------------------------------------- /gca_matting/decoders/res_gca_dec.py: -------------------------------------------------------------------------------- 1 | from .ops import GuidedCxtAtten, SpectralNorm 2 | from .res_shortcut_dec import ResShortCut_D_Dec 3 | 4 | 5 | class ResGuidedCxtAtten_Dec(ResShortCut_D_Dec): 6 | 7 | def __init__(self, block, layers, norm_layer=None, large_kernel=False): 8 | super(ResGuidedCxtAtten_Dec, self).__init__(block, layers, norm_layer, large_kernel) 9 | self.gca = GuidedCxtAtten(128, 128) 10 | 11 | def forward(self, x, mid_fea): 12 | fea1, fea2, fea3, fea4, fea5 = mid_fea['shortcut'] 13 | im = mid_fea['image_fea'] 14 | x = self.layer1(x) + fea5 # N x 256 x 32 x 32 15 | x = self.layer2(x) + fea4 # N x 128 x 64 x 64 16 | x, offset = self.gca(im, x, mid_fea['unknown']) # contextual attention 17 | x = self.layer3(x) + fea3 # N x 64 x 128 x 128 18 | x = self.layer4(x) + fea2 # N x 32 x 256 x 256 19 | x = self.conv1(x) 20 | x = self.bn1(x) 21 | x = self.leaky_relu(x) + fea1 22 | x = self.conv2(x) 23 | 24 | alpha = (self.tanh(x) + 1.0) / 2.0 25 | 26 | return alpha, {'offset_1': mid_fea['offset_1'], 'offset_2': offset} 27 | 28 | -------------------------------------------------------------------------------- /gca_matting/decoders/res_shortcut_dec.py: -------------------------------------------------------------------------------- 1 | from .resnet_dec import ResNet_D_Dec 2 | 3 | 4 | class ResShortCut_D_Dec(ResNet_D_Dec): 5 | 6 | def __init__(self, block, layers, norm_layer=None, large_kernel=False, late_downsample=False): 7 | super(ResShortCut_D_Dec, self).__init__(block, layers, norm_layer, large_kernel, 8 | late_downsample=late_downsample) 9 | 10 | def forward(self, x, mid_fea): 11 | fea1, fea2, fea3, fea4, fea5 = mid_fea['shortcut'] 12 | x = self.layer1(x) + fea5 13 | x = self.layer2(x) + fea4 14 | x = self.layer3(x) + fea3 15 | x = self.layer4(x) + fea2 16 | x = self.conv1(x) 17 | x = self.bn1(x) 18 | x = self.leaky_relu(x) + fea1 19 | x = self.conv2(x) 20 | 21 | alpha = (self.tanh(x) + 1.0) / 2.0 22 | 23 | return alpha, None 24 | 25 | -------------------------------------------------------------------------------- /gca_matting/encoders/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/gca_matting/encoders/.DS_Store -------------------------------------------------------------------------------- /gca_matting/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .resnet_enc import ResNet_D, BasicBlock 3 | from .res_shortcut_enc import ResShortCut_D 4 | from .res_gca_enc import ResGuidedCxtAtten 5 | 6 | 7 | __all__ = ['res_shortcut_encoder_29', 'resnet_gca_encoder_29'] 8 | 9 | 10 | def _res_shortcut_D(block, layers, **kwargs): 11 | model = ResShortCut_D(block, layers, **kwargs) 12 | return model 13 | 14 | 15 | def _res_gca_D(block, layers, **kwargs): 16 | model = ResGuidedCxtAtten(block, layers, **kwargs) 17 | return model 18 | 19 | 20 | def resnet_gca_encoder_29(**kwargs): 21 | """Constructs a resnet_encoder_29 model. 22 | """ 23 | return _res_gca_D(BasicBlock, [3, 4, 4, 2], **kwargs) 24 | 25 | 26 | def res_shortcut_encoder_29(**kwargs): 27 | """Constructs a resnet_encoder_25 model. 28 | """ 29 | return _res_shortcut_D(BasicBlock, [3, 4, 4, 2], **kwargs) 30 | 31 | 32 | if __name__ == "__main__": 33 | import torch 34 | logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] %(levelname)s: %(message)s', 35 | datefmt='%m-%d %H:%M:%S') 36 | resnet_encoder = res_shortcut_encoder_29() 37 | x = torch.randn(4,6,512,512) 38 | z = resnet_encoder(x) 39 | print(z[0].shape) 40 | -------------------------------------------------------------------------------- /gca_matting/encoders/res_gca_enc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .resnet_enc import ResNet_D 5 | from .ops import GuidedCxtAtten, SpectralNorm 6 | 7 | 8 | class ResGuidedCxtAtten(ResNet_D): 9 | 10 | def __init__(self, block, layers, norm_layer=None, late_downsample=False): 11 | super(ResGuidedCxtAtten, self).__init__(block, layers, norm_layer, late_downsample=late_downsample) 12 | self.trimap_channel = 3 13 | first_inplane = 3 + self.trimap_channel 14 | self.shortcut_inplane = [first_inplane, self.midplanes, 64, 128, 256] 15 | self.shortcut_plane = [32, self.midplanes, 64, 128, 256] 16 | 17 | self.shortcut = nn.ModuleList() 18 | for stage, inplane in enumerate(self.shortcut_inplane): 19 | self.shortcut.append(self._make_shortcut(inplane, self.shortcut_plane[stage])) 20 | 21 | self.guidance_head = nn.Sequential( 22 | nn.ReflectionPad2d(1), 23 | SpectralNorm(nn.Conv2d(3, 16, kernel_size=3, padding=0, stride=2, bias=False)), 24 | nn.ReLU(inplace=True), 25 | self._norm_layer(16), 26 | nn.ReflectionPad2d(1), 27 | SpectralNorm(nn.Conv2d(16, 32, kernel_size=3, padding=0, stride=2, bias=False)), 28 | nn.ReLU(inplace=True), 29 | self._norm_layer(32), 30 | nn.ReflectionPad2d(1), 31 | SpectralNorm(nn.Conv2d(32, 128, kernel_size=3, padding=0, stride=2, bias=False)), 32 | nn.ReLU(inplace=True), 33 | self._norm_layer(128) 34 | ) 35 | 36 | self.gca = GuidedCxtAtten(128, 128) 37 | 38 | # initialize guidance head 39 | for layers in range(len(self.guidance_head)): 40 | m = self.guidance_head[layers] 41 | if isinstance(m, nn.Conv2d): 42 | if hasattr(m, "weight_bar"): 43 | nn.init.xavier_uniform_(m.weight_bar) 44 | elif isinstance(m, nn.BatchNorm2d): 45 | nn.init.constant_(m.weight, 1) 46 | nn.init.constant_(m.bias, 0) 47 | 48 | def _make_shortcut(self, inplane, planes): 49 | return nn.Sequential( 50 | SpectralNorm(nn.Conv2d(inplane, planes, kernel_size=3, padding=1, bias=False)), 51 | nn.ReLU(inplace=True), 52 | self._norm_layer(planes), 53 | SpectralNorm(nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)), 54 | nn.ReLU(inplace=True), 55 | self._norm_layer(planes) 56 | ) 57 | 58 | def forward(self, x): 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.activation(out) 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | x1 = self.activation(out) # N x 32 x 256 x 256 66 | out = self.conv3(x1) 67 | out = self.bn3(out) 68 | out = self.activation(out) 69 | 70 | im_fea = self.guidance_head(x[:,:3,...]) # downsample origin image and extract features 71 | if self.trimap_channel == 3: 72 | unknown = F.interpolate(x[:,4:5,...], scale_factor=1/8, mode='nearest') 73 | else: 74 | unknown = F.interpolate(x[:,3:,...].eq(1.).float(), scale_factor=1/8, mode='nearest') 75 | 76 | x2 = self.layer1(out) # N x 64 x 128 x 128 77 | x3= self.layer2(x2) # N x 128 x 64 x 64 78 | x3, offset = self.gca(im_fea, x3, unknown) # contextual attention 79 | x4 = self.layer3(x3) # N x 256 x 32 x 32 80 | out = self.layer_bottleneck(x4) # N x 512 x 16 x 16 81 | 82 | fea1 = self.shortcut[0](x) # input image and trimap 83 | fea2 = self.shortcut[1](x1) 84 | fea3 = self.shortcut[2](x2) 85 | fea4 = self.shortcut[3](x3) 86 | fea5 = self.shortcut[4](x4) 87 | 88 | return out, {'shortcut':(fea1, fea2, fea3, fea4, fea5), 89 | 'image_fea':im_fea, 90 | 'unknown':unknown, 91 | 'offset_1':offset} 92 | 93 | 94 | if __name__ == "__main__": 95 | from networks.encoders.resnet_enc import BasicBlock 96 | m = ResGuidedCxtAtten(BasicBlock, [3, 4, 4, 2]) 97 | for m in m.modules(): 98 | print(m) 99 | -------------------------------------------------------------------------------- /gca_matting/encoders/res_shortcut_enc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .resnet_enc import ResNet_D 3 | from .ops import SpectralNorm 4 | 5 | 6 | class ResShortCut_D(ResNet_D): 7 | 8 | def __init__(self, block, layers, norm_layer=None, late_downsample=False): 9 | super(ResShortCut_D, self).__init__(block, layers, norm_layer, late_downsample=late_downsample) 10 | first_inplane = 3 + 3#CONFIG.model.trimap_channel 11 | self.shortcut_inplane = [first_inplane, self.midplanes, 64, 128, 256] 12 | self.shortcut_plane = [32, self.midplanes, 64, 128, 256] 13 | 14 | self.shortcut = nn.ModuleList() 15 | for stage, inplane in enumerate(self.shortcut_inplane): 16 | self.shortcut.append(self._make_shortcut(inplane, self.shortcut_plane[stage])) 17 | 18 | def _make_shortcut(self, inplane, planes): 19 | return nn.Sequential( 20 | SpectralNorm(nn.Conv2d(inplane, planes, kernel_size=3, padding=1, bias=False)), 21 | nn.ReLU(inplace=True), 22 | self._norm_layer(planes), 23 | SpectralNorm(nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)), 24 | nn.ReLU(inplace=True), 25 | self._norm_layer(planes) 26 | ) 27 | 28 | def forward(self, x): 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.activation(out) 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | x1 = self.activation(out) # N x 32 x 256 x 256 36 | out = self.conv3(x1) 37 | out = self.bn3(out) 38 | out = self.activation(out) 39 | 40 | x2 = self.layer1(out) # N x 64 x 128 x 128 41 | x3= self.layer2(x2) # N x 128 x 64 x 64 42 | x4 = self.layer3(x3) # N x 256 x 32 x 32 43 | out = self.layer_bottleneck(x4) # N x 512 x 16 x 16 44 | 45 | fea1 = self.shortcut[0](x) # input image and trimap 46 | fea2 = self.shortcut[1](x1) 47 | fea3 = self.shortcut[2](x2) 48 | fea4 = self.shortcut[3](x3) 49 | fea5 = self.shortcut[4](x4) 50 | 51 | return out, {'shortcut':(fea1, fea2, fea3, fea4, fea5), 'image':x[:,:3,...]} -------------------------------------------------------------------------------- /gca_matting/generators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from . import encoders, decoders 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, encoder, decoder): 9 | 10 | super(Generator, self).__init__() 11 | 12 | if encoder not in encoders.__all__: 13 | raise NotImplementedError("Unknown Encoder {}".format(encoder)) 14 | self.encoder = encoders.__dict__[encoder]() 15 | 16 | if decoder not in decoders.__all__: 17 | raise NotImplementedError("Unknown Decoder {}".format(decoder)) 18 | self.decoder = decoders.__dict__[decoder]() 19 | 20 | def forward(self, image, trimap): 21 | inp = torch.cat((image, trimap), dim=1) 22 | embedding, mid_fea = self.encoder(inp) 23 | alpha, info_dict = self.decoder(embedding, mid_fea) 24 | 25 | return alpha, info_dict 26 | 27 | 28 | def get_generator(encoder, decoder): 29 | generator = Generator(encoder=encoder, decoder=decoder) 30 | return generator 31 | 32 | def remove_prefix_state_dict(state_dict, prefix="module"): 33 | """ 34 | remove prefix from the key of pretrained state dict for Data-Parallel 35 | """ 36 | new_state_dict = {} 37 | first_state_name = list(state_dict.keys())[0] 38 | if not first_state_name.startswith(prefix): 39 | for key, value in state_dict.items(): 40 | new_state_dict[key] = state_dict[key].float() 41 | else: 42 | for key, value in state_dict.items(): 43 | new_state_dict[key[len(prefix)+1:]] = state_dict[key].float() 44 | return new_state_dict 45 | 46 | if __name__ == "__main__": 47 | gca_model = get_generator(encoder='resnet_gca_encoder_29', decoder='res_gca_decoder_22') 48 | gca_ckpt = torch.load("checkpoints\gca-dist-all-data.pth") 49 | gca_model.load_state_dict(remove_prefix_state_dict(gca_ckpt['state_dict']), strict=True) 50 | gca_model.eval() 51 | input = torch.rand(1, 3, 512, 512) 52 | output = gca_model(input, input) 53 | print(output[0].shape) 54 | print(output[1]) -------------------------------------------------------------------------------- /indexnet_matting/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/indexnet_matting/.DS_Store -------------------------------------------------------------------------------- /indexnet_matting/.ipynb_checkpoints/demo-checkpoint.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import cv2 as cv 5 | from time import time 6 | from PIL import Image 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | from hlmobilenetv2 import hlmobilenetv2 12 | 13 | # ignore warnings 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | IMG_SCALE = 1./255 18 | IMG_MEAN = np.array([0.485, 0.456, 0.406, 0]).reshape((1, 1, 4)) 19 | IMG_STD = np.array([0.229, 0.224, 0.225, 1]).reshape((1, 1, 4)) 20 | 21 | STRIDE = 32 22 | RESTORE_FROM = 'indexnet_matting.pth.tar' 23 | RESULT_DIR = './examples/mattes' 24 | 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | 27 | if not os.path.exists(RESULT_DIR): 28 | os.makedirs(RESULT_DIR) 29 | 30 | # load pretrained model 31 | net = hlmobilenetv2( 32 | pretrained=False, 33 | freeze_bn=True, 34 | output_stride=STRIDE, 35 | apply_aspp=True, 36 | conv_operator='std_conv', 37 | decoder='indexnet', 38 | decoder_kernel_size=5, 39 | indexnet='depthwise', 40 | index_mode='m2o', 41 | use_nonlinear=True, 42 | use_context=True 43 | ) 44 | 45 | try: 46 | checkpoint = torch.load(RESTORE_FROM, map_location=device) 47 | pretrained_dict = OrderedDict() 48 | for key, value in checkpoint['state_dict'].items(): 49 | if 'module' in key: 50 | key = key[7:] 51 | pretrained_dict[key] = value 52 | except: 53 | raise Exception('Please download the pretrained model!') 54 | net.load_state_dict(pretrained_dict) 55 | net.to(device) 56 | if torch.cuda.is_available(): 57 | net = nn.DataParallel(net) 58 | 59 | # switch to eval mode 60 | net.eval() 61 | 62 | def read_image(x): 63 | img_arr = np.array(Image.open(x)) 64 | return img_arr 65 | 66 | def image_alignment(x, output_stride, odd=False): 67 | imsize = np.asarray(x.shape[:2], dtype=np.float) 68 | if odd: 69 | new_imsize = np.ceil(imsize / output_stride) * output_stride + 1 70 | else: 71 | new_imsize = np.ceil(imsize / output_stride) * output_stride 72 | h, w = int(new_imsize[0]), int(new_imsize[1]) 73 | 74 | x1 = x[:, :, 0:3] 75 | x2 = x[:, :, 3] 76 | new_x1 = cv.resize(x1, dsize=(w,h), interpolation=cv.INTER_CUBIC) 77 | new_x2 = cv.resize(x2, dsize=(w,h), interpolation=cv.INTER_NEAREST) 78 | 79 | new_x2 = np.expand_dims(new_x2, axis=2) 80 | new_x = np.concatenate((new_x1, new_x2), axis=2) 81 | 82 | return new_x 83 | 84 | def inference(image_path, trimap_path): 85 | with torch.no_grad(): 86 | image, trimap = read_image(image_path), read_image(trimap_path) 87 | trimap = np.expand_dims(trimap, axis=2) 88 | image = np.concatenate((image, trimap), axis=2) 89 | 90 | h, w = image.shape[:2] 91 | 92 | image = image.astype('float32') 93 | image = (IMG_SCALE * image - IMG_MEAN) / IMG_STD 94 | image = image.astype('float32') 95 | 96 | image = image_alignment(image, STRIDE) 97 | inputs = torch.from_numpy(np.expand_dims(image.transpose(2, 0, 1), axis=0)) 98 | inputs = inputs.to(device) 99 | 100 | print(inputs[0,3,:,:]) 101 | input(inputs.shape) 102 | # inference 103 | start = time() 104 | outputs = net(inputs) 105 | end = time() 106 | 107 | outputs = outputs.squeeze().cpu().numpy() 108 | 109 | alpha = cv.resize(outputs, dsize=(w,h), interpolation=cv.INTER_CUBIC) 110 | alpha = np.clip(alpha, 0, 1) * 255. 111 | trimap = trimap.squeeze() 112 | mask = np.equal(trimap, 128).astype(np.float32) 113 | alpha = (1 - mask) * trimap + mask * alpha 114 | 115 | _, image_name = os.path.split(image_path) 116 | Image.fromarray(alpha.astype(np.uint8)).save(os.path.join(RESULT_DIR, image_name)) 117 | # Image.fromarray(alpha.astype(np.uint8)).show() 118 | 119 | running_frame_rate = 1 * float(1 / (end - start)) # batch_size = 1 120 | print('framerate: {0:.2f}Hz'.format(running_frame_rate)) 121 | 122 | 123 | if __name__ == "__main__": 124 | image_path = [ 125 | '/home/zeyang/dataset/dataset/testing/00298.png' 126 | ] 127 | trimap_path = [ 128 | '/home/zeyang/dataset/dataset/gt/00298_matte.png' 129 | ] 130 | for image, trimap in zip(image_path, trimap_path): 131 | inference(image, trimap) 132 | -------------------------------------------------------------------------------- /indexnet_matting/.ipynb_checkpoints/demo_indexnet_matting-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | IndexNet Matting 3 | 4 | Indices Matter: Learning to Index for Deep Image Matting 5 | IEEE/CVF International Conference on Computer Vision, 2019 6 | 7 | This software is strictly limited to academic purposes only 8 | Copyright (c) 2019, Hao Lu (hao.lu@adelaide.edu.au) 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | * Redistributions of source code must retain the above copyright notice, this 15 | list of conditions and the following disclaimer. 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """ 31 | 32 | import os 33 | import cv2 34 | from time import time 35 | from PIL import Image 36 | 37 | import torch 38 | import torch.nn as nn 39 | from torchvision import transforms 40 | from torch.utils.data import DataLoader 41 | 42 | from hlvggnet import hlvgg16 43 | from hlmobilenetv2 import hlmobilenetv2 44 | from hldataset import AdobeImageMattingDataset, Normalize, ToTensor 45 | from utils import * 46 | 47 | IMG_SCALE = 1./255 48 | IMG_MEAN = np.array([0.485, 0.456, 0.406, 0]).reshape((1, 1, 4)) 49 | IMG_STD = np.array([0.229, 0.224, 0.225, 1]).reshape((1, 1, 4)) 50 | DATA_DIR = '/media/hao/DATA/Combined_Dataset' 51 | DATA_TEST_LIST = './lists/test.txt' 52 | 53 | STRIDE = 32 54 | RESTORE_FROM = './pretrained/indexnet_matting.pth.tar' 55 | RESULT_DIR = './results/indexnet_matting' 56 | 57 | if not os.path.exists(RESULT_DIR): 58 | os.makedirs(RESULT_DIR) 59 | 60 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 61 | 62 | 63 | # instantiate network 64 | net = hlmobilenetv2( 65 | pretrained=False, 66 | freeze_bn=True, 67 | output_stride=STRIDE, 68 | apply_aspp=True, 69 | conv_operator='std_conv', 70 | decoder='indexnet', 71 | decoder_kernel_size=5, 72 | indexnet='depthwise', 73 | index_mode='m2o', 74 | use_nonlinear=True, 75 | use_context=True 76 | ) 77 | net = nn.DataParallel(net) 78 | net.to(device) 79 | 80 | try: 81 | checkpoint = torch.load(RESTORE_FROM) 82 | pretrained_dict = checkpoint['state_dict'] 83 | except: 84 | raise Exception('Please download the pretrained model!') 85 | 86 | net.load_state_dict(pretrained_dict) 87 | net.to(device) 88 | 89 | dataset = AdobeImageMattingDataset 90 | testset = dataset( 91 | data_file=DATA_TEST_LIST, 92 | data_dir=DATA_DIR, 93 | train=False, 94 | transform=transforms.Compose([ 95 | Normalize(IMG_SCALE, IMG_MEAN, IMG_STD), 96 | ToTensor()] 97 | ) 98 | ) 99 | test_loader = DataLoader( 100 | testset, 101 | batch_size=1, 102 | shuffle=False, 103 | num_workers=0, 104 | pin_memory=False 105 | ) 106 | 107 | image_list = [name.split('\t') for name in open(DATA_TEST_LIST).read().splitlines()] 108 | # switch to eval mode 109 | net.eval() 110 | 111 | with torch.no_grad(): 112 | sad = [] 113 | mse = [] 114 | grad = [] 115 | conn = [] 116 | avg_frame_rate = 0 117 | start = time() 118 | for i, sample in enumerate(test_loader): 119 | image, target = sample['image'], sample['alpha'] 120 | 121 | h, w = image.size()[2:] 122 | image = image.squeeze().numpy().transpose(1, 2, 0) 123 | image = image_alignment(image, STRIDE, odd=False) 124 | inputs = torch.from_numpy(np.expand_dims(image.transpose(2, 0, 1), axis=0)) 125 | 126 | # inference 127 | torch.cuda.synchronize() 128 | start = time() 129 | outputs = net(inputs.cuda()).squeeze().cpu().numpy() 130 | torch.cuda.synchronize() 131 | end = time() 132 | 133 | alpha = cv.resize(outputs, dsize=(w,h), interpolation=cv.INTER_CUBIC) 134 | alpha = np.clip(alpha, 0, 1) * 255. 135 | trimap = target[:, 1, :, :].squeeze().numpy() 136 | mask = np.equal(trimap, 128).astype(np.float32) 137 | 138 | alpha = (1 - mask) * trimap + mask * alpha 139 | gt_alpha = target[:, 0, :, :].squeeze().numpy() * 255. 140 | 141 | _, image_name = os.path.split(image_list[i][0]) 142 | Image.fromarray(alpha.astype(np.uint8)).save( 143 | os.path.join(RESULT_DIR, image_name) 144 | ) 145 | # Image.fromarray(gt_alpha.astype(np.uint8)).show() 146 | 147 | sad.append(compute_sad_loss(alpha, gt_alpha, mask)) 148 | mse.append(compute_mse_loss(alpha, gt_alpha, mask)) 149 | 150 | running_frame_rate = 1 * float(1 / (end - start)) # batch_size = 1 151 | avg_frame_rate = (avg_frame_rate*i + running_frame_rate)/(i+1) 152 | print( 153 | 'test: {0}/{1}, sad: {2:.2f}, SAD: {3:.2f}, MSE: {4:.4f},' 154 | ' framerate: {5:.2f}Hz/{6:.2f}Hz' 155 | .format(i+1, len(test_loader), sad[-1], np.mean(sad), np.mean(mse), 156 | running_frame_rate, avg_frame_rate) 157 | ) 158 | -------------------------------------------------------------------------------- /indexnet_matting/.ipynb_checkpoints/hlconv-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | IndexNet Matting 3 | 4 | Indices Matter: Learning to Index for Deep Image Matting 5 | IEEE/CVF International Conference on Computer Vision, 2019 6 | 7 | This software is strictly limited to academic purposes only 8 | Copyright (c) 2019, Hao Lu (hao.lu@adelaide.edu.au) 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | * Redistributions of source code must retain the above copyright notice, this 15 | list of conditions and the following disclaimer. 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """ 31 | 32 | 33 | import torch 34 | import torch.nn as nn 35 | from .lib.nn import SynchronizedBatchNorm2d 36 | 37 | def conv_bn(inp, oup, k=3, s=1, BatchNorm2d=SynchronizedBatchNorm2d): 38 | return nn.Sequential( 39 | nn.Conv2d(inp, oup, k, s, padding=k//2, bias=False), 40 | BatchNorm2d(oup), 41 | nn.ReLU6(inplace=True) 42 | ) 43 | 44 | def dep_sep_conv_bn(inp, oup, k=3, s=1, BatchNorm2d=SynchronizedBatchNorm2d): 45 | return nn.Sequential( 46 | nn.Conv2d(inp, inp, k, s, padding=k//2, groups=inp, bias=False), 47 | BatchNorm2d(inp), 48 | nn.ReLU6(inplace=True), 49 | nn.Conv2d(inp, oup, 1, 1, padding=0, bias=False), 50 | BatchNorm2d(oup), 51 | nn.ReLU6(inplace=True) 52 | ) 53 | 54 | hlconv = { 55 | 'std_conv': conv_bn, 56 | 'dep_sep_conv': dep_sep_conv_bn 57 | } -------------------------------------------------------------------------------- /indexnet_matting/Composition_code.py: -------------------------------------------------------------------------------- 1 | ##Copyright 2017 Adobe Systems Inc. 2 | ## 3 | ##Licensed under the Apache License, Version 2.0 (the "License"); 4 | ##you may not use this file except in compliance with the License. 5 | ##You may obtain a copy of the License at 6 | ## 7 | ## http://www.apache.org/licenses/LICENSE-2.0 8 | ## 9 | ##Unless required by applicable law or agreed to in writing, software 10 | ##distributed under the License is distributed on an "AS IS" BASIS, 11 | ##WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | ##See the License for the specific language governing permissions and 13 | ##limitations under the License. 14 | 15 | 16 | ############################################################## 17 | #Set your paths here 18 | 19 | #path to provided foreground images 20 | fg_path = 'fg/' 21 | 22 | #path to provided alpha mattes 23 | a_path = 'alpha/' 24 | 25 | #Path to background images (MSCOCO) 26 | bg_path = 'train2014/' 27 | 28 | #Path to folder where you want the composited images to go 29 | out_path = 'merged_cv/' 30 | 31 | ############################################################## 32 | 33 | import numpy as np 34 | from PIL import Image 35 | import os 36 | import math 37 | import time 38 | import cv2 as cv 39 | 40 | def composite4(fg, bg, a, w, h): 41 | fg = np.array(fg, np.float32) 42 | bg = np.array(bg[0:h, 0:w], np.float32) 43 | alpha = np.zeros((h, w, 1), np.float32) 44 | alpha[:, :, 0] = a / 255. 45 | comp = alpha * fg + (1 - alpha) * bg 46 | comp = comp.astype(np.uint8) 47 | return comp 48 | 49 | num_bgs = 100 50 | 51 | fg_files = [name for name in open('training_fg_names.txt').read().splitlines()] 52 | bg_files = [name for name in open('training_bg_names.txt').read().splitlines()] 53 | 54 | bg_iter = iter(bg_files) 55 | for k, im_name in enumerate(fg_files): 56 | 57 | im = cv.imread(fg_path + im_name) 58 | a = cv.imread(a_path + im_name, 0) 59 | h, w = im.shape[:2] 60 | 61 | bcount = 0 62 | for i in range(num_bgs): 63 | 64 | bg_name = next(bg_iter) 65 | bg = cv.imread(bg_path + bg_name) 66 | bh, bw = bg.shape[:2] 67 | wratio = float(w) / float(bw) 68 | hratio = float(h) / float(bh) 69 | ratio = wratio if wratio > hratio else hratio 70 | if ratio > 1: 71 | bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC) 72 | 73 | out = composite4(im, bg, a, w, h) 74 | filename = out_path + im_name[:len(im_name)-4] + '_' + str(bcount) + '.png' 75 | 76 | cv.imwrite(filename, out, [cv.IMWRITE_PNG_COMPRESSION, 9]) 77 | 78 | bcount += 1 79 | print(k*num_bgs + bcount) 80 | -------------------------------------------------------------------------------- /indexnet_matting/demo.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import cv2 as cv 5 | from time import time 6 | from PIL import Image 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | from hlmobilenetv2 import hlmobilenetv2 12 | 13 | # ignore warnings 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | IMG_SCALE = 1./255 18 | IMG_MEAN = np.array([0.485, 0.456, 0.406, 0]).reshape((1, 1, 4)) 19 | IMG_STD = np.array([0.229, 0.224, 0.225, 1]).reshape((1, 1, 4)) 20 | 21 | STRIDE = 32 22 | RESTORE_FROM = 'indexnet_matting.pth.tar' 23 | RESULT_DIR = './examples/mattes' 24 | 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | 27 | if not os.path.exists(RESULT_DIR): 28 | os.makedirs(RESULT_DIR) 29 | 30 | # load pretrained model 31 | net = hlmobilenetv2( 32 | pretrained=False, 33 | freeze_bn=True, 34 | output_stride=STRIDE, 35 | apply_aspp=True, 36 | conv_operator='std_conv', 37 | decoder='indexnet', 38 | decoder_kernel_size=5, 39 | indexnet='depthwise', 40 | index_mode='m2o', 41 | use_nonlinear=True, 42 | use_context=True 43 | ) 44 | 45 | try: 46 | checkpoint = torch.load(RESTORE_FROM, map_location=device) 47 | pretrained_dict = OrderedDict() 48 | for key, value in checkpoint['state_dict'].items(): 49 | if 'module' in key: 50 | key = key[7:] 51 | pretrained_dict[key] = value 52 | except: 53 | raise Exception('Please download the pretrained model!') 54 | net.load_state_dict(pretrained_dict) 55 | net.to(device) 56 | if torch.cuda.is_available(): 57 | net = nn.DataParallel(net) 58 | 59 | # switch to eval mode 60 | net.eval() 61 | 62 | def read_image(x): 63 | img_arr = np.array(Image.open(x)) 64 | return img_arr 65 | 66 | def image_alignment(x, output_stride, odd=False): 67 | imsize = np.asarray(x.shape[:2], dtype=np.float) 68 | if odd: 69 | new_imsize = np.ceil(imsize / output_stride) * output_stride + 1 70 | else: 71 | new_imsize = np.ceil(imsize / output_stride) * output_stride 72 | h, w = int(new_imsize[0]), int(new_imsize[1]) 73 | 74 | x1 = x[:, :, 0:3] 75 | x2 = x[:, :, 3] 76 | new_x1 = cv.resize(x1, dsize=(w,h), interpolation=cv.INTER_CUBIC) 77 | new_x2 = cv.resize(x2, dsize=(w,h), interpolation=cv.INTER_NEAREST) 78 | 79 | new_x2 = np.expand_dims(new_x2, axis=2) 80 | new_x = np.concatenate((new_x1, new_x2), axis=2) 81 | 82 | return new_x 83 | 84 | def inference(image_path, trimap_path): 85 | with torch.no_grad(): 86 | image, trimap = read_image(image_path), read_image(trimap_path) 87 | # [1024,1024,3] [1024,1024,1] 88 | trimap = np.expand_dims(trimap, axis=2) 89 | image = np.concatenate((image, trimap), axis=2) 90 | 91 | h, w = image.shape[:2] 92 | 93 | image = image.astype('float32') 94 | image = (IMG_SCALE * image - IMG_MEAN) / IMG_STD 95 | image = image.astype('float32') 96 | 97 | image = image_alignment(image, STRIDE) 98 | inputs = torch.from_numpy(np.expand_dims(image.transpose(2, 0, 1), axis=0)) 99 | inputs = inputs.to(device) 100 | 101 | print(inputs[0,3,:,:]) 102 | input(inputs.shape) 103 | # inference 104 | start = time() 105 | outputs = net(inputs) 106 | end = time() 107 | 108 | outputs = outputs.squeeze().cpu().numpy() 109 | 110 | alpha = cv.resize(outputs, dsize=(w,h), interpolation=cv.INTER_CUBIC) 111 | alpha = np.clip(alpha, 0, 1) * 255. 112 | trimap = trimap.squeeze() 113 | mask = np.equal(trimap, 128).astype(np.float32) 114 | alpha = (1 - mask) * trimap + mask * alpha 115 | 116 | _, image_name = os.path.split(image_path) 117 | Image.fromarray(alpha.astype(np.uint8)).save(os.path.join(RESULT_DIR, image_name)) 118 | # Image.fromarray(alpha.astype(np.uint8)).show() 119 | 120 | running_frame_rate = 1 * float(1 / (end - start)) # batch_size = 1 121 | print('framerate: {0:.2f}Hz'.format(running_frame_rate)) 122 | 123 | 124 | if __name__ == "__main__": 125 | image_path = [ 126 | '/home/zeyang/dataset/dataset/testing/00298.png' 127 | ] 128 | trimap_path = [ 129 | '/home/zeyang/dataset/dataset/gt/00298_matte.png' 130 | ] 131 | for image, trimap in zip(image_path, trimap_path): 132 | inference(image, trimap) 133 | -------------------------------------------------------------------------------- /indexnet_matting/demo_deep_matting.py: -------------------------------------------------------------------------------- 1 | """ 2 | IndexNet Matting 3 | 4 | Indices Matter: Learning to Index for Deep Image Matting 5 | IEEE/CVF International Conference on Computer Vision, 2019 6 | 7 | This software is strictly limited to academic purposes only 8 | Copyright (c) 2019, Hao Lu (hao.lu@adelaide.edu.au) 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | * Redistributions of source code must retain the above copyright notice, this 15 | list of conditions and the following disclaimer. 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """ 31 | 32 | import os 33 | import cv2 34 | from time import time 35 | from PIL import Image 36 | 37 | import torch 38 | import torch.nn as nn 39 | from torchvision import transforms 40 | from torch.utils.data import DataLoader 41 | 42 | from hlvggnet import hlvgg16 43 | from hlmobilenetv2 import hlmobilenetv2 44 | from hldataset import AdobeImageMattingDataset, Normalize, ToTensor 45 | from utils import * 46 | 47 | IMG_SCALE = 1./255 48 | IMG_MEAN = np.array([0.485, 0.456, 0.406, 0]).reshape((1, 1, 4)) 49 | IMG_STD = np.array([0.229, 0.224, 0.225, 1]).reshape((1, 1, 4)) 50 | DATA_TEST_LIST = './lists/test.txt' 51 | DATA_DIR = '/media/hao/DATA/Combined_Dataset' 52 | 53 | STRIDE = 32 54 | MODEL = 'deep_matting' 55 | RESTORE_FROM = './pretrained/'+MODEL+'.pth.tar' 56 | RESULT_DIR = './results/'+MODEL 57 | 58 | if not os.path.exists(RESULT_DIR): 59 | os.makedirs(RESULT_DIR) 60 | 61 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 62 | 63 | 64 | # instantiate network 65 | net = hlvgg16(pretrained=False) 66 | net.to(device) 67 | 68 | try: 69 | checkpoint = torch.load(RESTORE_FROM) 70 | pretrained_dict = checkpoint['state_dict'] 71 | except: 72 | raise Exception('Please download the pretrained model!') 73 | 74 | net.load_state_dict(pretrained_dict) 75 | net.to(device) 76 | 77 | dataset = AdobeImageMattingDataset 78 | testset = dataset( 79 | data_file=DATA_TEST_LIST, 80 | data_dir=DATA_DIR, 81 | train=False, 82 | transform=transforms.Compose([ 83 | Normalize(IMG_SCALE, IMG_MEAN, IMG_STD), 84 | ToTensor()] 85 | ) 86 | ) 87 | test_loader = DataLoader( 88 | testset, 89 | batch_size=1, 90 | shuffle=False, 91 | num_workers=0, 92 | pin_memory=False 93 | ) 94 | 95 | image_list = [name.split('\t') for name in open(DATA_TEST_LIST).read().splitlines()] 96 | # switch to eval mode 97 | net.eval() 98 | 99 | with torch.no_grad(): 100 | sad = [] 101 | mse = [] 102 | grad = [] 103 | conn = [] 104 | avg_frame_rate = 0 105 | start = time() 106 | for i in range(0, len(test_loader)): 107 | sample = testset.__getitem__(i) 108 | image, target = sample['image'], sample['alpha'] 109 | image = torch.unsqueeze(image, dim=0) 110 | target = torch.unsqueeze(target, dim=0) 111 | 112 | h, w = image.size()[2:] 113 | image = image.squeeze().numpy().transpose(1, 2, 0) 114 | 115 | # ---------------------------------------------------------------- 116 | # NOTICE: 117 | # comment the following line to test the image in full resolution 118 | # notice that full resolution requires a lot of GPU memory 119 | # alternatively, you can test the image on CPU is GPU is out of memory 120 | # this is for demo use only, formal evaluation should be on the full resolution 121 | image = image_rescale(image, 0.5) 122 | # ---------------------------------------------------------------- 123 | 124 | image = image_alignment(image, STRIDE, odd=False) 125 | inputs = torch.from_numpy(np.expand_dims(image.transpose(2, 0, 1), axis=0)) 126 | 127 | # inference 128 | torch.cuda.synchronize() 129 | start = time() 130 | outputs = net(inputs.to(device)).squeeze().cpu().numpy() 131 | torch.cuda.synchronize() 132 | end = time() 133 | 134 | alpha = cv.resize(outputs, dsize=(w,h), interpolation=cv.INTER_CUBIC) 135 | alpha = np.clip(alpha, 0, 1) * 255. 136 | 137 | trimap = target[:, 1, :, :].squeeze().numpy() 138 | mask = np.equal(trimap, 128).astype(np.float32) 139 | 140 | alpha = (1 - mask) * trimap + mask * alpha 141 | gt_alpha = target[:, 0, :, :].squeeze().numpy() * 255. 142 | 143 | alpha.astype(np.uint8) 144 | gt_alpha.astype(np.uint8) 145 | 146 | path, image_name = os.path.split(image_list[i][0]) 147 | Image.fromarray(alpha.astype(np.uint8)).save(os.path.join(RESULT_DIR, image_name)) 148 | # Image.fromarray(gt_alpha.astype(np.uint8)).show() 149 | 150 | sad.append(compute_sad_loss(alpha, gt_alpha, mask)) 151 | mse.append(compute_mse_loss(alpha, gt_alpha, mask)) 152 | 153 | running_frame_rate = 1 * float(1 / (end - start)) # batch_size = 1 154 | avg_frame_rate = (avg_frame_rate*i + running_frame_rate)/(i+1) 155 | print( 156 | 'test: {0}/{1}, sad: {2:.2f}, SAD: {3:.2f}, MSE: {4:.4f},' 157 | ' framerate: {5:.2f}Hz/{6:.2f}Hz' 158 | .format(i+1, len(test_loader), sad[-1], np.mean(sad), np.mean(mse), 159 | running_frame_rate, avg_frame_rate) 160 | ) 161 | -------------------------------------------------------------------------------- /indexnet_matting/demo_indexnet_matting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from time import time 4 | from PIL import Image 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | from torch.utils.data import DataLoader 10 | 11 | from hlvggnet import hlvgg16 12 | from hlmobilenetv2 import hlmobilenetv2 13 | from hldataset import AdobeImageMattingDataset, Normalize, ToTensor 14 | from utils import * 15 | 16 | IMG_SCALE = 1./255 17 | IMG_MEAN = np.array([0.485, 0.456, 0.406, 0]).reshape((1, 1, 4)) 18 | IMG_STD = np.array([0.229, 0.224, 0.225, 1]).reshape((1, 1, 4)) 19 | DATA_DIR = '/media/hao/DATA/Combined_Dataset' 20 | DATA_TEST_LIST = './lists/test.txt' 21 | 22 | STRIDE = 32 23 | RESTORE_FROM = './pretrained/indexnet_matting.pth.tar' 24 | RESULT_DIR = './results/indexnet_matting' 25 | 26 | if not os.path.exists(RESULT_DIR): 27 | os.makedirs(RESULT_DIR) 28 | 29 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 30 | 31 | 32 | # instantiate network 33 | net = hlmobilenetv2( 34 | pretrained=False, 35 | freeze_bn=True, 36 | output_stride=STRIDE, 37 | apply_aspp=True, 38 | conv_operator='std_conv', 39 | decoder='indexnet', 40 | decoder_kernel_size=5, 41 | indexnet='depthwise', 42 | index_mode='m2o', 43 | use_nonlinear=True, 44 | use_context=True 45 | ) 46 | net = nn.DataParallel(net) 47 | net.to(device) 48 | 49 | try: 50 | checkpoint = torch.load(RESTORE_FROM) 51 | pretrained_dict = checkpoint['state_dict'] 52 | except: 53 | raise Exception('Please download the pretrained model!') 54 | 55 | net.load_state_dict(pretrained_dict) 56 | net.to(device) 57 | 58 | dataset = AdobeImageMattingDataset 59 | testset = dataset( 60 | data_file=DATA_TEST_LIST, 61 | data_dir=DATA_DIR, 62 | train=False, 63 | transform=transforms.Compose([ 64 | Normalize(IMG_SCALE, IMG_MEAN, IMG_STD), 65 | ToTensor()] 66 | ) 67 | ) 68 | test_loader = DataLoader( 69 | testset, 70 | batch_size=1, 71 | shuffle=False, 72 | num_workers=0, 73 | pin_memory=False 74 | ) 75 | 76 | image_list = [name.split('\t') for name in open(DATA_TEST_LIST).read().splitlines()] 77 | # switch to eval mode 78 | net.eval() 79 | 80 | with torch.no_grad(): 81 | sad = [] 82 | mse = [] 83 | grad = [] 84 | conn = [] 85 | avg_frame_rate = 0 86 | start = time() 87 | for i, sample in enumerate(test_loader): 88 | image, target = sample['image'], sample['alpha'] 89 | 90 | h, w = image.size()[2:] 91 | image = image.squeeze().numpy().transpose(1, 2, 0) 92 | image = image_alignment(image, STRIDE, odd=False) 93 | inputs = torch.from_numpy(np.expand_dims(image.transpose(2, 0, 1), axis=0)) 94 | 95 | # inference 96 | torch.cuda.synchronize() 97 | start = time() 98 | outputs = net(inputs.cuda()).squeeze().cpu().numpy() 99 | torch.cuda.synchronize() 100 | end = time() 101 | 102 | alpha = cv.resize(outputs, dsize=(w,h), interpolation=cv.INTER_CUBIC) 103 | alpha = np.clip(alpha, 0, 1) * 255. 104 | trimap = target[:, 1, :, :].squeeze().numpy() 105 | mask = np.equal(trimap, 128).astype(np.float32) 106 | 107 | alpha = (1 - mask) * trimap + mask * alpha 108 | gt_alpha = target[:, 0, :, :].squeeze().numpy() * 255. 109 | 110 | _, image_name = os.path.split(image_list[i][0]) 111 | Image.fromarray(alpha.astype(np.uint8)).save( 112 | os.path.join(RESULT_DIR, image_name) 113 | ) 114 | # Image.fromarray(gt_alpha.astype(np.uint8)).show() 115 | 116 | sad.append(compute_sad_loss(alpha, gt_alpha, mask)) 117 | mse.append(compute_mse_loss(alpha, gt_alpha, mask)) 118 | 119 | running_frame_rate = 1 * float(1 / (end - start)) # batch_size = 1 120 | avg_frame_rate = (avg_frame_rate*i + running_frame_rate)/(i+1) 121 | print( 122 | 'test: {0}/{1}, sad: {2:.2f}, SAD: {3:.2f}, MSE: {4:.4f},' 123 | ' framerate: {5:.2f}Hz/{6:.2f}Hz' 124 | .format(i+1, len(test_loader), sad[-1], np.mean(sad), np.mean(mse), 125 | running_frame_rate, avg_frame_rate) 126 | ) 127 | -------------------------------------------------------------------------------- /indexnet_matting/evaluation_code/compute_connectivity_error.m: -------------------------------------------------------------------------------- 1 | % compute the connectivity error given a prediction, a ground truth and a trimap. 2 | % author Ning Xu 3 | % date 2018-1-1 4 | 5 | % pred: the predicted alpha matte 6 | % target: the ground truth alpha matte 7 | % trimap: the given trimap 8 | % step = 0.1 9 | 10 | function loss = compute_connectivity_error(pred,target,trimap,step) 11 | pred = single(pred)/255; 12 | target = single(target)/255; 13 | 14 | [dimy,dimx] = size(pred); 15 | 16 | thresh_steps = 0:step:1; 17 | l_map = ones(size(pred))*(-1); 18 | dist_maps = zeros([dimy,dimx,numel(thresh_steps)]); 19 | for ii = 2:numel(thresh_steps) 20 | pred_alpha_thresh = pred>=thresh_steps(ii); 21 | target_alpha_thresh = target>=thresh_steps(ii); 22 | 23 | cc = bwconncomp(pred_alpha_thresh & target_alpha_thresh,4); 24 | size_vec = cellfun(@numel,cc.PixelIdxList); 25 | if isempty(size_vec) 26 | continue 27 | end 28 | [~,max_id] = max(size_vec); 29 | 30 | omega = zeros([dimy,dimx]); 31 | omega(cc.PixelIdxList{max_id}) = 1; 32 | 33 | flag = l_map==-1 & omega==0; 34 | l_map(flag==1) = thresh_steps(ii-1); 35 | 36 | dist_maps(:,:,ii) = bwdist(omega); 37 | dist_maps(:,:,ii) = dist_maps(:,:,ii) / max(max(dist_maps(:,:,ii))); 38 | end 39 | l_map(l_map==-1) = 1; 40 | 41 | pred_d = pred - l_map; 42 | target_d = target - l_map; 43 | 44 | pred_phi = 1 - pred_d .* single(pred_d>=0.15); 45 | target_phi = 1 - target_d .* single(target_d>=0.15); 46 | 47 | loss = sum(sum(abs(pred_phi - target_phi).*single(trimap==128))); 48 | 49 | -------------------------------------------------------------------------------- /indexnet_matting/evaluation_code/compute_gradient_loss.m: -------------------------------------------------------------------------------- 1 | % compute the gradient error given a prediction, a ground truth and a trimap. 2 | % author Ning Xu 3 | % date 2018-1-1 4 | 5 | % pred: the predicted alpha matte 6 | % target: the ground truth alpha matte 7 | % trimap: the given trimap 8 | % step = 0.1 9 | 10 | function loss = compute_gradient_loss(pred,target,trimap) 11 | pred = mat2gray(pred); 12 | target = mat2gray(target); 13 | [pred_x,pred_y] = gaussgradient(pred,1.4); 14 | [target_x,target_y] = gaussgradient(target,1.4); 15 | pred_amp = sqrt(pred_x.^2 + pred_y.^2); 16 | target_amp = sqrt(target_x.^2 + target_y.^2); 17 | 18 | error_map = (single(pred_amp) - single(target_amp)).^2; 19 | loss = sum(sum(error_map.*single(trimap==128))); 20 | -------------------------------------------------------------------------------- /indexnet_matting/evaluation_code/compute_mse_loss.m: -------------------------------------------------------------------------------- 1 | % compute the MSE error given a prediction, a ground truth and a trimap. 2 | % author Ning Xu 3 | % date 2018-1-1 4 | 5 | % pred: the predicted alpha matte 6 | % target: the ground truth alpha matte 7 | % trimap: the given trimap 8 | 9 | function loss = compute_mse_loss(pred,target,trimap) 10 | error_map = (single(pred)-single(target))/255; 11 | loss = sum(sum(error_map.^2.*single(trimap==128))) / sum(sum(single(trimap==128))); 12 | -------------------------------------------------------------------------------- /indexnet_matting/evaluation_code/compute_sad_loss.m: -------------------------------------------------------------------------------- 1 | % compute the SAD error given a prediction, a ground truth and a trimap. 2 | % author Ning Xu 3 | % date 2018-1-1 4 | 5 | function loss = compute_sad_loss(pred,target,trimap) 6 | error_map = abs(single(pred)-single(target))/255; 7 | loss = sum(sum(error_map.*single(trimap==128))) ; 8 | 9 | % the loss is scaled by 1000 due to the large images used in our experiment. 10 | % Please check the result table in our paper to make sure the result is correct. 11 | loss = loss / 1000 ; 12 | -------------------------------------------------------------------------------- /indexnet_matting/evaluation_code/gaussgradient.m: -------------------------------------------------------------------------------- 1 | function [gx,gy]=gaussgradient(IM,sigma) 2 | %GAUSSGRADIENT Gradient using first order derivative of Gaussian. 3 | % [gx,gy]=gaussgradient(IM,sigma) outputs the gradient image gx and gy of 4 | % image IM using a 2-D Gaussian kernel. Sigma is the standard deviation of 5 | % this kernel along both directions. 6 | % 7 | % Contributed by Guanglei Xiong (xgl99@mails.tsinghua.edu.cn) 8 | % at Tsinghua University, Beijing, China. 9 | 10 | %determine the appropriate size of kernel. The smaller epsilon, the larger 11 | %size. 12 | epsilon=1e-2; 13 | halfsize=ceil(sigma*sqrt(-2*log(sqrt(2*pi)*sigma*epsilon))); 14 | size=2*halfsize+1; 15 | %generate a 2-D Gaussian kernel along x direction 16 | for i=1:size 17 | for j=1:size 18 | u=[i-halfsize-1 j-halfsize-1]; 19 | hx(i,j)=gauss(u(1),sigma)*dgauss(u(2),sigma); 20 | end 21 | end 22 | hx=hx/sqrt(sum(sum(abs(hx).*abs(hx)))); 23 | %generate a 2-D Gaussian kernel along y direction 24 | hy=hx'; 25 | %2-D filtering 26 | gx=imfilter(IM,hx,'replicate','conv'); 27 | gy=imfilter(IM,hy,'replicate','conv'); 28 | 29 | function y = gauss(x,sigma) 30 | %Gaussian 31 | y = exp(-x^2/(2*sigma^2)) / (sigma*sqrt(2*pi)); 32 | 33 | function y = dgauss(x,sigma) 34 | %first order derivative of Gaussian 35 | y = -x * gauss(x,sigma) / sigma^2; -------------------------------------------------------------------------------- /indexnet_matting/evaluation_code/hleval.m: -------------------------------------------------------------------------------- 1 | 2 | clear; close all; clc 3 | 4 | GT_DIR = '/media/hao/DATA/Combined_Dataset'; 5 | RE_DIR = '/home/hao/Pytorch_Codes/ICCV19 Matting Results/Baseline/mobilenet_width_mult_1dot4_unet_decoder_std_conv5x5_stride32'; 6 | DATA_TEST_LIST = '../test.txt'; 7 | 8 | fid = fopen(DATA_TEST_LIST); 9 | imlist = textscan(fid, '%s%s%s', 'Delimiter','\t'); 10 | fclose(fid); 11 | 12 | sad = zeros(length(imlist{1}), 1); 13 | mse = zeros(length(imlist{1}), 1); 14 | grad = zeros(length(imlist{1}), 1); 15 | conn = zeros(length(imlist{1}), 1); 16 | parfor i = 1:length(imlist{1}) 17 | [~, imname, ~] = fileparts(imlist{1}{i}); 18 | 19 | pd = imread(fullfile(RE_DIR, [imname '.png'])); 20 | gt = imread(fullfile(GT_DIR, imlist{2}{i})); 21 | tr = imread(fullfile(GT_DIR, imlist{3}{i})); 22 | 23 | gt = gt(:, :, 1); 24 | 25 | sad(i) = compute_sad_loss(pd, gt, tr); 26 | mse(i) = compute_mse_loss(pd, gt, tr); 27 | grad(i) = compute_gradient_loss(pd, gt, tr) / 1e3; 28 | conn(i) = compute_connectivity_error(pd, gt, tr, 0.1) / 1e3; 29 | 30 | fprintf('test: %d\n', i) 31 | end 32 | 33 | SAD = mean(sad); 34 | MSE = mean(mse); 35 | GRAD = mean(grad); 36 | CONN = mean(conn); 37 | 38 | fprintf('SAD: %.2f, MSE: %.4f, Grad: %.2f, Conn: %.2f\n', ... 39 | SAD, MSE, GRAD, CONN) 40 | 41 | -------------------------------------------------------------------------------- /indexnet_matting/examples/mattes/.ipynb_checkpoints/00298-checkpoint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/indexnet_matting/examples/mattes/.ipynb_checkpoints/00298-checkpoint.png -------------------------------------------------------------------------------- /indexnet_matting/examples/mattes/00298.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/indexnet_matting/examples/mattes/00298.png -------------------------------------------------------------------------------- /indexnet_matting/hlconv.py: -------------------------------------------------------------------------------- 1 | """ 2 | IndexNet Matting 3 | 4 | Indices Matter: Learning to Index for Deep Image Matting 5 | IEEE/CVF International Conference on Computer Vision, 2019 6 | 7 | This software is strictly limited to academic purposes only 8 | Copyright (c) 2019, Hao Lu (hao.lu@adelaide.edu.au) 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | * Redistributions of source code must retain the above copyright notice, this 15 | list of conditions and the following disclaimer. 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """ 31 | 32 | 33 | import torch 34 | import torch.nn as nn 35 | from .lib.nn import SynchronizedBatchNorm2d 36 | 37 | def conv_bn(inp, oup, k=3, s=1, BatchNorm2d=SynchronizedBatchNorm2d): 38 | return nn.Sequential( 39 | nn.Conv2d(inp, oup, k, s, padding=k//2, bias=False), 40 | BatchNorm2d(oup), 41 | nn.ReLU6(inplace=True) 42 | ) 43 | 44 | def dep_sep_conv_bn(inp, oup, k=3, s=1, BatchNorm2d=SynchronizedBatchNorm2d): 45 | return nn.Sequential( 46 | nn.Conv2d(inp, inp, k, s, padding=k//2, groups=inp, bias=False), 47 | BatchNorm2d(inp), 48 | nn.ReLU6(inplace=True), 49 | nn.Conv2d(inp, oup, 1, 1, padding=0, bias=False), 50 | BatchNorm2d(oup), 51 | nn.ReLU6(inplace=True) 52 | ) 53 | 54 | hlconv = { 55 | 'std_conv': conv_bn, 56 | 'dep_sep_conv': dep_sep_conv_bn 57 | } -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/indexnet_matting/lib/nn/.DS_Store -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/indexnet_matting/lib/nn/modules/.DS_Store -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/parallel/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/indexnet_matting/lib/nn/parallel/.DS_Store -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /indexnet_matting/lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /indexnet_matting/lists/generate_imlist.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | num_bgs_train = 100 5 | num_bgs_test = 20 6 | train_data_file = 'train.txt' 7 | test_data_file = 'test.txt' 8 | 9 | train_file = '/media/hao/DATA/Combined_Dataset/Training_set/training_fg_names.txt' 10 | train_file_bg = '/media/hao/DATA/Combined_Dataset/Training_set/training_bg_names.txt' 11 | test_file = '/media/hao/DATA/Combined_Dataset/Test_set/test_fg_names.txt' 12 | fg_names = [name for name in open(train_file).read().splitlines()] 13 | bg_names = [name for name in open(train_file_bg).read().splitlines()] 14 | fg_names_test = [name for name in open(test_file).read().splitlines()] 15 | 16 | image_head = 'merged' 17 | alpha_head = 'alpha' 18 | trimap_head = 'trimaps' 19 | fg_head = 'fg' 20 | bg_head = 'train2014' 21 | 22 | def write_datalist(img_name, bg_names, idx, f): 23 | prefix = 'Training_set' 24 | img_path = os.path.join(prefix, image_head, img_name+'_'+str(idx)+'.png') 25 | msk_path = os.path.join(prefix, alpha_head, img_name+'.jpg') 26 | fg_path = os.path.join(prefix, fg_head, img_name+'.jpg') 27 | bg_path = os.path.join(prefix, bg_head, bg_names) 28 | f.write(img_path+'\t'+msk_path+'\t'+fg_path+'\t'+bg_path+'\n') 29 | 30 | def write_datalist_test(img_name, idx, f): 31 | prefix = 'Test_set' 32 | img_path = os.path.join(prefix, image_head, img_name+'_'+str(idx)+'.png') 33 | msk_path = os.path.join(prefix, alpha_head, img_name+'.png') 34 | trimap_path = os.path.join(prefix, trimap_head, img_name+'_'+str(idx)+'.png') 35 | f.write(img_path+'\t'+msk_path+'\t'+trimap_path+'\n') 36 | 37 | if __name__ == '__main__': 38 | with open(train_data_file, 'w') as f: 39 | count = 0 40 | for name in fg_names: 41 | img_name, ext = os.path.splitext(name) 42 | for idx in range(num_bgs_train): 43 | write_datalist(img_name, bg_names[count], idx, f) 44 | count += 1 45 | 46 | with open(test_data_file, 'w') as f: 47 | for name in fg_names_test: 48 | img_name, ext = os.path.splitext(name) 49 | for idx in range(num_bgs_test): 50 | write_datalist_test(img_name, idx, f) -------------------------------------------------------------------------------- /indexnet_matting/modelsummary.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | import logging 14 | from collections import namedtuple 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | def get_model_summary(model, *input_tensors, item_length=26, verbose=False): 20 | """ 21 | :param model: 22 | :param input_tensors: 23 | :param item_length: 24 | :return: 25 | """ 26 | 27 | summary = [] 28 | 29 | ModuleDetails = namedtuple( 30 | "Layer", ["name", "input_size", "num_parameters", "multiply_adds"]) 31 | hooks = [] 32 | layer_instances = {} 33 | 34 | def add_hooks(module): 35 | 36 | def hook(module, input, output): 37 | class_name = str(module.__class__.__name__) 38 | 39 | instance_index = 1 40 | if class_name not in layer_instances: 41 | layer_instances[class_name] = instance_index 42 | else: 43 | instance_index = layer_instances[class_name] + 1 44 | layer_instances[class_name] = instance_index 45 | 46 | layer_name = class_name + "_" + str(instance_index) 47 | 48 | params = 0 49 | 50 | if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \ 51 | class_name.find("Linear") != -1: 52 | for param_ in module.parameters(): 53 | params += param_.view(-1).size(0) 54 | 55 | flops = "Not Available" 56 | if class_name.find("Conv") != -1 and hasattr(module, "weight"): 57 | flops = ( 58 | torch.prod( 59 | torch.LongTensor(list(module.weight.data.size()))) * 60 | torch.prod( 61 | torch.LongTensor(list(output.size())[2:]))).item() 62 | elif isinstance(module, nn.Linear): 63 | flops = (torch.prod(torch.LongTensor(list(output.size()))) \ 64 | * input[0].size(1)).item() 65 | 66 | if isinstance(input[0], list): 67 | input = input[0] 68 | if isinstance(output, list): 69 | output = output[0] 70 | 71 | summary.append( 72 | ModuleDetails( 73 | name=layer_name, 74 | input_size=list(input[0].size()), 75 | # output_size=list(output.size()), 76 | num_parameters=params, 77 | multiply_adds=flops) 78 | ) 79 | 80 | if not isinstance(module, nn.ModuleList) \ 81 | and not isinstance(module, nn.Sequential) \ 82 | and module != model: 83 | hooks.append(module.register_forward_hook(hook)) 84 | 85 | model.eval() 86 | model.apply(add_hooks) 87 | 88 | space_len = item_length 89 | 90 | model(*input_tensors) 91 | for hook in hooks: 92 | hook.remove() 93 | 94 | details = '' 95 | if verbose: 96 | details = "Model Summary" + \ 97 | os.linesep + \ 98 | "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format( 99 | ' ' * (space_len - len("Name")), 100 | ' ' * (space_len - len("Input Size")), 101 | ' ' * (space_len - len("Output Size")), 102 | ' ' * (space_len - len("Parameters")), 103 | ' ' * (space_len - len("Multiply Adds (Flops)"))) \ 104 | + os.linesep + '-' * space_len * 5 + os.linesep 105 | 106 | params_sum = 0 107 | flops_sum = 0 108 | for layer in summary: 109 | params_sum += layer.num_parameters 110 | if layer.multiply_adds != "Not Available": 111 | flops_sum += layer.multiply_adds 112 | if verbose: 113 | details += "{}{}{}{}{}{}{}{}{}{}".format( 114 | layer.name, 115 | ' ' * (space_len - len(layer.name)), 116 | layer.input_size, 117 | ' ' * (space_len - len(str(layer.input_size))), 118 | # layer.output_size, 119 | # ' ' * (space_len - len(str(layer.output_size))), 120 | layer.num_parameters, 121 | ' ' * (space_len - len(str(layer.num_parameters))), 122 | layer.multiply_adds, 123 | ' ' * (space_len - len(str(layer.multiply_adds)))) \ 124 | + os.linesep + '-' * space_len * 5 + os.linesep 125 | 126 | details += os.linesep \ 127 | + "Total Parameters: {:,}".format(params_sum) \ 128 | + os.linesep + '-' * space_len * 5 + os.linesep 129 | details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \ 130 | + os.linesep + '-' * space_len * 5 + os.linesep 131 | details += "Number of Layers" + os.linesep 132 | for layer in layer_instances: 133 | details += "{} : {} layers ".format(layer, layer_instances[layer]) 134 | 135 | return details -------------------------------------------------------------------------------- /indexnet_matting/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | IndexNet Matting 3 | 4 | Indices Matter: Learning to Index for Deep Image Matting 5 | IEEE/CVF International Conference on Computer Vision, 2019 6 | 7 | This software is strictly limited to academic purposes only 8 | Copyright (c) 2019, Hao Lu (hao.lu@adelaide.edu.au) 9 | All rights reserved. 10 | 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | 14 | * Redistributions of source code must retain the above copyright notice, this 15 | list of conditions and the following disclaimer. 16 | * Redistributions in binary form must reproduce the above copyright notice, 17 | this list of conditions and the following disclaimer in the documentation 18 | and/or other materials provided with the distribution. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | """ 31 | 32 | import numpy as np 33 | import cv2 as cv 34 | from scipy.ndimage import gaussian_filter, morphology 35 | from skimage.measure import label, regionprops 36 | 37 | 38 | # compute the SAD error given a pdiction, a ground truth and a mask 39 | def compute_sad_loss(pd, gt, mask): 40 | cv.normalize(pd, pd, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 41 | cv.normalize(gt, gt, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 42 | error_map = np.abs(pd - gt) / 255. 43 | loss = np.sum(error_map * mask) 44 | # the loss is scaled by 1000 due to the large images 45 | loss = loss / 1000 46 | return loss 47 | 48 | 49 | # compute the MSE error 50 | def compute_mse_loss(pd, gt, mask): 51 | cv.normalize(pd, pd, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 52 | cv.normalize(gt, gt, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 53 | error_map = (pd - gt) / 255. 54 | loss = np.sum(np.square(error_map) * mask) / np.sum(mask) 55 | return loss 56 | 57 | 58 | # compute the gradient error 59 | def compute_gradient_loss(pd, gt, mask): 60 | cv.normalize(pd, pd, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 61 | cv.normalize(gt, gt, 0.0, 255.0, cv.NORM_MINMAX, dtype=cv.CV_32F) 62 | pd = pd / 255. 63 | gt = gt / 255. 64 | pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32) 65 | pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32) 66 | gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32) 67 | gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32) 68 | pd_mag = np.sqrt(pd_x**2 + pd_y**2) 69 | gt_mag = np.sqrt(gt_x**2 + gt_y**2) 70 | 71 | error_map = np.square(pd_mag - gt_mag) 72 | loss = np.sum(error_map * mask) / 10 73 | return loss 74 | 75 | 76 | # compute the connectivity error 77 | def compute_connectivity_loss(pd, gt, mask, step=0.1): 78 | cv.normalize(pd, pd, 0, 255, cv.NORM_MINMAX) 79 | cv.normalize(gt, gt, 0, 255, cv.NORM_MINMAX) 80 | pd = pd / 255. 81 | gt = gt / 255. 82 | 83 | h, w = pd.shape 84 | 85 | thresh_steps = np.arange(0, 1.1, step) 86 | l_map = -1 * np.ones((h, w), dtype=np.float32) 87 | lambda_map = np.ones((h, w), dtype=np.float32) 88 | for i in range(1, thresh_steps.size): 89 | pd_th = pd >= thresh_steps[i] 90 | gt_th = gt >= thresh_steps[i] 91 | 92 | label_image = label(pd_th & gt_th, connectivity=1) 93 | cc = regionprops(label_image) 94 | size_vec = np.array([c.area for c in cc]) 95 | if len(size_vec) == 0: 96 | continue 97 | max_id = np.argmax(size_vec) 98 | coords = cc[max_id].coords 99 | 100 | omega = np.zeros((h, w), dtype=np.float32) 101 | omega[coords[:, 0], coords[:, 1]] = 1 102 | 103 | flag = (l_map == -1) & (omega == 0) 104 | l_map[flag == 1] = thresh_steps[i-1] 105 | 106 | dist_maps = morphology.distance_transform_edt(omega==0) 107 | dist_maps = dist_maps / dist_maps.max() 108 | # lambda_map[flag == 1] = dist_maps.mean() 109 | l_map[l_map == -1] = 1 110 | 111 | # the definition of lambda is ambiguous 112 | d_pd = pd - l_map 113 | d_gt = gt - l_map 114 | # phi_pd = 1 - lambda_map * d_pd * (d_pd >= 0.15).astype(np.float32) 115 | # phi_gt = 1 - lambda_map * d_gt * (d_gt >= 0.15).astype(np.float32) 116 | phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32) 117 | phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32) 118 | loss = np.sum(np.abs(phi_pd - phi_gt) * mask) / 1000 119 | return loss 120 | 121 | 122 | def image_alignment(x, output_stride, odd=False): 123 | imsize = np.asarray(x.shape[:2], dtype=np.float) 124 | if odd: 125 | new_imsize = np.ceil(imsize / output_stride) * output_stride + 1 126 | else: 127 | new_imsize = np.ceil(imsize / output_stride) * output_stride 128 | h, w = int(new_imsize[0]), int(new_imsize[1]) 129 | 130 | x1 = x[:, :, 0:3] 131 | x2 = x[:, :, 3] 132 | new_x1 = cv.resize(x1, dsize=(w,h), interpolation=cv.INTER_CUBIC) 133 | new_x2 = cv.resize(x2, dsize=(w,h), interpolation=cv.INTER_NEAREST) 134 | 135 | new_x2 = np.expand_dims(new_x2, axis=2) 136 | new_x = np.concatenate((new_x1, new_x2), axis=2) 137 | 138 | return new_x 139 | 140 | 141 | def image_rescale(x, scale): 142 | x1 = x[:, :, 0:3] 143 | x2 = x[:, :, 3] 144 | new_x1 = cv.resize(x1, None, fx=scale, fy=scale, interpolation=cv.INTER_CUBIC) 145 | new_x2 = cv.resize(x2, None, fx=scale, fy=scale, interpolation=cv.INTER_NEAREST) 146 | new_x2 = np.expand_dims(new_x2, axis=2) 147 | new_x = np.concatenate((new_x1,new_x2), axis=2) 148 | return new_x 149 | -------------------------------------------------------------------------------- /lpips/.ipynb_checkpoints/base_model-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /lpips/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /lpips/__pycache__/base_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/__pycache__/base_model.cpython-39.pyc -------------------------------------------------------------------------------- /lpips/__pycache__/dist_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/__pycache__/dist_model.cpython-39.pyc -------------------------------------------------------------------------------- /lpips/__pycache__/networks_basic.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/__pycache__/networks_basic.cpython-39.pyc -------------------------------------------------------------------------------- /lpips/__pycache__/pretrained_networks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/__pycache__/pretrained_networks.cpython-39.pyc -------------------------------------------------------------------------------- /lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/op/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /op/__pycache__/fused_act.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/op/__pycache__/fused_act.cpython-39.pyc -------------------------------------------------------------------------------- /op/__pycache__/upfirdn2d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cnnlstm/StyleGAN_Matting/93c13fe0ca1e5371d22de5fa06ba9551a59f62d1/op/__pycache__/upfirdn2d.cpython-39.pyc -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | 'fused', 12 | sources=[ 13 | os.path.join(module_path, 'fused_bias_act.cpp'), 14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 15 | ], 16 | ) 17 | 18 | 19 | class FusedLeakyReLUFunctionBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, out, negative_slope, scale): 22 | ctx.save_for_backward(out) 23 | ctx.negative_slope = negative_slope 24 | ctx.scale = scale 25 | 26 | empty = grad_output.new_empty(0) 27 | 28 | grad_input = fused.fused_bias_act( 29 | grad_output, empty, out, 3, 1, negative_slope, scale 30 | ) 31 | 32 | dim = [0] 33 | 34 | if grad_input.ndim > 2: 35 | dim += list(range(2, grad_input.ndim)) 36 | 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | return grad_input, grad_bias 40 | 41 | @staticmethod 42 | def backward(ctx, gradgrad_input, gradgrad_bias): 43 | out, = ctx.saved_tensors 44 | gradgrad_out = fused.fused_bias_act( 45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 46 | ) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | @staticmethod 53 | def forward(ctx, input, bias, negative_slope, scale): 54 | empty = input.new_empty(0) 55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 56 | ctx.save_for_backward(out) 57 | ctx.negative_slope = negative_slope 58 | ctx.scale = scale 59 | 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | out, = ctx.saved_tensors 65 | 66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 67 | grad_output, out, ctx.negative_slope, ctx.scale 68 | ) 69 | 70 | return grad_input, grad_bias, None, None 71 | 72 | 73 | class FusedLeakyReLU(nn.Module): 74 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | 81 | def forward(self, input): 82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 83 | 84 | 85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 87 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | 8 | module_path = os.path.dirname(__file__) 9 | upfirdn2d_op = load( 10 | 'upfirdn2d', 11 | sources=[ 12 | os.path.join(module_path, 'upfirdn2d.cpp'), 13 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class UpFirDn2dBackward(Function): 19 | @staticmethod 20 | def forward( 21 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 22 | ): 23 | 24 | up_x, up_y = up 25 | down_x, down_y = down 26 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 27 | 28 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 29 | 30 | grad_input = upfirdn2d_op.upfirdn2d( 31 | grad_output, 32 | grad_kernel, 33 | down_x, 34 | down_y, 35 | up_x, 36 | up_y, 37 | g_pad_x0, 38 | g_pad_x1, 39 | g_pad_y0, 40 | g_pad_y1, 41 | ) 42 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 43 | 44 | ctx.save_for_backward(kernel) 45 | 46 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 47 | 48 | ctx.up_x = up_x 49 | ctx.up_y = up_y 50 | ctx.down_x = down_x 51 | ctx.down_y = down_y 52 | ctx.pad_x0 = pad_x0 53 | ctx.pad_x1 = pad_x1 54 | ctx.pad_y0 = pad_y0 55 | ctx.pad_y1 = pad_y1 56 | ctx.in_size = in_size 57 | ctx.out_size = out_size 58 | 59 | return grad_input 60 | 61 | @staticmethod 62 | def backward(ctx, gradgrad_input): 63 | kernel, = ctx.saved_tensors 64 | 65 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 66 | 67 | gradgrad_out = upfirdn2d_op.upfirdn2d( 68 | gradgrad_input, 69 | kernel, 70 | ctx.up_x, 71 | ctx.up_y, 72 | ctx.down_x, 73 | ctx.down_y, 74 | ctx.pad_x0, 75 | ctx.pad_x1, 76 | ctx.pad_y0, 77 | ctx.pad_y1, 78 | ) 79 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 80 | gradgrad_out = gradgrad_out.view( 81 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 82 | ) 83 | 84 | return gradgrad_out, None, None, None, None, None, None, None, None 85 | 86 | 87 | class UpFirDn2d(Function): 88 | @staticmethod 89 | def forward(ctx, input, kernel, up, down, pad): 90 | up_x, up_y = up 91 | down_x, down_y = down 92 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 93 | 94 | kernel_h, kernel_w = kernel.shape 95 | batch, channel, in_h, in_w = input.shape 96 | ctx.in_size = input.shape 97 | 98 | input = input.reshape(-1, in_h, in_w, 1) 99 | 100 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 101 | 102 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 103 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 104 | ctx.out_size = (out_h, out_w) 105 | 106 | ctx.up = (up_x, up_y) 107 | ctx.down = (down_x, down_y) 108 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 109 | 110 | g_pad_x0 = kernel_w - pad_x0 - 1 111 | g_pad_y0 = kernel_h - pad_y0 - 1 112 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 113 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 114 | 115 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 116 | 117 | out = upfirdn2d_op.upfirdn2d( 118 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 119 | ) 120 | # out = out.view(major, out_h, out_w, minor) 121 | out = out.view(-1, channel, out_h, out_w) 122 | 123 | return out 124 | 125 | @staticmethod 126 | def backward(ctx, grad_output): 127 | kernel, grad_kernel = ctx.saved_tensors 128 | 129 | grad_input = UpFirDn2dBackward.apply( 130 | grad_output, 131 | kernel, 132 | grad_kernel, 133 | ctx.up, 134 | ctx.down, 135 | ctx.pad, 136 | ctx.g_pad, 137 | ctx.in_size, 138 | ctx.out_size, 139 | ) 140 | 141 | return grad_input, None, None, None, None 142 | 143 | 144 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 145 | out = UpFirDn2d.apply( 146 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 147 | ) 148 | 149 | return out 150 | 151 | 152 | def upfirdn2d_native( 153 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 154 | ): 155 | _, in_h, in_w, minor = input.shape 156 | kernel_h, kernel_w = kernel.shape 157 | 158 | out = input.view(-1, in_h, 1, in_w, 1, minor) 159 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 160 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 161 | 162 | out = F.pad( 163 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 164 | ) 165 | out = out[ 166 | :, 167 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 168 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 169 | :, 170 | ] 171 | 172 | out = out.permute(0, 3, 1, 2) 173 | out = out.reshape( 174 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 175 | ) 176 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 177 | out = F.conv2d(out, w) 178 | out = out.reshape( 179 | -1, 180 | minor, 181 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 182 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 183 | ) 184 | out = out.permute(0, 2, 3, 1) 185 | 186 | return out[:, ::down_y, ::down_x, :] 187 | 188 | -------------------------------------------------------------------------------- /sgmatting.yaml: -------------------------------------------------------------------------------- 1 | name: sg_matting 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch 4 | - pytorch 5 | - conda-forge 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 11 | - defaults 12 | dependencies: 13 | - _libgcc_mutex=0.1=conda_forge 14 | - _openmp_mutex=4.5=2_gnu 15 | - binutils_impl_linux-64=2.36.1=h193b22a_2 16 | - blas=1.0=mkl 17 | - brotlipy=0.7.0=py39hb9d737c_1004 18 | - bzip2=1.0.8=h7f98852_4 19 | - ca-certificates=2022.9.24=ha878542_0 20 | - cffi=1.15.1=py39he91dace_0 21 | - charset-normalizer=2.1.1=pyhd8ed1ab_0 22 | - cryptography=37.0.1=py39h9ce1e76_0 23 | - cudatoolkit=11.3.1=h9edb442_10 24 | - dlib=19.24.0=py39h9ea2e13_0 25 | - ffmpeg=4.3=hf484d3e_0 26 | - freetype=2.12.1=hca18f0e_0 27 | - gcc=9.4.0=h192d537_10 28 | - gcc_impl_linux-64=9.4.0=h03d3576_16 29 | - gmp=6.2.1=h58526e2_0 30 | - gnutls=3.6.13=h85f3911_1 31 | - gxx=9.4.0=h192d537_10 32 | - gxx_impl_linux-64=9.4.0=h03d3576_16 33 | - intel-openmp=2022.1.0=h9e868ea_3769 34 | - jpeg=9e=h166bdaf_2 35 | - kernel-headers_linux-64=2.6.32=he073ed8_15 36 | - lame=3.100=h7f98852_1001 37 | - lcms2=2.12=hddcbb42_0 38 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 39 | - lerc=4.0.0=h27087fc_0 40 | - libblas=3.9.0=16_linux64_mkl 41 | - libcblas=3.9.0=16_linux64_mkl 42 | - libdeflate=1.14=h166bdaf_0 43 | - libffi=3.4.2=h7f98852_5 44 | - libgcc-devel_linux-64=9.4.0=hd854feb_16 45 | - libgcc-ng=12.1.0=h8d9b700_16 46 | - libgomp=12.1.0=h8d9b700_16 47 | - libiconv=1.17=h166bdaf_0 48 | - liblapack=3.9.0=16_linux64_mkl 49 | - libnsl=2.0.0=h7f98852_0 50 | - libpng=1.6.38=h753d276_0 51 | - libsanitizer=9.4.0=h79bfe98_16 52 | - libsqlite=3.39.3=h753d276_0 53 | - libstdcxx-devel_linux-64=9.4.0=hd854feb_16 54 | - libstdcxx-ng=12.1.0=ha89aaad_16 55 | - libtiff=4.4.0=h55922b4_4 56 | - libuuid=2.32.1=h7f98852_1000 57 | - libuv=1.44.2=h166bdaf_0 58 | - libwebp-base=1.2.4=h166bdaf_0 59 | - libxcb=1.13=h7f98852_1004 60 | - libzlib=1.2.12=h166bdaf_3 61 | - mkl=2022.1.0=hc2b9512_224 62 | - ncurses=6.3=h27087fc_1 63 | - nettle=3.6=he412f7d_0 64 | - numpy=1.23.3=py39hba7629e_0 65 | - openh264=2.1.1=h4ff587b_0 66 | - openjpeg=2.5.0=h7d73246_1 67 | - openssl=3.0.5=h166bdaf_2 68 | - pillow=9.2.0=py39hd5dbb17_2 69 | - pip=22.2.2=pyhd8ed1ab_0 70 | - pthread-stubs=0.4=h36c2ea0_1001 71 | - pycparser=2.21=pyhd8ed1ab_0 72 | - pyopenssl=22.0.0=pyhd8ed1ab_0 73 | - pysocks=1.7.1=pyha2e5f31_6 74 | - python=3.9.13=h2660328_0_cpython 75 | - python_abi=3.9=2_cp39 76 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 77 | - pytorch-mutex=1.0=cuda 78 | - readline=8.1.2=h0f457ee_0 79 | - requests=2.28.1=pyhd8ed1ab_1 80 | - setuptools=65.3.0=pyhd8ed1ab_1 81 | - sqlite=3.39.3=h4ff8645_0 82 | - sysroot_linux-64=2.12=he073ed8_15 83 | - tk=8.6.12=h27826a3_0 84 | - torchaudio=0.11.0=py39_cu113 85 | - typing_extensions=4.3.0=pyha770c72_0 86 | - tzdata=2022c=h191b570_0 87 | - wheel=0.37.1=pyhd8ed1ab_0 88 | - xorg-libxau=1.0.9=h7f98852_0 89 | - xorg-libxdmcp=1.1.3=h7f98852_0 90 | - xz=5.2.6=h166bdaf_0 91 | - zlib=1.2.12=h166bdaf_3 92 | - zstd=1.5.2=h6239696_4 93 | - pip: 94 | - absl-py==1.2.0 95 | - asttokens==2.0.8 96 | - backcall==0.2.0 97 | - beautifulsoup4==4.11.1 98 | - cachetools==5.2.0 99 | - certifi==2022.9.14 100 | - click==8.1.3 101 | - contourpy==1.0.5 102 | - cupy-cuda113==10.6.0 103 | - cycler==0.11.0 104 | - decorator==5.1.1 105 | - executing==1.1.0 106 | - fastrlock==0.8 107 | - filelock==3.8.0 108 | - fonttools==4.37.3 109 | - gdown==4.5.1 110 | - google-auth==2.11.0 111 | - google-auth-oauthlib==0.4.6 112 | - grpcio==1.49.0 113 | - idna==3.4 114 | - imageio==2.22.0 115 | - importlib-metadata==4.12.0 116 | - ipython==8.5.0 117 | - jedi==0.18.1 118 | - kiwisolver==1.4.4 119 | - markdown==3.4.1 120 | - markupsafe==2.1.1 121 | - matplotlib==3.6.0 122 | - matplotlib-inline==0.1.6 123 | - mypy-extensions==0.4.3 124 | - networkx==2.8.7 125 | - ninja==1.10.2.3 126 | - oauthlib==3.2.1 127 | - opencv-python==4.6.0.66 128 | - packaging==21.3 129 | - parso==0.8.3 130 | - pexpect==4.8.0 131 | - pickleshare==0.7.5 132 | - prompt-toolkit==3.0.31 133 | - protobuf==3.19.5 134 | - psutil==5.9.2 135 | - ptyprocess==0.7.0 136 | - pure-eval==0.2.2 137 | - pyasn1==0.4.8 138 | - pyasn1-modules==0.2.8 139 | - pygments==2.13.0 140 | - pyparsing==3.0.9 141 | - pyrallis==0.3.1 142 | - python-dateutil==2.8.2 143 | - pytorch-fid==0.2.1 144 | - pywavelets==1.4.1 145 | - pyyaml==6.0 146 | - requests-oauthlib==1.3.1 147 | - rsa==4.9 148 | - scikit-image==0.19.3 149 | - scipy==1.9.1 150 | - six==1.16.0 151 | - soupsieve==2.3.2.post1 152 | - stack-data==0.5.1 153 | - tdqm==0.0.1 154 | - tensorboard==2.10.0 155 | - tensorboard-data-server==0.6.1 156 | - tensorboard-plugin-wit==1.8.1 157 | - tensorboardx==2.5.1 158 | - tifffile==2022.10.10 159 | - torch==1.10.1 160 | - torchvision==0.11.2 161 | - tqdm==4.64.1 162 | - traitlets==5.4.0 163 | - typing-inspect==0.8.0 164 | - urllib3==1.26.12 165 | - wcwidth==0.2.5 166 | - werkzeug==2.2.2 167 | - wget==3.2 168 | - zipp==3.8.1 169 | prefix: /home/xxxx/anaconda3/envs/sg_matting 170 | -------------------------------------------------------------------------------- /stylegan++.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import torch 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torchvision import transforms, utils 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | import lpips 13 | from model import Generator 14 | from compare import compare 15 | 16 | def make_image(tensor): 17 | return ( 18 | tensor.detach() 19 | .clamp_(min=-1, max=1) 20 | .add(1) 21 | .div_(2) 22 | .mul(255) 23 | .type(torch.uint8) 24 | .permute(0, 2, 3, 1) 25 | .to('cpu') 26 | .numpy() 27 | .squeeze(0) 28 | ) 29 | 30 | 31 | if __name__ == '__main__': 32 | device = 'cuda' 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--ckpt', type=str, default='stylegan2-ffhq-config-f.pt') 36 | parser.add_argument('--size', type=int, default=1024) 37 | parser.add_argument('--w_lr', type=float, default=0.01) 38 | parser.add_argument('--n_lr', type=float, default=5) 39 | parser.add_argument('--w_step', type=int, default=5000) 40 | parser.add_argument('--n_step', type=int, default=3000) 41 | parser.add_argument('--is_face', action='store_true') 42 | parser.add_argument('files', metavar='FILES', nargs='+') 43 | 44 | args = parser.parse_args() 45 | n_mean_latent = 10000 46 | 47 | resize = args.size 48 | 49 | transform = transforms.Compose( 50 | [ 51 | transforms.Resize((resize,resize)), 52 | transforms.CenterCrop(resize), 53 | transforms.ToTensor(), 54 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 55 | ] 56 | ) 57 | 58 | g_ema = Generator(args.size, 512, 8) 59 | g_ema.load_state_dict(torch.load(args.ckpt)['g_ema'], strict=False) 60 | g_ema.eval() 61 | g_ema = g_ema.to(device) 62 | 63 | percept = lpips.PerceptualLoss( 64 | model='net-lin', net='vgg', use_gpu=device.startswith('cuda') 65 | ) 66 | 67 | for file_idx, imgfile in enumerate(args.files): 68 | with torch.no_grad(): 69 | if args.is_face: 70 | noise_sample = torch.randn(n_mean_latent, 512, device=device) 71 | latent_out = g_ema.style(noise_sample) 72 | latent_in = latent_out.mean(0).detach().clone().unsqueeze(0).unsqueeze(0).repeat(1, g_ema.n_latent, 1).cuda() 73 | else: 74 | # uniform initialize 75 | latent_in = torch.FloatTensor(g_ema.n_latent, 512).unsqueeze(0).uniform_(-1, 1).cuda() 76 | 77 | noises = g_ema.make_noise() 78 | optimizer_w = optim.Adam([latent_in], lr=args.w_lr) 79 | optimizer_n = optim.Adam(noises, lr=args.n_lr) 80 | 81 | img = transform(Image.open(imgfile).convert('RGB')).unsqueeze(0).cuda() 82 | img_list = [img.detach().cpu()] 83 | 84 | w_pbar = tqdm(range(args.w_step)) 85 | n_pbar = tqdm(range(args.n_step)) 86 | for i in w_pbar: 87 | # fix n 88 | latent_in.requires_grad = True 89 | for noise in noises: 90 | noise.requires_grad = False 91 | 92 | img_gen, _ = g_ema([latent_in], input_is_latent=True, noise=noises) 93 | batch, channel, height, width = img_gen.shape 94 | 95 | if height > 256: 96 | factor = height // 256 97 | p_img_gen = img_gen.reshape( 98 | batch, channel, height // factor, factor, width // factor, factor 99 | ) 100 | p_img_gen = p_img_gen.mean([3, 5]) 101 | p_img = img.reshape( 102 | batch, channel, height // factor, factor, width // factor, factor 103 | ) 104 | p_img = p_img.mean([3, 5]) 105 | 106 | 107 | p_loss = percept(p_img_gen, p_img).sum() 108 | mse_loss = F.mse_loss(img_gen, img) 109 | 110 | loss = 0.01 * p_loss + 1 * mse_loss 111 | 112 | optimizer_w.zero_grad() 113 | loss.backward() 114 | optimizer_w.step() 115 | w_pbar.set_description( 116 | ( 117 | f'perceptual: {p_loss.item():.4f}; mse: {mse_loss.item():.4f}' 118 | ) 119 | ) 120 | img_list.append(img_gen.detach().cpu()) 121 | for i in n_pbar: 122 | # fix w 123 | latent_in.requires_grad = False 124 | for noise in noises: 125 | noise.requires_grad = True 126 | 127 | img_gen, _ = g_ema([latent_in], input_is_latent=True, noise=noises) 128 | 129 | mse_loss = F.mse_loss(img_gen, img) 130 | loss = 0.1 * mse_loss 131 | 132 | optimizer_n.zero_grad() 133 | loss.backward() 134 | optimizer_n.step() 135 | n_pbar.set_description( 136 | ( 137 | f'mse: {mse_loss.item():.4f}' 138 | ) 139 | ) 140 | img_name = os.path.basename(imgfile).split(".")[0] 141 | arr = {"latent_in": latent_in, "noises":noises} 142 | torch.save(arr, "dataset/testing/param/" + img_name +".pt") 143 | 144 | 145 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py --name test --iter 501 --optimize_choices both --backbone gca --mse 10 --g 1 --e 1 --------------------------------------------------------------------------------