├── .gitignore ├── LICENSE ├── README.md ├── core ├── README.md ├── config.py ├── config_yacs.py ├── configs │ ├── ResNet34.yaml │ ├── Swin_T.yaml │ └── ViTAE_S.yaml ├── data.py ├── data_util.py ├── eval.py ├── evaluate.py ├── infer.py ├── logger.py ├── metrics.py ├── network │ ├── RestNet34 │ │ ├── P3mNet.py │ │ ├── __init__.py │ │ └── resnet_mp.py │ ├── Swin_T │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── swin_stem_pooling5_transformer.py │ ├── ViTAE_S │ │ ├── NormalCell.py │ │ ├── SELayer.py │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── models.py │ │ ├── token_performer.py │ │ └── token_transformer.py │ ├── __init__.py │ ├── build_model.py │ └── modules.py ├── test.py ├── train.py └── util.py ├── demo ├── README.md ├── face_obfuscation.jpg ├── gif │ ├── 2.gif │ ├── 3.gif │ ├── 4.gif │ └── p_3cf7997c.gif ├── imgs │ ├── alpha │ │ ├── p_07141906.png │ │ ├── p_09ba26b4.png │ │ ├── p_22fc0130.png │ │ ├── p_514ca06a.png │ │ ├── p_51c916fc.png │ │ ├── p_818e689d.png │ │ ├── p_a09f6d7a.png │ │ ├── p_bac3c1ff.png │ │ ├── p_bc5cfad1.png │ │ ├── p_bd6af989.png │ │ └── p_d684dae3.png │ ├── color │ │ ├── p_07141906.png │ │ ├── p_09ba26b4.png │ │ ├── p_22fc0130.png │ │ ├── p_514ca06a.png │ │ ├── p_51c916fc.png │ │ ├── p_818e689d.png │ │ ├── p_a09f6d7a.png │ │ ├── p_bac3c1ff.png │ │ ├── p_bc5cfad1.png │ │ ├── p_bd6af989.png │ │ └── p_d684dae3.png │ └── original │ │ ├── p_07141906.jpg │ │ ├── p_09ba26b4.jpg │ │ ├── p_22fc0130.jpg │ │ ├── p_514ca06a.jpg │ │ ├── p_51c916fc.jpg │ │ ├── p_818e689d.jpg │ │ ├── p_a09f6d7a.jpg │ │ ├── p_bac3c1ff.jpg │ │ ├── p_bc5cfad1.jpg │ │ ├── p_bd6af989.jpg │ │ └── p_d684dae3.jpg ├── network.png ├── p3m-cp.png ├── p3m-net-variants.png ├── p3m_dataset.png ├── results1.png └── results2.png ├── requirements.txt ├── samples ├── original │ ├── p_015cd10e.jpg │ ├── p_0865636e.jpg │ └── p_819ea202.jpg ├── result_alpha │ ├── p_015cd10e.png │ ├── p_0865636e.png │ └── p_819ea202.png └── result_color │ ├── p_015cd10e.png │ ├── p_0865636e.png │ └── p_819ea202.png └── scripts ├── eval.sh ├── test.sh ├── test_dataset.sh ├── test_samples.sh └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | */DS_Store 3 | __pycache__/ 4 | */__pycache__ 5 | *.html 6 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sihan Ma and Jizhizi Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Rethinking Portrait Matting with Privacy Preserving [IJCV-2023]

2 | 3 |

4 | 5 | 6 | 7 | 8 |

9 | 10 | 11 |

This is the official repository of the paper [IJCV'23] Rethinking Portrait Matting with Privacy Preserving. 12 | 13 | For further questions, please contact Sihan Ma at [sima7436@uni.sydney.edu.au](mailto:sima7436@uni.sydney.edu.au) or Jizhizi Li at [jili8515@uni.sydney.edu.au](mailto:jili8515@uni.sydney.edu.au).

14 | 15 | 16 |
Sihan Ma, Jizhizi Li, Jing Zhang, He Zhang, and Dacheng Tao
17 | 18 |

19 | Introduction | 20 | PPT and P3M-10k | 21 | P3M-Net | 22 | P3M-CP | 23 | Results | 24 | Train | 25 | Inference code | 26 | Statement 27 |

28 | 29 | 30 | 31 | *** 32 | >

:postbox: News

33 | > 34 | > [2024-3-31]: [Code for training](./core/README.md) is available! 35 | > 36 | > [2023-11-05]: Publish the ViTAE-S and SWIN-T backbone models pretrained on ImageNet that can be used to train our P3M-Net from scratch. 37 | > 38 | > [2023-03-28]: The paper has been accepted by the International Journal of Computer Vision ([IJCV](https://www.springer.com/journal/11263))! 🎉 39 | > 40 | > [2022-03-31]: Publish the inference code and the pretrained model that can be used to test with our SOTA model P3M-Net(ViTAE-S) on your own privacy-preserving or normal portrait images. 41 | > 42 | > [2021-12-06]: Publish the [P3M-10k](#ppt-setting-and-p3m-10k-dataset) dataset. 43 | > 44 | > [2021-11-21]: Publish the conference paper ACM MM 2021 "[Privacy-preserving Portrait Matting](https://dl.acm.org/doi/10.1145/3474085.3475512)". The code and data are available at [github repo](https://github.com/JizhiziLi/P3M). 45 | > 46 | > Other applications of ViTAE Transformer include: [image classification](https://github.com/ViTAE-Transformer/ViTAE-Transformer/tree/main/Image-Classification) | [object detection](https://github.com/ViTAE-Transformer/ViTAE-Transformer/tree/main/Object-Detection) | [semantic segmentation](https://github.com/ViTAE-Transformer/ViTAE-Transformer/tree/main/Semantic-Segmentation) | [animal pose segmentation](https://github.com/ViTAE-Transformer/ViTAE-Transformer/tree/main/Animal-Pose-Estimation) | [remote sensing](https://github.com/ViTAE-Transformer/ViTAE-Transformer-Remote-Sensing) 47 | 48 | 49 | ## Introduction 50 | 51 | 52 |

Recently, there has been an increasing concern about the privacy issue raised by using personally identifiable information in machine learning. However, previous portrait matting methods were all based on identifiable portrait images.

53 | 54 |

To fill the gap, we present P3M-10k in this paper, which is the first large-scale anonymized benchmark for Privacy-Preserving Portrait Matting. P3M-10k consists of 10,000 high-resolution face-blurred portrait images along with high-quality alpha mattes. We systematically evaluate both trimap-free and trimap-based matting methods on P3M-10k and find that existing matting methods show different generalization capabilities when following the Privacy-Preserving Training (PPT) setting, 𝑖.𝑒., training on face-blurred images and testing on arbitrary images.

55 | 56 |

To devise a better trimap-free portrait matting model, we propose P3M-Net, consisting of three carefully designed integration modules that can perform privacy-insensitive semantic perception and detail-reserved matting simultaneously. We further design multiple variants of P3MNet with different CNN and transformer backbones and identify the difference of their generalization abilities.

57 | 58 |

To further mitigate this issue, we devise a simple yet effective Copy and Paste strategy (P3M-CP) that can borrow facial information from public celebrity images without privacy concerns and direct the network to reacquire the face context at both data and feature level. P3M-CP only brings a few additional computations during training, while enabling the matting model to process both face-blurred and normal images without extra effort during inference.

59 | 60 |

Extensive experiments on P3M-10k demonstrate the superiority of P3M-Net over state-of-the-art methods and the effectiveness of P3MCP in improving the generalization ability of P3M-Net, implying a great significance of P3M for future research and real-world applications.

61 | 62 | 63 | ## PPT Setting and P3M-10k Dataset 64 | 65 | 66 |

PPT Setting: Due to the privacy concern, we propose the Privacy-Preserving Training (PPT) setting in portrait matting, 𝑖.𝑒., training on privacy-preserved images (𝑒.𝑔., processed by face obfuscation) and testing on arbitraty images with or without privacy content. As an initial step towards privacy-preserving portrait matting problem, we only define the identifiable faces in frontal and some profile portrait images as the private content in this work.

67 | 68 | 69 |

P3M-10k Dataset: To further explore the effect of PPT setting, we establish the first large-scale privacy-preserving portrait matting benchmark named P3M-10k. It contains 10,000 annonymized high-resolution portrait images by face obfuscation along with high-quality ground truth alpha mattes. Specifically, we carefully collect, filter, and annotate about 10,000 high-resolution images from the Internet with free use license. There are 9,421 images in the training set and 500 images in the test set, denoted as P3M-500-P. In addition, we also collect and annotate another 500 public celebrity images from the Internet without face obfuscation, to evaluate the performance of matting models under the PPT setting on normal portrait images, denoted as P3M-500-NP. We show some examples as below, where (a) is from the training set, (b) is from P3M-500-P, and (c) is from P3M-500-NP.

70 | 71 | 72 | P3M-10k and the facemask are now published!! You can get access to it from the following links, please make sure that you have read and agreed to the agreement. Note that the facemask is not used in our work. So it's optional to download it. 73 | 74 | 78 | 79 | | Dataset |

Dataset Link
(Google Drive)

|

Dataset Link
(Baidu Wangpan 百度网盘)

| Dataset Release Agreement| 80 | | :----:| :----: | :----: | :----: | 81 | |P3M-10k|[Link](https://drive.google.com/file/d/1odzHp2zbQApLm90HH_Cvr5b5OwJVhEQG/view?usp=sharing)|[Link](https://pan.baidu.com/s/1aEmEXO5BflSp5hiA-erVBA?pwd=cied) (pw: cied) |[Agreement (MIT License)](https://jizhizili.github.io/files/p3m_dataset_agreement/P3M-10k_Dataset_Release_Agreement.pdf)| 82 | 83 | 84 | 85 | ![](demo/p3m_dataset.png) 86 | 87 | ## P3M-Net and Variants 88 | 89 | ![](demo/network.png) 90 | 91 |

Our P3M-Net network models the comprehensive interactions between the sharing encoder and two decoders through three carefully designed integration modules, i.e., 1) a tripartite-feature integration (TFI) module to enable the interaction between encoder and two decoders; 2) a deep bipartite-feature integration (dBFI) module to enhance the interaction between the encoder and segmentation decoder; and 3) a shallow bipartitefeature integration (sBFI) module to promote the interaction between the encoder and matting decoder.

92 | 93 |

We design three variants of P3M Basic Blocks based on CNN and vision transformers, namely P3M-Net(ResNet-34), P3M-Net(Swin-T), P3M-Net(ViTAE-S). We leverage the ability of transformers in modeling long-range dependency to extract more accurate global information and the locality modelling ability to reserve lots of details in the transition areas. The structures are shown in the following figures.

94 | 95 | ![](demo/p3m-net-variants.png) 96 | 97 | 103 | 104 | 105 | ## P3M-CP 106 | 107 | 108 |

To further improve the generalization ability of P3M-Net, we devise a simple yet effective Copy and Paste strategy (P3M-CP) that can borrow facial information from publicly available celebrity images without privacy concerns and guide the network to reacquire the face context at both data and feature level, namely P3M-ICP and P3M-FCP. The pipeline is shown in the following figure.

109 | 110 | ![](demo/p3m-cp.png) 111 | 112 | ## Results 113 | 114 |

We test our network on our proposed P3M-500-P and P3M-500-NP and compare with previous SOTA methods, we list the results as below.

115 | 116 | ![](demo/results1.png) 117 | ![](demo/results2.png) 118 | 119 | ## Quick start - Train 120 | 121 | Please follow this [instruction page](./core/README.md). 122 | 123 | ## Inference Code - How to Test on Your Images 124 | 125 |

Here we provide the procedure of testing on sample images by our pretrained P3M-Net(ViTAE-S) model:

126 | 127 | 1. Setup environment following this [instruction page](./core/README.md); 128 | 129 | 2. Insert the path `REPOSITORY_ROOT_PATH` in the file `core/config.py`; 130 | 131 | 3. Download the pretrained P3M-Net(ViTAE-S) model from here ([Google Drive](https://drive.google.com/file/d/1QbSjPA_Mxs7rITp_a9OJiPeFRDwxemqK/view?usp=sharing) | [Baidu Wangpan](https://pan.baidu.com/s/19FuiR1RwamqvxfhdXDL1fg) (pw: hxxy))) and unzip to the folder `models/pretrained/`; 132 | 133 | 4. Save your sample images in folder `samples/original/.`; 134 | 135 | 5. Setup parameters in the file `scripts/test_samples.sh` and run by: 136 | 137 | `chmod +x scripts/test_samples.sh` 138 | 139 | `scripts/test_samples.sh`; 140 | 141 | 6. The results of alpha matte and transparent color image will be saved in folder `samples/result_alpha/.` and `samples/result_color/.`. 142 | 143 |

We show some sample images, the predicted alpha mattes, and their transparent results as below. We use the pretrained P3M-Net(ViTAE-S) model from section P3M-Net and Variants with `RESIZE` test strategy.

144 | 145 | 146 | 147 | 148 | 149 | 150 | ## Statement 151 | 152 |

If you are interested in our work, please consider citing the following:

153 | 154 | ``` 155 | @article{rethink_p3m, 156 | title={Rethinking Portrait Matting with Pirvacy Preserving}, 157 | author={Ma, Sihan and Li, Jizhizi and Zhang, Jing and Zhang, He and Tao, Dacheng}, 158 | journal={International Journal of Computer Vision}, 159 | publisher={Springer}, 160 | ISSN={1573-1405}, 161 | year={2023} 162 | } 163 | ``` 164 | 165 |

This project is under MIT licence.

166 | 167 | For further questions, please contact Sihan Ma at [sima7436@uni.sydney.edu.au](mailto:sima7436@uni.sydney.edu.au) or Jizhizi Li at [jili8515@uni.sydney.edu.au](mailto:jili8515@uni.sydney.edu.au). 168 | 169 | 170 | ## Relevant Projects 171 | 172 | 173 |

174 | 175 | [1] Deep Automatic Natural Image Matting, IJCAI, 2021 | [Paper](https://www.ijcai.org/proceedings/2021/0111.pdf) | [Github](https://github.com/JizhiziLi/AIM) 176 |
     Jizhizi Li, Jing Zhang, and Dacheng Tao 177 | 178 | [2] Privacy-preserving Portrait Matting, ACM MM, 2021 | [Paper](https://dl.acm.org/doi/10.1145/3474085.3475512) | [Github](https://github.com/JizhiziLi/P3M) 179 |
     Jizhizi Li, Sihan Ma, Jing Zhang, Dacheng Tao 180 | 181 | [3] Bridging Composite and Real: Towards End-to-end Deep Image Matting, IJCV, 2022 | [Paper](https://link.springer.com/article/10.1007/s11263-021-01541-0) | [Github](https://github.com/JizhiziLi/GFM) 182 |
     Jizhizi Li, Jing Zhang, Stephen J. Maybank, Dacheng Tao 183 | 184 | [4] Referring Image Matting, CVPR, 2023 | [Paper](https://arxiv.org/pdf/2206.05149.pdf) | [Github](https://github.com/JizhiziLi/RIM) 185 |
     Jizhizi Li, Jing Zhang, and Dacheng Tao 186 | 187 | 188 | [5] Deep Image Matting: A Comprehensive Survey, ArXiv, 2023 | [Paper](https://arxiv.org/abs/2304.04672) | [Github](https://github.com/jizhiziLi/matting-survey) 189 |
     Jizhizi Li, Jing Zhang, and Dacheng Tao -------------------------------------------------------------------------------- /core/README.md: -------------------------------------------------------------------------------- 1 |

Rethinking Portrait Matting with Privacy Preserving

2 | 3 |

4 | Installation | 5 | Prepare Datasets | 6 | Pretrained Models | 7 | Train on P3M-10k | 8 | Test | 9 | Inference 10 |

11 | 12 | 13 | ## Installation 14 | Requirements: 15 | 16 | - Python 3.7.7+ with Numpy and scikit-image 17 | - Pytorch (version>=1.7.1) 18 | - Torchvision (version 0.8.2) 19 | 20 | 1. Clone this repository 21 | 22 | `git clone https://github.com/ViTAE-Transformer/P3M-Net.git`; 23 | 24 | 2. Go into the repository 25 | 26 | `cd P3M-Net`; 27 | 28 | 3. Create conda environment and activate 29 | 30 | `conda create -n p3m python=3.7.7`, 31 | 32 | `conda activate p3m`; 33 | 34 | 4. Install dependencies, install pytorch and torchvision separately if you need 35 | 36 | `pip install -r requirements.txt`, 37 | 38 | `conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch`. 39 | 40 | Our code has been tested with Python 3.7.7, Pytorch 1.7.1, Torchvision 0.8.2, CUDA 10.2 on Ubuntu 18.04. 41 | 42 | 43 | ## Prepare Datasets 44 | 45 | 49 | 50 | | Dataset |

Dataset Link
(Google Drive)

|

Dataset Link
(Baidu Wangpan 百度网盘)

| Dataset Release Agreement| 51 | | :----:| :----: | :----: | :----: | 52 | |P3M-10k|[Link](https://drive.google.com/file/d/1odzHp2zbQApLm90HH_Cvr5b5OwJVhEQG/view?usp=sharing)|[Link](https://pan.baidu.com/s/1aEmEXO5BflSp5hiA-erVBA?pwd=cied) (pw: cied) |[Agreement (MIT License)](https://jizhizili.github.io/files/p3m_dataset_agreement/P3M-10k_Dataset_Release_Agreement.pdf)| 53 | 54 | 55 | 1. Download the datasets P3M-10k from the above links and unzip to the folders `P3M_DATASET_ROOT_PATH`, set up the configuratures in the file `core/config.py`. Please make sure that you have checked out and agreed to the agreements. 56 | 57 | After dataset preparation, the structure of the complete datasets should be like the following. 58 | ```text 59 | P3M-10k 60 | ├── train 61 | ├── blurred_image 62 | ├── mask (alpha mattes) 63 | ├── fg_blurred 64 | ├── bg 65 | ├── facemask 66 | ├── validation 67 | ├── P3M-500-P 68 | ├── blurred_image 69 | ├── mask 70 | ├── trimap 71 | ├── facemask 72 | ├── P3M-500-NP 73 | ├── original_image 74 | ├── mask 75 | ├── trimap 76 | ``` 77 | 78 | 2. If you want to test on RWP Test set, please download the original images and alpha mattes at this link. 79 | 80 | After datasets preparation, the structure of the complete datasets should be like the following. 81 | ```text 82 | RealWorldPortrait-636 83 | ├── image 84 | ├── alpha 85 | ├── ... 86 | ``` 87 | 88 | ## Pretrained Models 89 | 90 | Here we provide the model P3M-Net(ViTAE-S) that is trained on P3M-10k for testing. 91 | 92 | | Model| Google Drive | Baidu Wangpan(百度网盘) | 93 | | :----: | :----:| :----: | 94 | | P3M-Net(ViTAE-S) | [Link](https://drive.google.com/file/d/1QbSjPA_Mxs7rITp_a9OJiPeFRDwxemqK/view?usp=sharing) | [Link](https://pan.baidu.com/s/19FuiR1RwamqvxfhdXDL1fg) (pw: hxxy) | 95 | 96 | Here we provide the pretrained models of all backbones for training. 97 | 98 | | Model| Google Drive | Baidu Wangpan(百度网盘) | 99 | | :----: | :----:| :----: | 100 | | pretrained models | [Link](https://drive.google.com/file/d/1V2xt0BWCVx550Ll7GGfquvopTLseX9gY/view?usp=sharing) | [Link](https://pan.baidu.com/s/1eJ7mTLQszEtMJHJ2zn3dag?pwd=gxn9) (pw:gxn9) | 101 | 102 | 103 | ## Train on P3M-10k 104 | 105 | 1. Download P3M-10k dataset in root `P3M_DATASET_ROOT_PATH` (set up in `core/config.py`); 106 | 107 | 2. Download the pretrained models of all backbones in the previous section, and set up the output folder `REPOSITORY_ROOT_PATH` in `core/config.py`. The folder structure should be like the following, 108 | ```text 109 | [REPOSITORY_ROOT_PATH] 110 | ├── logs 111 | ├── models 112 | ├── pretrained 113 | ├── r34mp_pretrained_imagenet.pth.tar 114 | ├── swin_pretrained_epoch_299.pth 115 | ├── vitae_pretrained_ckpt.pth.tar 116 | ├── trained 117 | ``` 118 | 119 | 3. Set up parameters in `scripts/train.sh`, specify config file `cfg`, name for the run `nickname`, etc. Run the file: 120 | 121 | `chmod +x scripts/train.sh` 122 | 123 | `./scripts/train.sh` 124 | 125 | ## Test 126 | 127 | Set up parameters in `scripts/test.sh`, specify config file `cfg`, name for the run `nickname`, etc. Run the file: 128 | 129 | `chmod +x scripts/test.sh` 130 | 131 | `./scripts/test.sh` 132 | 133 | 134 |
135 | Test using provided model 136 | 137 | ### Test on P3M-10k 138 | 139 | 1. Download provided model on P3M-10k as shown in the previous section, unzip to the folder `models/pretrained/`; 140 | 141 | 2. Download P3M-10k dataset in root `P3M_DATASET_ROOT_PATH` (set up in `core/config.py`); 142 | 143 | 3. Setup parameters in `scripts/test_dataset.sh`, choose `dataset=P3M10K`, and `valset=P3M_500_NP` or `valset=P3M_500_P` depends on which validation set you want to use, run the file: 144 | 145 | `chmod +x scripts/test_dataset.sh` 146 | 147 | `./scripts/test_dataset.sh` 148 | 149 | 4. The results of the alpha matte will be saved in folder `args.test_result_dir`. Note that there may be some slight differences of the evaluation results with the ones reported in the paper due to some packages versions differences and the testing strategy. 150 | 151 | ### Test on RWP 152 | 153 | 1. Download provided model on P3M-10k as shown in the previous section, unzip to the folder `models/pretrained/`; 154 | 155 | 2. Download RWP dataset in root `RWP_TEST_SET_ROOT_PATH` (set up in `core/config.py`). Download link is here; 156 | 157 | 3. Setup parameters in `scripts/test_dataset.sh`, choose `dataset=RWP` and `valset=RWP`, run the file: 158 | 159 | `chmod +x scripts/test_dataset.sh` 160 | 161 | `./scripts/test_dataset.sh` 162 | 163 | 4. The results of the alpha matte will be saved in folder `args.test_result_dir`. Note that there may be some slight differences of the evaluation results with the ones reported in the paper due to some packages versions differences and the testing strategy. 164 | 165 | ### Test on Samples 166 | 167 | 1. Download provided model on P3M-10k as shown in the previous section, unzip to the folder `models/pretrained/`; 168 | 169 | 2. Download images in root `SAMPLES_ROOT_PATH/original` (set up in config.py) 170 | 171 | 3. Set up parameters in `scripts/test_samples.sh`, and run the file: 172 | 173 | `chmod +x samples/original/*` 174 | 175 | `chmod +x scripts/test_samples.sh` 176 | 177 | `./scripts/test_samples.sh` 178 | 179 | 4. The results of the alpha matte will be saved in folder `SAMPLES_RESULT_ALPHA_PATH` (set up in config.py). The color results will be saved in folder `SAMPLES_RESULT_COLOR_PATH` (set up in config.py). Note that there may be some slight differences of the evaluation results with the ones reported in the paper due to some packages versions differences and the testing strategy. 180 | 181 |
182 | 183 | ## Inference Code - How to Test on Your Images 184 | 185 |

Here we provide the procedure of testing on sample images by our pretrained P3M-Net(ViTAE-S) model:

186 | 187 | 1. Setup environment following this [instruction page](https://github.com/ViTAE-Transformer/P3M-Net/tree/main/core); 188 | 189 | 2. Insert the path `REPOSITORY_ROOT_PATH` in the file `core/config.py`; 190 | 191 | 3. Download the pretrained P3M-Net(ViTAE-S) model from here ([Google Drive](https://drive.google.com/file/d/1QbSjPA_Mxs7rITp_a9OJiPeFRDwxemqK/view?usp=sharing) | [Baidu Wangpan](https://pan.baidu.com/s/19FuiR1RwamqvxfhdXDL1fg) (pw: hxxy))) and unzip to the folder `models/pretrained/`; 192 | 193 | 4. Save your sample images in folder `samples/original/.`; 194 | 195 | 5. Setup parameters in the file `scripts/test_samples.sh` and run by: 196 | 197 | `chmod +x scripts/test_samples.sh` 198 | 199 | `scripts/test_samples.sh`; 200 | 201 | 6. The results of alpha matte and transparent color image will be saved in folder `samples/result_alpha/.` and `samples/result_color/.`. 202 | 203 |

We show some sample images, the predicted alpha mattes, and their transparent results as below. We use the pretrained P3M-Net(ViTAE-S) model from section P3M-Net and Variants with `RESIZE` test strategy.

204 | 205 | 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /core/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | config file. 4 | 5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | Licensed under the MIT License (see LICENSE for details) 7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | """ 11 | 12 | 13 | ########## Root paths and logging files paths 14 | REPOSITORY_ROOT_PATH = 'p3m_out/' 15 | P3M_DATASET_ROOT_PATH = 'P3M-10k/' 16 | RWP_TEST_SET_ROOT_PATH = '' 17 | 18 | WANDB_KEY_FILE = '' # setup wandb key file path. Write your wandb key in a text file and set up the path for this file here. 19 | 20 | ########## Paths for training 21 | TRAIN_LOGS_FOLDER = REPOSITORY_ROOT_PATH+'logs/train_logs/' 22 | TEST_LOGS_FOLDER = REPOSITORY_ROOT_PATH+'logs/test_logs/' 23 | WANDB_LOGS_FOLDER = REPOSITORY_ROOT_PATH+'logs/' 24 | CKPT_SAVE_FOLDER = REPOSITORY_ROOT_PATH+'models/trained/' 25 | DEBUG_FOLDER = REPOSITORY_ROOT_PATH+'debug/' 26 | TEST_RESULT_FOLDER = REPOSITORY_ROOT_PATH+'result/' 27 | 28 | ######### Paths of datasets 29 | VALID_TEST_DATASET_CHOICE = [ 30 | 'P3M_500_P', 31 | 'P3M_500_NP', 32 | 'VAL500P', 33 | 'VAL500NP', 34 | 'VAL500P_NORMAL', 35 | 'VAL500P_MOSAIC', 36 | 'VAL500P_ZERO', 37 | 'RealWorldPortrait636'] 38 | VALID_TEST_DATA_CHOICE = [*VALID_TEST_DATASET_CHOICE, 'SAMPLES'] 39 | 40 | DATASET_PATHS_DICT={ 41 | 'P3M10K':{ 42 | 'TRAIN':{ 43 | 'ROOT_PATH':P3M_DATASET_ROOT_PATH+'train/', 44 | 'ORIGINAL_PATH':P3M_DATASET_ROOT_PATH+'train/facemask_blurred/', 45 | 'MASK_PATH':P3M_DATASET_ROOT_PATH+'train/mask/', 46 | 'FG_PATH':P3M_DATASET_ROOT_PATH+'train/fg_blurred/', 47 | 'BG_PATH':P3M_DATASET_ROOT_PATH+'train/bg/', 48 | 'PRIVACY_MASK_PATH': P3M_DATASET_ROOT_PATH+'train/facemask/', 49 | 'NORMAL_ORIGINAL_PATH': P3M_DATASET_ROOT_PATH+'train/original/', 50 | 'SAMPLE_NUMBER':9421 51 | }, 52 | 'P3M_500_P':{ 53 | 'ROOT_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/', 54 | 'ORIGINAL_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/blurred_image/', 55 | 'MASK_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/mask/', 56 | 'TRIMAP_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/trimap/', 57 | 'PRIVACY_MASK_PATH': P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/facemask/', 58 | 'SAMPLE_NUMBER':500 59 | }, 60 | 'P3M_500_NP':{ 61 | 'ROOT_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-NP/', 62 | 'ORIGINAL_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-NP/original_image/', 63 | 'MASK_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-NP/mask/', 64 | 'TRIMAP_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-NP/trimap/', 65 | 'PRIVACY_MASK_PATH': None, 66 | 'SAMPLE_NUMBER':500 67 | }, 68 | }, 69 | 'RWP': { 70 | 'RWP': { 71 | 'ROOT_PATH': RWP_TEST_SET_ROOT_PATH, 72 | 'ORIGINAL_PATH': RWP_TEST_SET_ROOT_PATH+'image/', 73 | 'MASK_PATH': RWP_TEST_SET_ROOT_PATH+'alpha/', 74 | 'TRIMAP_PATH': None, 75 | 'PRIVACY_MASK_PATH': None, 76 | 'SAMPLE_NUMBER': 636 77 | } 78 | } 79 | } 80 | 81 | ######### Paths of samples for test 82 | 83 | SAMPLES_ORIGINAL_PATH = REPOSITORY_ROOT_PATH+'samples/original/' 84 | SAMPLES_RESULT_ALPHA_PATH = REPOSITORY_ROOT_PATH+'samples/result_alpha/' 85 | SAMPLES_RESULT_COLOR_PATH = REPOSITORY_ROOT_PATH+'samples/result_color/' 86 | 87 | ######### Paths of pretrained model 88 | PRETRAINED_R34_MP = REPOSITORY_ROOT_PATH+'models/pretrained/r34mp_pretrained_imagenet.pth.tar' 89 | PRETRAINED_SWIN_STEM_POOLING5 = REPOSITORY_ROOT_PATH+'models/pretrained/swin_pretrained_epoch_299.pth' 90 | PRETRAINED_VITAE_NORC_MAXPOOLING_BIAS_BASIC_STAGE4_14 = REPOSITORY_ROOT_PATH+'models/pretrained/vitae_pretrained_ckpt.pth.tar' 91 | 92 | ######### Test config 93 | MAX_SIZE_H = 1600 94 | MAX_SIZE_W = 1600 95 | MIN_SIZE_H = 512 96 | MIN_SIZE_W = 512 97 | SHORTER_PATH_LIMITATION = 1080 98 | -------------------------------------------------------------------------------- /core/config_yacs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | import datetime 4 | import yaml 5 | from yacs.config import CfgNode as CN 6 | 7 | _C = CN() 8 | 9 | # base config files 10 | _C.BASE = [''] 11 | 12 | # ----------------------------------------------------------------------------- 13 | # Data settings 14 | # ----------------------------------------------------------------------------- 15 | _C.DATA = CN() 16 | _C.DATA.BATCH_SIZE = 8 17 | _C.DATA.DATASET = 'P3M10K' 18 | _C.DATA.TRAIN_SET = 'TRAIN' # other options: [TRAIN_NORMAL, TRAIN_MOSAIC, TRAIN_ZERO] 19 | _C.DATA.HYBRID_OBFUSCATION = None 20 | _C.DATA.NUM_WORKERS = 8 21 | _C.DATA.RESIZE_SIZE = 512 22 | _C.DATA.CROP_SIZE = [512, 768, 1024] # RATIO: 1, 1.5, 2 23 | 24 | # global data setting | for two stage network only 25 | _C.DATA.GLOBAL = CN() 26 | _C.DATA.GLOBAL.CROP_SIZE = [256, 384, 512] 27 | _C.DATA.GLOBAL.RESIZE_SIZE = 256 28 | 29 | # local data setting | for two stage network only 30 | _C.DATA.LOCAL = CN() 31 | _C.DATA.LOCAL.GLOBAL_SIZE = None 32 | _C.DATA.LOCAL.PATCH_RESIZE_SIZE = 64 33 | _C.DATA.LOCAL.PATCH_SIZE = 64 34 | _C.DATA.LOCAL.PATCH_NUMBER = 64 35 | 36 | # Settings for cut and paste at data level 37 | _C.DATA.CUT_AND_PASTE = CN() 38 | _C.DATA.CUT_AND_PASTE.TYPE = 'NONE' # CHOICES: ['NONE', 'VANILLA', 'AUG', 'RESIZE2FILL', 'GRID_SAMPLE'] 39 | _C.DATA.CUT_AND_PASTE.SOURCE_DATASET = '' # CHOICES: ['SELF', 'P3M10K', 'CELEBAMASK_HQ'] 40 | _C.DATA.CUT_AND_PASTE.PROB = 0.5 41 | # cut and paste: aug 42 | _C.DATA.CUT_AND_PASTE.AUG = CN() 43 | _C.DATA.CUT_AND_PASTE.AUG.DEGREE = 30 44 | _C.DATA.CUT_AND_PASTE.AUG.SCALE = [0.8,1.2] 45 | _C.DATA.CUT_AND_PASTE.AUG.SHEAR = None 46 | _C.DATA.CUT_AND_PASTE.AUG.FLIP = [0.5,0] 47 | _C.DATA.CUT_AND_PASTE.AUG.RANDOM_PASTE = False 48 | # cut and paste: grid_sample 49 | _C.DATA.CUT_AND_PASTE.GRID_SAMPLE = CN() 50 | _C.DATA.CUT_AND_PASTE.GRID_SAMPLE.SELECT_RANGE = [10, 2, 10, 2] 51 | _C.DATA.CUT_AND_PASTE.GRID_SAMPLE.DOWN_SCALE = 4 52 | 53 | # ----------------------------------------------------------------------------- 54 | # Model settings 55 | # ----------------------------------------------------------------------------- 56 | _C.MODEL = CN() 57 | _C.MODEL.TYPE = '' 58 | _C.MODEL.PRETRAINED = True 59 | 60 | # Settings for swin transformer 61 | _C.MODEL.SWIN = CN() 62 | _C.MODEL.SWIN.PATCH_SIZE = 4 63 | _C.MODEL.SWIN.IN_CHANS = 3 64 | _C.MODEL.SWIN.EMBED_DIM = 96 65 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 66 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 67 | _C.MODEL.SWIN.WINDOW_SIZE = 7 68 | _C.MODEL.SWIN.MLP_RATIO = 4. 69 | _C.MODEL.SWIN.QKV_BIAS = True 70 | _C.MODEL.SWIN.QK_SCALE = None 71 | _C.MODEL.SWIN.APE = False 72 | _C.MODEL.SWIN.PATCH_NORM = True 73 | _C.MODEL.SWIN.USE_CHECKPOINT = False 74 | _C.MODEL.SWIN.DROP_RATE = 0.0 75 | _C.MODEL.SWIN.DROP_PATH_RATE = 0.2 76 | 77 | # Settings for cut and paste at feature level 78 | _C.MODEL.CUT_AND_PASTE = CN() 79 | _C.MODEL.CUT_AND_PASTE.TYPE = 'NONE' # CHOICES: ['NONE', 'CP', 'SHUFFLE'], cp shuffle 只能二选一 80 | _C.MODEL.CUT_AND_PASTE.PROB = 0.5 81 | _C.MODEL.CUT_AND_PASTE.START_EPOCH = 50 82 | _C.MODEL.CUT_AND_PASTE.LAYER = [] # can be a list of layers 83 | _C.MODEL.CUT_AND_PASTE.DETACH = True 84 | 85 | # Settings for cut and paste cp 86 | _C.MODEL.CUT_AND_PASTE.CP = CN() 87 | _C.MODEL.CUT_AND_PASTE.CP.TYPE = 'VANILLA' # CHOICES: ['VANILLA', 'MULTISCALE'], control the fea cut and paste data class type 88 | _C.MODEL.CUT_AND_PASTE.CP.MODEL = 'SELF' # CHOICES: ['COPY_EVERY_ITER', 'COPY_EVERY_EPOCH', 'SELF'], control the model used to extract source fea 89 | _C.MODEL.CUT_AND_PASTE.CP.SOURCE_DATASET = 'SELF' # CHOIES: ['SELF', 'P3M10K', 'CELEBAMASK_HQ'], control the data used to extract source fea 90 | _C.MODEL.CUT_AND_PASTE.CP.SOURCE_BATCH_SIZE = 1 91 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ = CN() 92 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.DEFREE = None # degree ranges, [-degree, +degree] 93 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.FLIP = None # [prob horizon, prob vertical] 94 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.CROP_SIZE = None 95 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.RESIZE_SIZE = None 96 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.SCALE = [0.3, 0.7] # scale < 1.0 97 | 98 | # Settings for cut and paste shuffle 99 | _C.MODEL.CUT_AND_PASTE.SHUFFLE = CN() 100 | _C.MODEL.CUT_AND_PASTE.SHUFFLE.TYPE = '' # CHOICES: ['NONE', 'FG2FACE'] 101 | _C.MODEL.CUT_AND_PASTE.SHUFFLE.KERNEL_SIZE = 1 102 | 103 | # TODO 104 | # .SHUFFLE.TYPE: FG2FACE 105 | 106 | # ----------------------------------------------------------------------------- 107 | # Training settings 108 | # ----------------------------------------------------------------------------- 109 | _C.TRAIN = CN() 110 | _C.TRAIN.START_EPOCH = 1 111 | _C.TRAIN.EPOCHS = 150 112 | _C.TRAIN.WARMUP_EPOCHS = 0 113 | _C.TRAIN.LR_DECAY = False 114 | _C.TRAIN.LR = 0.00001 115 | _C.TRAIN.CLIP_GRAD = False 116 | _C.TRAIN.RESUME_CKPT = None 117 | 118 | # Settings for optimizer 119 | _C.TRAIN.OPTIMIZER = CN() 120 | _C.TRAIN.OPTIMIZER.TYPE = 'ADAM' 121 | 122 | # ----------------------------------------------------------------------------- 123 | # Test settings 124 | # ----------------------------------------------------------------------------- 125 | _C.TEST = CN() 126 | _C.TEST.DATASET = 'VAL500P' 127 | _C.TEST.TEST_METHOD = 'HYBRID' 128 | _C.TEST.CKPT_NAME = 'best_SAD_VAL500P' 129 | _C.TEST.FAST_TEST = False 130 | _C.TEST.TEST_PRIVACY = False 131 | _C.TEST.SAVE_RESULT = False 132 | _C.TEST.LOCAL_PATCH_NUM = 1024 # for test only 133 | 134 | # ----------------------------------------------------------------------------- 135 | # Other settings, e.g. logging, wandb, dist, tag, test freq, save ckpt 136 | # ----------------------------------------------------------------------------- 137 | _C.TAG = 'debug' 138 | _C.ENABLE_WANDB = False 139 | _C.TEST_FREQ = 1 140 | _C.AUTO_RESUME = False 141 | _C.SEED = 10007 142 | _C.DIST = False 143 | _C.LOCAL_RANK = 0 144 | 145 | 146 | def _update_config_from_file(config, cfg_file): 147 | config.defrost() 148 | with open(cfg_file, 'r') as f: 149 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 150 | 151 | for cfg in yaml_cfg.setdefault('BASE', ['']): 152 | if cfg: 153 | _update_config_from_file( 154 | config, os.path.join(os.path.dirname(cfg_file), cfg) 155 | ) 156 | print('=> merge config from {}'.format(cfg_file)) 157 | config.merge_from_file(cfg_file) 158 | config.freeze() 159 | 160 | 161 | def update_config(config, args): 162 | # merge config from other config 163 | if getattr(args, 'existing_cfg', None): # for test only 164 | if hasattr(args.existing_cfg.MODEL, 'CUT_AND_PASTE') and hasattr(args.existing_cfg.MODEL.CUT_AND_PASTE, 'LAYER'): 165 | if type(args.existing_cfg.MODEL.CUT_AND_PASTE.LAYER) != list: 166 | args.existing_cfg.defrost() 167 | args.existing_cfg.MODEL.CUT_AND_PASTE.LAYER = [args.existing_cfg.MODEL.CUT_AND_PASTE.LAYER] 168 | args.existing_cfg.freeze() 169 | config.merge_from_other_cfg(args.existing_cfg) 170 | assert args.tag == config.TAG 171 | 172 | # merge config from file 173 | if getattr(args, 'cfg', None): 174 | _update_config_from_file(config, args.cfg) 175 | 176 | config.defrost() 177 | if getattr(args, 'opts', None): 178 | config.merge_from_list(args.opts) 179 | 180 | # merge from specific arguments 181 | if getattr(args, 'arch', None): 182 | config.MODEL.TYPE = args.arch 183 | 184 | if getattr(args, 'train_from_scratch', None): 185 | config.MODEL.PRETRAINED = False 186 | 187 | if getattr(args, 'tag', None): 188 | config.TAG = args.tag 189 | 190 | if getattr(args, 'nEpochs', None): 191 | config.TRAIN.EPOCHS = args.nEpochs 192 | 193 | if getattr(args, 'warmup_nEpochs', None): 194 | config.TRAIN.WARMUP_EPOCHS = args.warmup_nEpochs 195 | 196 | if getattr(args, 'batchSize', None): 197 | config.DATA.BATCH_SIZE = args.batchSize 198 | 199 | if getattr(args, 'lr', None): 200 | config.TRAIN.LR = args.lr 201 | 202 | if getattr(args, 'lr_decay', None): 203 | config.TRAIN.LR_DECAY = args.lr_decay 204 | 205 | if getattr(args, 'clip_grad', None): 206 | config.TRAIN.CLIP_GRAD = args.clip_grad 207 | 208 | if getattr(args, 'threads', None): 209 | config.DATA.NUM_WORKERS = args.threads 210 | 211 | if getattr(args, 'test_freq', None): 212 | config.TEST_FREQ = args.test_freq 213 | 214 | if getattr(args, 'enable_wandb', None): 215 | config.ENABLE_WANDB = args.enable_wandb 216 | 217 | if getattr(args, 'auto_resume', None): 218 | config.AUTO_RESUME = args.auto_resume 219 | 220 | if getattr(args, 'source_batch_size', None): 221 | config.MODEL.CUT_AND_PASTE.CP.SOURCE_BATCH_SIZE = args.source_batch_size 222 | 223 | if getattr(args, 'dataset', None): 224 | config.DATA.DATASET = args.dataset 225 | 226 | if getattr(args, 'train_set', None): 227 | config.DATA.TRAIN_SET = args.train_set 228 | 229 | if getattr(args, 'seed', None): 230 | config.SEED = args.seed 231 | 232 | if getattr(args, 'test_dataset', None): 233 | config.TEST.DATASET = args.test_dataset 234 | 235 | if getattr(args, 'test_ckpt', None): 236 | config.TEST.CKPT_NAME = args.test_ckpt 237 | 238 | if getattr(args, 'test_method', None): 239 | config.TEST.TEST_METHOD = args.test_method 240 | 241 | if getattr(args, 'fast_test', None): 242 | config.TEST.FAST_TEST = args.fast_test 243 | 244 | if getattr(args, 'test_privacy', None): 245 | config.TEST.TEST_PRIVACY = args.test_privacy 246 | 247 | if getattr(args, 'save_result', None): 248 | config.TEST.SAVE_RESULT = args.save_result 249 | 250 | if getattr(args, 'local_rank', None): 251 | config.LOCAL_RANK = args.local_rank 252 | 253 | config.freeze() 254 | 255 | def get_config(args): 256 | config = _C.clone() 257 | update_config(config, args) 258 | return config -------------------------------------------------------------------------------- /core/configs/ResNet34.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: 'r34' 3 | CUT_AND_PASTE: 4 | TYPE: 'NONE' 5 | 6 | DATA: 7 | CUT_AND_PASTE: 8 | TYPE: 'NONE' 9 | -------------------------------------------------------------------------------- /core/configs/Swin_T.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: 'swin' 3 | CUT_AND_PASTE: 4 | TYPE: 'NONE' 5 | 6 | DATA: 7 | CUT_AND_PASTE: 8 | TYPE: 'NONE' -------------------------------------------------------------------------------- /core/configs/ViTAE_S.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: 'vitae' 3 | CUT_AND_PASTE: 4 | TYPE: 'NONE' 5 | 6 | DATA: 7 | CUT_AND_PASTE: 8 | TYPE: 'NONE' 9 | -------------------------------------------------------------------------------- /core/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | Inferernce file. 4 | 5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | Licensed under the MIT License (see LICENSE for details) 7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | """ 11 | import torch 12 | import cv2 13 | import random 14 | import numpy as np 15 | from PIL import Image 16 | from torchvision import transforms 17 | from copy import deepcopy 18 | import ipdb 19 | import math 20 | from torch.utils.data import DataLoader 21 | from torch.utils.data._utils.collate import default_collate 22 | from collections.abc import Iterable 23 | 24 | from data_util import * 25 | from config import * 26 | from util import * 27 | 28 | 29 | ######################### 30 | ## collate 31 | ######################### 32 | 33 | def collate_two_stage(batch): 34 | transposed = list(zip(*batch)) 35 | global_data = transposed[:6] 36 | local_data = transposed[6:12] 37 | params = transposed[12:14] 38 | 39 | global_batch = [default_collate(list(elem)) for elem in global_data] 40 | local_batch = [torch.cat(list(elem), dim=0) for elem in local_data] 41 | return [*global_batch, *local_batch, *params] 42 | 43 | ######################### 44 | ## Data transformer 45 | ######################### 46 | class MattingTransform(object): 47 | def __init__(self, crop_size, resize_size): 48 | super(MattingTransform, self).__init__() 49 | self.crop_size = crop_size 50 | self.resize_size = resize_size 51 | 52 | # args: image(blurred), mask, fg, bg, trimap, facemask, source_img, source_facemask 53 | def __call__(self, *args): 54 | ori = args[0] 55 | trimap = args[4] 56 | 57 | h, w, c = ori.shape 58 | crop_size = random.choice(self.crop_size) 59 | crop_size = crop_size if crop_size < min(h, w) else 512 60 | resize_size = self.resize_size 61 | 62 | target = np.where(trimap == 128) if random.random() < 0.5 else np.where(trimap > -100) 63 | if len(target[0]) == 0: 64 | target = np.where(trimap > -100) 65 | 66 | random_idx = np.random.choice(len(target[0])) 67 | centerh = target[0][random_idx] 68 | centerw = target[1][random_idx] 69 | crop_loc = self.safe_crop([centerh, centerw], crop_size, trimap.shape) 70 | 71 | flip_flag = True if random.random() < 0.5 else False 72 | 73 | args_transform = [] 74 | for item in args: 75 | item = item[crop_loc[0]:crop_loc[2], crop_loc[1]:crop_loc[3]] 76 | if flip_flag: 77 | item = cv2.flip(item, 1) 78 | item = cv2.resize(item, (resize_size, resize_size), interpolation=cv2.INTER_LINEAR) 79 | args_transform.append(item) 80 | 81 | return args_transform 82 | 83 | def safe_crop(self, center_pt, crop_size, img_size): 84 | h, w = img_size[:2] 85 | crop_size = min(h, w, crop_size) # make sure crop_size <= min(h,w) 86 | 87 | center_h, center_w = center_pt 88 | 89 | left_top_h = max(center_h-crop_size//2, 0) 90 | right_bottom_h = min(h, left_top_h+crop_size) 91 | left_top_h = min(left_top_h, right_bottom_h-crop_size) 92 | 93 | left_top_w = max(center_w-crop_size//2, 0) 94 | right_bottom_w = min(w, left_top_w+crop_size) 95 | left_top_w = min(left_top_w, right_bottom_w-crop_size) 96 | 97 | return (left_top_h, left_top_w, right_bottom_h, right_bottom_w) 98 | 99 | 100 | class MattingDataset(torch.utils.data.Dataset): 101 | def __init__(self, config, transform): 102 | # Prepare transform 103 | self.transform = transform 104 | 105 | # Load data 106 | self.samples=[] 107 | self.samples += generate_paths_for_dataset(dataset=config.DATA.DATASET, trainset=config.DATA.TRAIN_SET) 108 | 109 | def __getitem__(self,index): 110 | # Prepare training sample paths 111 | ori_path, mask_path, fg_path, bg_path, facemask_path = self.samples[index] 112 | 113 | ori = np.array(Image.open(ori_path)) 114 | mask = trim_img(np.array(Image.open(mask_path))) 115 | fg = np.array(Image.open(fg_path)) 116 | bg = np.array(Image.open(bg_path)) 117 | facemask = np.array(Image.open(facemask_path))[:,:,0:1] 118 | # Generate trimap/dilation/erosion online 119 | kernel_size = random.randint(15, 30) 120 | trimap = gen_trimap_with_dilate(mask, kernel_size) 121 | 122 | # Data transformation to generate samples (crop/flip/resize) 123 | # Transform input order: ori, mask, fg, bg, trimap 124 | argv = self.transform(ori, mask, fg, bg, trimap, facemask) 125 | argv_transform = [] 126 | for item in argv: 127 | if item.ndim<3: 128 | item = torch.from_numpy(item.astype(np.float32)[np.newaxis, :, :]) 129 | else: 130 | item = torch.from_numpy(item.astype(np.float32)).permute(2, 0, 1) 131 | argv_transform.append(item) 132 | 133 | [ori, mask, fg, bg, trimap, facemask] = argv_transform 134 | 135 | trimap[trimap > 180] = 255 136 | trimap[trimap < 50] = 0 137 | trimap[(trimap < 255) * (trimap > 0)] = 128 138 | 139 | facemask[facemask>100] = 255 140 | facemask[facemask<=100] = 0 141 | 142 | # normalize ori, fg, bg 143 | ori = ori/255.0 144 | fg = fg/255.0 145 | bg = bg/255.0 146 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 147 | std=[0.229, 0.224, 0.225]) 148 | ori = normalize(ori) 149 | fg = normalize(fg) 150 | bg = normalize(bg) 151 | 152 | # output order: ori, mask, fg, bg, trimap 153 | return ori, mask, fg, bg, trimap, facemask 154 | 155 | def __len__(self): 156 | return len(self.samples) 157 | 158 | -------------------------------------------------------------------------------- /core/data_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | Inferernce file. 4 | 5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | Licensed under the MIT License (see LICENSE for details) 7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | """ 11 | 12 | import ipdb 13 | import random 14 | import numpy as np 15 | from copy import deepcopy 16 | import math 17 | 18 | 19 | ######################################### 20 | # Pure Functions 21 | ######################################### 22 | 23 | def safe_crop(center_pt, crop_size, img_size): 24 | h, w = img_size[:2] 25 | crop_size = min(h, w, crop_size) # make sure crop_size <= min(h,w) 26 | 27 | center_h, center_w = center_pt 28 | 29 | left_top_h = max(center_h-crop_size//2, 0) 30 | right_bottom_h = min(h, left_top_h+crop_size) 31 | left_top_h = min(left_top_h, right_bottom_h-crop_size) 32 | 33 | left_top_w = max(center_w-crop_size//2, 0) 34 | right_bottom_w = min(w, left_top_w+crop_size) 35 | left_top_w = min(left_top_w, right_bottom_w-crop_size) 36 | 37 | return (left_top_h, left_top_w, right_bottom_h, right_bottom_w) 38 | 39 | 40 | ######################################### 41 | # Functions for Affine Transform 42 | ######################################### 43 | 44 | 45 | def get_random_params_for_inverse_affine_matrix(degrees, translate, scale_ranges, shears, flip, img_size): 46 | """Get parameters for affine transformation 47 | 48 | Returns: 49 | sequence: params to be passed to the affine transformation 50 | """ 51 | angle = random.uniform(-degrees, degrees) 52 | assert translate is None, "bugs here, figure out the xy axis and img size problem" 53 | if translate is not None: 54 | max_dx = translate[0] * img_size[0] 55 | max_dy = translate[1] * img_size[1] 56 | translations = (np.round(random.uniform(-max_dx, max_dx)), 57 | np.round(random.uniform(-max_dy, max_dy))) 58 | else: 59 | translations = (0, 0) 60 | 61 | if scale_ranges is not None: 62 | scale = (random.uniform(scale_ranges[0], scale_ranges[1]), 63 | random.uniform(scale_ranges[0], scale_ranges[1])) 64 | else: 65 | scale = (1.0, 1.0) 66 | 67 | if shears is not None: 68 | shear = random.uniform(shears[0], shears[1]) 69 | else: 70 | shear = 0.0 71 | 72 | if flip is not None: 73 | flip = 1 - (np.random.rand(2) < flip).astype(np.int) * 2 74 | flip = flip.tolist() 75 | else: 76 | flip = [1.0,1.0] 77 | 78 | return angle, translations, scale, shear, flip 79 | 80 | 81 | def get_inverse_affine_matrix(center, angle, translate=None, scale=None, shear=None, flip=None): 82 | # Helper method to compute inverse matrix for affine transformation 83 | 84 | # As it is explained in PIL.Image.rotate 85 | # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 86 | # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] 87 | # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] 88 | # RSS is rotation with scale and shear matrix 89 | # It is different from the original function in torchvision 90 | # The order are changed to flip -> scale -> rotation -> shear 91 | # x and y have different scale factors 92 | # RSS(shear, a, scale, f) = [ cos(a + shear)*scale_x*f -sin(a + shear)*scale_y 0] 93 | # [ sin(a)*scale_x*f cos(a)*scale_y 0] 94 | # [ 0 0 1] 95 | # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 96 | 97 | """ 98 | Args: 99 | center (tuple(int,int)) : the center of the image, required 100 | angle (int) : angle for rotation, range [-360,360], required 101 | translate (tuple(int,int)) : default (0,0). Bugs here, so DON'T use it. 102 | scale (tuple(double,double)): default (1.,1.), scale for x and y axis 103 | shear (double) : default 0.0 104 | flip (tuple(int,int)) : default no flip, choices [0 horizonal, 1 vertical] 105 | """ 106 | # assertions, check param range 107 | if translate is not None: 108 | assert translate == (0,0), "BUG UNSOLVED" 109 | else: 110 | translate = (0,0) 111 | 112 | if scale is not None: 113 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 114 | "scale should be a list or tuple and it must be of length 2." 115 | for s in scale: 116 | if s <= 0: 117 | raise ValueError("scale values should be positive") 118 | else: 119 | scale = (1.,1.) 120 | 121 | if shear is None: 122 | shear = 0.0 123 | 124 | if flip is not None: 125 | assert isinstance(flip, (tuple, list)) and len(flip) == 2, \ 126 | "flip should be a list or tuple and it must be of length 2." 127 | for f in flip: 128 | if f != -1 and f != 1: 129 | raise ValueError("flip values should be -1 or 1.") 130 | else: 131 | # final flip value -1, means flip 132 | # final flip value 1, means not flip 133 | # flip = (np.random.rand(2) < flip).astype(np.int) * 2 - 1 134 | # flip[0] horizonal 135 | # flip[1] vertical 136 | flip = (1.0,1.0) 137 | 138 | angle = math.radians(angle) 139 | shear = math.radians(shear) 140 | scale_x = 1.0 / scale[0] * flip[0] 141 | scale_y = 1.0 / scale[1] * flip[1] 142 | 143 | # Inverted rotation matrix with scale and shear 144 | d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) 145 | matrix = [ 146 | math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0, 147 | -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0 148 | ] 149 | matrix = [m / d for m in matrix] 150 | 151 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 152 | matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) 153 | matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) 154 | 155 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1 156 | matrix[2] += center[0] 157 | matrix[5] += center[1] 158 | 159 | return matrix 160 | 161 | -------------------------------------------------------------------------------- /core/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | evaluation file. 4 | 5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | Licensed under the MIT License (see LICENSE for details) 7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | """ 11 | 12 | import os 13 | import argparse 14 | import numpy as np 15 | import cv2 16 | from tqdm import tqdm 17 | from metrics import * 18 | from util import listdir_nohidden 19 | 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser(description='Arguments for the training purpose.') 23 | parser.add_argument('--pred_dir', type=str, required=True, help='path to predictions') 24 | parser.add_argument('--alpha_dir', type=str, required=True, help='path to groundtruth alpha') 25 | parser.add_argument('--trimap_dir', type=str, help='path to trimap') 26 | parser.add_argument('--fast_test', action='store_true', help='skip grad and conn') 27 | args, _ = parser.parse_known_args() 28 | 29 | print('Prediction dir: {}'.format(args.pred_dir)) 30 | print('Alpha dir: {}'.format(args.alpha_dir)) 31 | print('Trimap dir: {}'.format(args.trimap_dir)) 32 | if args.fast_test: 33 | print('Will skip gradient and connectivity...') 34 | return args 35 | 36 | def evaluate_folder(args): 37 | img_list = listdir_nohidden(args.alpha_dir) 38 | total_number = len(img_list) 39 | 40 | sad_diffs = 0. 41 | mse_diffs = 0. 42 | mad_diffs = 0. 43 | sad_trimap_diffs = 0. 44 | mse_trimap_diffs = 0. 45 | mad_trimap_diffs = 0. 46 | sad_fg_diffs = 0. 47 | sad_bg_diffs = 0. 48 | conn_diffs = 0. 49 | grad_diffs = 0. 50 | 51 | for img_name in tqdm(img_list): 52 | predict = cv2.imread(os.path.join(args.pred_dir, img_name), 0).astype(np.float32)/255.0 53 | alpha = cv2.imread(os.path.join(args.alpha_dir, img_name), 0).astype(np.float32)/255.0 54 | 55 | if args.trimap_dir is not None: 56 | trimap = cv2.imread(os.path.join(args.trimap_dir, img_name), 0).astype(np.float32) 57 | else: 58 | trimap = None 59 | 60 | sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha) 61 | 62 | if trimap is not None: 63 | sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap) 64 | sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap) 65 | else: 66 | sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = 0.0, 0.0, 0.0 67 | sad_fg_diff, sad_bg_diff = 0.0, 0.0 68 | 69 | if args.fast_test: 70 | conn_diff = 0.0 71 | grad_diff = 0.0 72 | else: 73 | conn_diff = compute_connectivity_loss_whole_image(predict, alpha) 74 | grad_diff = compute_gradient_whole_image(predict, alpha) 75 | 76 | 77 | sad_diffs += sad_diff 78 | mse_diffs += mse_diff 79 | mad_diffs += mad_diff 80 | mse_trimap_diffs += mse_trimap_diff 81 | sad_trimap_diffs += sad_trimap_diff 82 | mad_trimap_diffs += mad_trimap_diff 83 | sad_fg_diffs += sad_fg_diff 84 | sad_bg_diffs += sad_bg_diff 85 | conn_diffs += conn_diff 86 | grad_diffs += grad_diff 87 | 88 | res_dict = {} 89 | res_dict['SAD'] = sad_diffs / total_number 90 | res_dict['MSE'] = mse_diffs / total_number 91 | res_dict['MAD'] = mad_diffs / total_number 92 | res_dict['SAD_TRIMAP'] = sad_trimap_diffs / total_number 93 | res_dict['MSE_TRIMAP'] = mse_trimap_diffs / total_number 94 | res_dict['MAD_TRIMAP'] = mad_trimap_diffs / total_number 95 | res_dict['SAD_FG'] = sad_fg_diffs / total_number 96 | res_dict['SAD_BG'] = sad_bg_diffs / total_number 97 | res_dict['CONN'] = conn_diffs / total_number 98 | res_dict['GRAD'] = grad_diffs / total_number 99 | 100 | print('Average results') 101 | print('Test image numbers: {}'.format(total_number)) 102 | print('Whole image SAD:', res_dict['SAD']) 103 | print('Whole image MSE:', res_dict['MSE']) 104 | print('Whole image MAD:', res_dict['MAD']) 105 | print('Unknown region SAD:', res_dict['SAD_TRIMAP']) 106 | print('Unknown region MSE:', res_dict['MSE_TRIMAP']) 107 | print('Unknown region MAD:', res_dict['MAD_TRIMAP']) 108 | print('Foreground SAD:', res_dict['SAD_FG']) 109 | print('Background SAD:', res_dict['SAD_BG']) 110 | print('Gradient:', res_dict['GRAD']) 111 | print('Connectivity:', res_dict['CONN']) 112 | return res_dict 113 | 114 | 115 | if __name__ == '__main__': 116 | args = get_args() 117 | evaluate_folder(args) 118 | -------------------------------------------------------------------------------- /core/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | Inferernce file. 4 | 5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | Licensed under the MIT License (see LICENSE for details) 7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | """ 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import math 15 | from torch.autograd import Variable 16 | import torch.nn.functional as fnn 17 | import ipdb 18 | 19 | 20 | ############################## 21 | ### Training loses for P3M-NET 22 | ############################## 23 | def get_crossentropy_loss(gt,pre): 24 | gt_copy = gt.clone() 25 | gt_copy[gt_copy==0] = 0 26 | gt_copy[gt_copy==255] = 2 27 | gt_copy[gt_copy>2] = 1 28 | gt_copy = gt_copy.long() 29 | gt_copy = gt_copy[:,0,:,:] 30 | criterion = nn.CrossEntropyLoss() 31 | entropy_loss = criterion(pre, gt_copy) 32 | return entropy_loss 33 | 34 | def get_alpha_loss(predict, alpha, trimap): 35 | weighted = torch.zeros(trimap.shape).cuda() 36 | weighted[trimap == 128] = 1. 37 | alpha_f = alpha / 255. 38 | alpha_f = alpha_f.cuda() 39 | diff = predict - alpha_f 40 | diff = diff * weighted 41 | alpha_loss = torch.sqrt(diff ** 2 + 1e-12) 42 | alpha_loss_weighted = alpha_loss.sum() / (weighted.sum() + 1.) 43 | return alpha_loss_weighted 44 | 45 | def get_alpha_loss_whole_img(predict, alpha): 46 | weighted = torch.ones(alpha.shape).cuda() 47 | alpha_f = alpha / 255. 48 | alpha_f = alpha_f.cuda() 49 | diff = predict - alpha_f 50 | alpha_loss = torch.sqrt(diff ** 2 + 1e-12) 51 | alpha_loss = alpha_loss.sum()/(weighted.sum()) 52 | return alpha_loss 53 | 54 | ## Laplacian loss is refer to 55 | ## https://gist.github.com/MarcoForte/a07c40a2b721739bb5c5987671aa5270 56 | def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False): 57 | if size % 2 != 1: 58 | raise ValueError("kernel size must be uneven") 59 | grid = np.float32(np.mgrid[0:size,0:size].T) 60 | gaussian = lambda x: np.exp((x - size//2)**2/(-2*sigma**2))**2 61 | kernel = np.sum(gaussian(grid), axis=2) 62 | kernel /= np.sum(kernel) 63 | kernel = np.tile(kernel, (n_channels, 1, 1)) 64 | kernel = torch.FloatTensor(kernel[:, None, :, :]).cuda() 65 | return Variable(kernel, requires_grad=False) 66 | 67 | def conv_gauss(img, kernel): 68 | """ convolve img with a gaussian kernel that has been built with build_gauss_kernel """ 69 | n_channels, _, kw, kh = kernel.shape 70 | img = fnn.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 71 | return fnn.conv2d(img, kernel, groups=n_channels) 72 | 73 | def laplacian_pyramid(img, kernel, max_levels=5): 74 | current = img 75 | pyr = [] 76 | for level in range(max_levels): 77 | filtered = conv_gauss(current, kernel) 78 | diff = current - filtered 79 | pyr.append(diff) 80 | current = fnn.avg_pool2d(filtered, 2) 81 | pyr.append(current) 82 | return pyr 83 | 84 | def get_laplacian_loss(predict, alpha, trimap): 85 | weighted = torch.zeros(trimap.shape).cuda() 86 | weighted[trimap == 128] = 1. 87 | alpha_f = alpha / 255. 88 | alpha_f = alpha_f.cuda() 89 | alpha_f = alpha_f.clone()*weighted 90 | predict = predict.clone()*weighted 91 | gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True) 92 | pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5) 93 | pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5) 94 | laplacian_loss_weighted = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict)) 95 | return laplacian_loss_weighted 96 | 97 | def get_laplacian_loss_whole_img(predict, alpha): 98 | alpha_f = alpha / 255. 99 | alpha_f = alpha_f.cuda() 100 | gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True) 101 | pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5) 102 | pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5) 103 | laplacian_loss = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict)) 104 | return laplacian_loss 105 | 106 | def get_composition_loss_whole_img(img, alpha, fg, bg, predict): 107 | weighted = torch.ones(alpha.shape).cuda() 108 | predict_3 = torch.cat((predict, predict, predict), 1) 109 | comp = predict_3 * fg + (1. - predict_3) * bg 110 | comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12) # fixed bug, remove "/255." 111 | comp_loss = comp_loss.sum()/(weighted.sum()) 112 | return comp_loss 113 | 114 | def get_composition_loss(img, alpha, fg, bg, trimap, predict): 115 | wi = torch.zeros(trimap.shape) 116 | wi[trimap == 128] = 1. 117 | t_wi = wi.cuda() 118 | t3_wi = torch.cat((wi, wi, wi), 1).cuda() 119 | unknown_region_size = t_wi.sum() 120 | predict_3 = torch.cat((predict, predict, predict), 1) 121 | comp = predict_3 * fg + (1. - predict_3) * bg 122 | comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12) / 255. 123 | comp_loss = (comp_loss * t3_wi).sum() / (unknown_region_size + 1.) / 3. 124 | return comp_loss 125 | 126 | ############################## 127 | ### Test loss for matting 128 | ############################## 129 | def calculate_sad_mse_mad_privacy(predict_old, alpha, privacy): 130 | predict = np.copy(predict_old) 131 | pixel = float((privacy == 255).sum()) 132 | sad_diff_mask = np.abs(predict - alpha) 133 | sad_diff_mask[privacy != 255] = 0. 134 | sad_diff = np.sum(sad_diff_mask) / 1000 135 | if pixel == 0: 136 | return 0, 0, 0 137 | mse_diff = np.sum(sad_diff_mask ** 2)/pixel 138 | mad_diff = np.sum(sad_diff_mask)/pixel 139 | return sad_diff, mse_diff, mad_diff 140 | 141 | def calculate_sad_mse_mad(predict_old,alpha,trimap): 142 | predict = np.copy(predict_old) 143 | pixel = float((trimap == 128).sum()) 144 | predict[trimap == 255] = 1. 145 | predict[trimap == 0 ] = 0. 146 | sad_diff = np.sum(np.abs(predict - alpha))/1000 147 | if pixel==0: 148 | pixel = trimap.shape[0]*trimap.shape[1]-float((trimap==255).sum())-float((trimap==0).sum()) 149 | mse_diff = np.sum((predict - alpha) ** 2)/pixel 150 | mad_diff = np.sum(np.abs(predict - alpha))/pixel 151 | return sad_diff, mse_diff, mad_diff 152 | 153 | def calculate_sad_mse_mad_whole_img(predict, alpha): 154 | pixel = predict.shape[0]*predict.shape[1] 155 | sad_diff = np.sum(np.abs(predict - alpha))/1000 156 | mse_diff = np.sum((predict - alpha) ** 2)/pixel 157 | mad_diff = np.sum(np.abs(predict - alpha))/pixel 158 | return sad_diff, mse_diff, mad_diff 159 | 160 | def calculate_sad_fgbg(predict, alpha, trimap): 161 | sad_diff = np.abs(predict-alpha) 162 | weight_fg = np.zeros(predict.shape) 163 | weight_bg = np.zeros(predict.shape) 164 | weight_trimap = np.zeros(predict.shape) 165 | weight_fg[trimap==255] = 1. 166 | weight_bg[trimap==0 ] = 1. 167 | weight_trimap[trimap==128 ] = 1. 168 | sad_fg = np.sum(sad_diff*weight_fg)/1000 169 | sad_bg = np.sum(sad_diff*weight_bg)/1000 170 | sad_trimap = np.sum(sad_diff*weight_trimap)/1000 171 | return sad_fg, sad_bg 172 | 173 | def compute_gradient_whole_image(pd, gt): 174 | from scipy.ndimage import gaussian_filter 175 | pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32) 176 | pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32) 177 | gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32) 178 | gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32) 179 | pd_mag = np.sqrt(pd_x**2 + pd_y**2) 180 | gt_mag = np.sqrt(gt_x**2 + gt_y**2) 181 | 182 | error_map = np.square(pd_mag - gt_mag) 183 | loss = np.sum(error_map) / 10 184 | return loss 185 | 186 | def compute_connectivity_loss_whole_image(pd, gt, step=0.1): 187 | 188 | from scipy.ndimage import morphology 189 | from skimage.measure import label, regionprops 190 | h, w = pd.shape 191 | thresh_steps = np.arange(0, 1.1, step) 192 | l_map = -1 * np.ones((h, w), dtype=np.float32) 193 | lambda_map = np.ones((h, w), dtype=np.float32) 194 | for i in range(1, thresh_steps.size): 195 | pd_th = pd >= thresh_steps[i] 196 | gt_th = gt >= thresh_steps[i] 197 | label_image = label(pd_th & gt_th, connectivity=1) 198 | cc = regionprops(label_image) 199 | size_vec = np.array([c.area for c in cc]) 200 | if len(size_vec) == 0: 201 | continue 202 | max_id = np.argmax(size_vec) 203 | coords = cc[max_id].coords 204 | omega = np.zeros((h, w), dtype=np.float32) 205 | omega[coords[:, 0], coords[:, 1]] = 1 206 | flag = (l_map == -1) & (omega == 0) 207 | l_map[flag == 1] = thresh_steps[i-1] 208 | dist_maps = morphology.distance_transform_edt(omega==0) 209 | dist_maps = dist_maps / dist_maps.max() 210 | l_map[l_map == -1] = 1 211 | d_pd = pd - l_map 212 | d_gt = gt - l_map 213 | phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32) 214 | phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32) 215 | loss = np.sum(np.abs(phi_pd - phi_gt)) / 1000 216 | return loss -------------------------------------------------------------------------------- /core/infer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | Inferernce file. 4 | 5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | Licensed under the MIT License (see LICENSE for details) 7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | """ 11 | 12 | import torch 13 | import cv2 14 | import argparse 15 | import numpy as np 16 | from tqdm import tqdm 17 | from PIL import Image 18 | from skimage.transform import resize 19 | from torchvision import transforms 20 | 21 | from config import * 22 | from util import * 23 | from network import build_model 24 | 25 | 26 | def get_args(): 27 | parser = argparse.ArgumentParser(description='PyTorch Super Res Example') 28 | parser.add_argument('--arch', type=str, required=True, choices=['r34', 'swin', 'vitae'], help='model architecture') 29 | parser.add_argument('--dataset', type=str, required=True, choices=['P3M10K', 'RWP', 'SAMPLES'], help='dataset to test') 30 | parser.add_argument('--test_set', type=str, choices=['RWP', 'P3M_500_P', 'P3M_500_NP'], help='the validation set to test') 31 | parser.add_argument('--model_path', type=str, required=True, help='path of checkpoint') 32 | parser.add_argument('--test_choice', type=str, choices=['HYBRID', 'RESIZE'], required=True, help='how to test') 33 | parser.add_argument('--test_result_dir', type=str, help='path to save results of datasets') 34 | args, _ = parser.parse_known_args() 35 | 36 | print('Model architecture: {}'.format(args.arch)) 37 | print('Model path: {}'.format(args.model_path)) 38 | print('Test dataset: {}, set: {}'.format(args.dataset, args.test_set)) 39 | print('Test choice: {}'.format(args.test_choice)) 40 | if args.dataset != 'SAMPLES': 41 | print('Save results to {}'.format(args.test_result_dir)) 42 | else: 43 | print('Save alpha results to {}'.format(SAMPLES_RESULT_ALPHA_PATH)) 44 | print('Save color results to {}'.format(SAMPLES_RESULT_COLOR_PATH)) 45 | 46 | return args 47 | 48 | def inference_once(args, model, scale_img, scale_trimap=None): 49 | if torch.cuda.device_count() > 0: 50 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1).cuda() 51 | else: 52 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1) 53 | input_t = tensor_img 54 | input_t = input_t/255.0 55 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 56 | std=[0.229, 0.224, 0.225]) 57 | input_t = normalize(input_t) 58 | input_t = input_t.unsqueeze(0) 59 | pred_global, pred_local, pred_fusion = model(input_t)[:3] 60 | pred_global = pred_global.data.cpu().numpy() 61 | pred_global = gen_trimap_from_segmap_e2e(pred_global) 62 | pred_local = pred_local.data.cpu().numpy()[0,0,:,:] 63 | pred_fusion = pred_fusion.data.cpu().numpy()[0,0,:,:] 64 | return pred_global, pred_local, pred_fusion 65 | 66 | def inference_img_p3m(args, model, img): 67 | h, w, c = img.shape 68 | new_h = min(MAX_SIZE_H, h - (h % 32)) 69 | new_w = min(MAX_SIZE_W, w - (w % 32)) 70 | 71 | # for P3M-Net(Swin-T) model on very small images in RWP 72 | if args.dataset in ['RWP', 'SAMPLES'] and args.arch == 'swin' and (h < MIN_SIZE_H or w < MIN_SIZE_W): 73 | ratioh = float(MIN_SIZE_H)/float(h) 74 | ratiow = float(MIN_SIZE_W)/float(w) 75 | ratio = max(ratioh, ratiow) 76 | new_h = int(ratio*h) 77 | new_w = int(ratio*w) 78 | new_h = min(MAX_SIZE_H, new_h - (new_h % 32)) 79 | new_w = min(MAX_SIZE_W, new_w - (new_w % 32)) 80 | scale_img = resize(img, (new_h,new_w))*255.0 81 | print(scale_img.shape) 82 | pred_global, pred_local, pred_fusion = inference_once(args, model, scale_img) 83 | pred_local = resize(pred_local,(h,w)) 84 | pred_global = resize(pred_global,(h,w))*255.0 85 | pred_fusion = resize(pred_fusion,(h,w)) 86 | return pred_fusion 87 | 88 | if args.test_choice=='HYBRID': 89 | global_ratio = 1/2 90 | local_ratio = 1 91 | resize_h = int(h*global_ratio) 92 | resize_w = int(w*global_ratio) 93 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32)) 94 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32)) 95 | scale_img = resize(img,(new_h,new_w))*255.0 96 | pred_coutour_1, pred_retouching_1, pred_fusion_1 = inference_once(args, model, scale_img) 97 | # torch.cuda.empty_cache() 98 | pred_coutour_1 = resize(pred_coutour_1,(h,w))*255.0 99 | resize_h = int(h*local_ratio) 100 | resize_w = int(w*local_ratio) 101 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32)) 102 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32)) 103 | scale_img = resize(img,(new_h,new_w))*255.0 104 | pred_coutour_2, pred_retouching_2, pred_fusion_2 = inference_once(args, model, scale_img) 105 | # torch.cuda.empty_cache() 106 | pred_retouching_2 = resize(pred_retouching_2,(h,w)) 107 | pred_fusion = get_masked_local_from_global_test(pred_coutour_1, pred_retouching_2) 108 | return pred_fusion 109 | elif args.test_choice=='RESIZE': 110 | resize_h = int(h/2) 111 | resize_w = int(w/2) 112 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32)) 113 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32)) 114 | scale_img = resize(img,(new_h,new_w))*255.0 115 | pred_global, pred_local, pred_fusion = inference_once(args, model, scale_img) 116 | pred_local = resize(pred_local,(h,w)) 117 | pred_global = resize(pred_global,(h,w))*255.0 118 | pred_fusion = resize(pred_fusion,(h,w)) 119 | return pred_fusion 120 | else: 121 | raise NotImplementedError 122 | 123 | def test_dataset(args, model): 124 | if torch.cuda.device_count() > 0: 125 | torch.cuda.empty_cache() 126 | else: 127 | print('NO GPU AVAILABLE') 128 | return 129 | 130 | ############################ 131 | # Some initial setting for paths 132 | ############################ 133 | ORIGINAL_PATH = DATASET_PATHS_DICT[args.dataset][args.test_set]['ORIGINAL_PATH'] 134 | 135 | ############################ 136 | # Start testing 137 | ############################ 138 | result_dir = args.test_result_dir 139 | refresh_folder(result_dir) 140 | 141 | model.eval() 142 | img_list = listdir_nohidden(ORIGINAL_PATH) 143 | 144 | for img_name in tqdm(img_list): 145 | img_path = ORIGINAL_PATH+img_name 146 | img = np.array(Image.open(img_path)) 147 | img = img[:,:,:3] if img.ndim>2 else img 148 | 149 | with torch.no_grad(): 150 | predict = inference_img_p3m(args, model, img) 151 | save_test_result(os.path.join(result_dir, extract_pure_name(img_name)+'.png'),predict) 152 | 153 | def test_samples(args, model): 154 | model.eval() 155 | img_list = listdir_nohidden(SAMPLES_ORIGINAL_PATH) 156 | refresh_folder(SAMPLES_RESULT_ALPHA_PATH) 157 | refresh_folder(SAMPLES_RESULT_COLOR_PATH) 158 | for img_name in tqdm(img_list): 159 | img_path = SAMPLES_ORIGINAL_PATH+img_name 160 | try: 161 | img = np.array(Image.open(img_path))[:,:,:3] 162 | except Exception as e: 163 | print(f'Error: {str(e)} | Name: {img_name}') 164 | h, w, c = img.shape 165 | if min(h, w)>SHORTER_PATH_LIMITATION: 166 | if h>=w: 167 | new_w = SHORTER_PATH_LIMITATION 168 | new_h = int(SHORTER_PATH_LIMITATION*h/w) 169 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 170 | else: 171 | new_h = SHORTER_PATH_LIMITATION 172 | new_w = int(SHORTER_PATH_LIMITATION*w/h) 173 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 174 | 175 | with torch.no_grad(): 176 | if torch.cuda.device_count() > 0: 177 | torch.cuda.empty_cache() 178 | predict = inference_img_p3m(args, model, img) 179 | 180 | composite = generate_composite_img(img, predict) 181 | cv2.imwrite(os.path.join(SAMPLES_RESULT_COLOR_PATH, extract_pure_name(img_name)+'.png'),composite) 182 | predict = predict*255.0 183 | predict = cv2.resize(predict, (w, h), interpolation=cv2.INTER_LINEAR) 184 | cv2.imwrite(os.path.join(SAMPLES_RESULT_ALPHA_PATH, extract_pure_name(img_name)+'.png'),predict.astype(np.uint8)) 185 | 186 | def load_model_and_deploy(args): 187 | ### build model 188 | model = build_model(args.arch, pretrained=False) 189 | 190 | ### load ckpt 191 | ckpt = torch.load(args.model_path) 192 | model.load_state_dict(ckpt['state_dict'], strict=True) 193 | model = model.cuda() 194 | 195 | ### Test 196 | if args.dataset=='SAMPLES': 197 | test_samples(args, model) 198 | elif args.dataset in ['P3M10K', 'RWP']: 199 | test_dataset(args, model) 200 | else: 201 | print('Please input the correct dataset_choice (SAMPLES, P3M10K or RWP).') 202 | 203 | if __name__ == '__main__': 204 | args = get_args() 205 | load_model_and_deploy(args) 206 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | from termcolor import colored 2 | import logging 3 | import sys 4 | 5 | 6 | def create_logger(filename): 7 | # create logger 8 | logger = logging.getLogger('Logger') 9 | logger.setLevel(logging.DEBUG) 10 | logger.propagate = False 11 | 12 | # create formatter 13 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 14 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 15 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 16 | 17 | console_handler = logging.StreamHandler(sys.stdout) 18 | console_handler.setLevel(logging.DEBUG) 19 | console_handler.setFormatter( 20 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 21 | logger.addHandler(console_handler) 22 | 23 | # create file handlers 24 | file_handler = logging.FileHandler(filename, mode='a') 25 | file_handler.setLevel(logging.DEBUG) 26 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 27 | logger.addHandler(file_handler) 28 | 29 | return logger -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | 4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 5 | Licensed under the MIT License (see LICENSE for details) 6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 7 | Paper link: https://arxiv.org/abs/2203.16828 8 | 9 | """ 10 | 11 | import numpy as np 12 | 13 | ############################## 14 | ### Test loss for matting 15 | ############################## 16 | 17 | def calculate_sad_mse_mad(predict_old,alpha,trimap): 18 | predict = np.copy(predict_old) 19 | pixel = float((trimap == 128).sum()) 20 | predict[trimap == 255] = 1. 21 | predict[trimap == 0 ] = 0. 22 | sad_diff = np.sum(np.abs(predict - alpha))/1000 23 | if pixel==0: 24 | pixel = trimap.shape[0]*trimap.shape[1]-float((trimap==255).sum())-float((trimap==0).sum()) 25 | mse_diff = np.sum((predict - alpha) ** 2)/pixel 26 | mad_diff = np.sum(np.abs(predict - alpha))/pixel 27 | return sad_diff, mse_diff, mad_diff 28 | 29 | def calculate_sad_mse_mad_whole_img(predict, alpha): 30 | pixel = predict.shape[0]*predict.shape[1] 31 | sad_diff = np.sum(np.abs(predict - alpha))/1000 32 | mse_diff = np.sum((predict - alpha) ** 2)/pixel 33 | mad_diff = np.sum(np.abs(predict - alpha))/pixel 34 | return sad_diff, mse_diff, mad_diff 35 | 36 | def calculate_sad_fgbg(predict, alpha, trimap): 37 | sad_diff = np.abs(predict-alpha) 38 | weight_fg = np.zeros(predict.shape) 39 | weight_bg = np.zeros(predict.shape) 40 | weight_trimap = np.zeros(predict.shape) 41 | weight_fg[trimap==255] = 1. 42 | weight_bg[trimap==0 ] = 1. 43 | weight_trimap[trimap==128 ] = 1. 44 | sad_fg = np.sum(sad_diff*weight_fg)/1000 45 | sad_bg = np.sum(sad_diff*weight_bg)/1000 46 | sad_trimap = np.sum(sad_diff*weight_trimap)/1000 47 | return sad_fg, sad_bg 48 | 49 | def compute_gradient_whole_image(pd, gt): 50 | from scipy.ndimage import gaussian_filter 51 | pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32) 52 | pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32) 53 | gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32) 54 | gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32) 55 | pd_mag = np.sqrt(pd_x**2 + pd_y**2) 56 | gt_mag = np.sqrt(gt_x**2 + gt_y**2) 57 | 58 | error_map = np.square(pd_mag - gt_mag) 59 | loss = np.sum(error_map) / 10 60 | return loss 61 | 62 | def compute_connectivity_loss_whole_image(pd, gt, step=0.1): 63 | from scipy.ndimage import morphology 64 | from skimage.measure import label, regionprops 65 | h, w = pd.shape 66 | thresh_steps = np.arange(0, 1.1, step) 67 | l_map = -1 * np.ones((h, w), dtype=np.float32) 68 | lambda_map = np.ones((h, w), dtype=np.float32) 69 | for i in range(1, thresh_steps.size): 70 | pd_th = pd >= thresh_steps[i] 71 | gt_th = gt >= thresh_steps[i] 72 | label_image = label(pd_th & gt_th, connectivity=1) 73 | cc = regionprops(label_image) 74 | size_vec = np.array([c.area for c in cc]) 75 | if len(size_vec) == 0: 76 | continue 77 | max_id = np.argmax(size_vec) 78 | coords = cc[max_id].coords 79 | omega = np.zeros((h, w), dtype=np.float32) 80 | omega[coords[:, 0], coords[:, 1]] = 1 81 | flag = (l_map == -1) & (omega == 0) 82 | l_map[flag == 1] = thresh_steps[i-1] 83 | dist_maps = morphology.distance_transform_edt(omega==0) 84 | dist_maps = dist_maps / dist_maps.max() 85 | l_map[l_map == -1] = 1 86 | d_pd = pd - l_map 87 | d_gt = gt - l_map 88 | phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32) 89 | phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32) 90 | loss = np.sum(np.abs(phi_pd - phi_gt)) / 1000 91 | return loss -------------------------------------------------------------------------------- /core/network/RestNet34/P3mNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | 4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 5 | Licensed under the MIT License (see LICENSE for details) 6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 7 | Paper link: https://arxiv.org/abs/2203.16828 8 | 9 | """ 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from util import get_masked_local_from_global 13 | from ..modules import TFI, SBFI, DBFI 14 | from .resnet_mp import * 15 | 16 | 17 | class P3mNet(nn.Module): 18 | def __init__(self, pretrained=True): 19 | super().__init__() 20 | self.resnet = resnet34_mp(pretrained=pretrained) 21 | ############################ 22 | ### Encoder part - RESNETMP 23 | ############################ 24 | self.encoder0 = nn.Sequential( 25 | self.resnet.conv1, 26 | self.resnet.bn1, 27 | self.resnet.relu, 28 | ) 29 | self.mp0 = self.resnet.maxpool1 30 | self.encoder1 = nn.Sequential( 31 | self.resnet.layer1) 32 | self.mp1 = self.resnet.maxpool2 33 | self.encoder2 = self.resnet.layer2 34 | self.mp2 = self.resnet.maxpool3 35 | self.encoder3 = self.resnet.layer3 36 | self.mp3 = self.resnet.maxpool4 37 | self.encoder4 = self.resnet.layer4 38 | self.mp4 = self.resnet.maxpool5 39 | 40 | self.tfi_3 = TFI(256) 41 | self.tfi_2 = TFI(128) 42 | self.tfi_1 = TFI(64) 43 | self.tfi_0 = TFI(64) 44 | 45 | self.sbfi_2 = SBFI(128, 64, 8) 46 | self.sbfi_1 = SBFI(64, 64, 4) 47 | self.sbfi_0 = SBFI(64, 64, 2) 48 | 49 | self.dbfi_2 = DBFI(128, 512, 4) 50 | self.dbfi_1 = DBFI(64, 512, 8) 51 | self.dbfi_0 = DBFI(64, 512, 16) 52 | 53 | ########################## 54 | ### Decoder part - GLOBAL 55 | ########################## 56 | self.decoder4_g = nn.Sequential( 57 | nn.Conv2d(512,512,3,padding=1), 58 | nn.BatchNorm2d(512), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(512,512,3,padding=1), 61 | nn.BatchNorm2d(512), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(512,256,3,padding=1), 64 | nn.BatchNorm2d(256), 65 | nn.ReLU(inplace=True), 66 | nn.Upsample(scale_factor=2, mode='bilinear') ) 67 | self.decoder3_g = nn.Sequential( 68 | nn.Conv2d(256,256,3,padding=1), 69 | nn.BatchNorm2d(256), 70 | nn.ReLU(inplace=True), 71 | nn.Conv2d(256,256,3,padding=1), 72 | nn.BatchNorm2d(256), 73 | nn.ReLU(inplace=True), 74 | nn.Conv2d(256,128,3,padding=1), 75 | nn.BatchNorm2d(128), 76 | nn.ReLU(inplace=True), 77 | nn.Upsample(scale_factor=2, mode='bilinear') ) 78 | self.decoder2_g = nn.Sequential( 79 | nn.Conv2d(128,128,3,padding=1), 80 | nn.BatchNorm2d(128), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(128,128,3,padding=1), 83 | nn.BatchNorm2d(128), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(128,64,3,padding=1), 86 | nn.BatchNorm2d(64), 87 | nn.ReLU(inplace=True), 88 | nn.Upsample(scale_factor=2, mode='bilinear')) 89 | self.decoder1_g = nn.Sequential( 90 | nn.Conv2d(64,64,3,padding=1), 91 | nn.BatchNorm2d(64), 92 | nn.ReLU(inplace=True), 93 | nn.Conv2d(64,64,3,padding=1), 94 | nn.BatchNorm2d(64), 95 | nn.ReLU(inplace=True), 96 | nn.Conv2d(64,64,3,padding=1), 97 | nn.BatchNorm2d(64), 98 | nn.ReLU(inplace=True), 99 | nn.Upsample(scale_factor=2, mode='bilinear')) 100 | self.decoder0_g = nn.Sequential( 101 | nn.Conv2d(64,64,3,padding=1), 102 | nn.BatchNorm2d(64), 103 | nn.ReLU(inplace=True), 104 | nn.Conv2d(64,64,3,padding=1), 105 | nn.BatchNorm2d(64), 106 | nn.ReLU(inplace=True), 107 | nn.Conv2d(64,3,3,padding=1), 108 | nn.Upsample(scale_factor=2, mode='bilinear')) 109 | 110 | ########################## 111 | ### Decoder part - LOCAL 112 | ########################## 113 | self.decoder4_l = nn.Sequential( 114 | nn.Conv2d(512,512,3,padding=1), 115 | nn.BatchNorm2d(512), 116 | nn.ReLU(inplace=True), 117 | nn.Conv2d(512,512,3,padding=1), 118 | nn.BatchNorm2d(512), 119 | nn.ReLU(inplace=True), 120 | nn.Conv2d(512,256,3,padding=1), 121 | nn.BatchNorm2d(256), 122 | nn.ReLU(inplace=True)) 123 | self.decoder3_l = nn.Sequential( 124 | nn.Conv2d(256,256,3,padding=1), 125 | nn.BatchNorm2d(256), 126 | nn.ReLU(inplace=True), 127 | nn.Conv2d(256,256,3,padding=1), 128 | nn.BatchNorm2d(256), 129 | nn.ReLU(inplace=True), 130 | nn.Conv2d(256,128,3,padding=1), 131 | nn.BatchNorm2d(128), 132 | nn.ReLU(inplace=True)) 133 | self.decoder2_l = nn.Sequential( 134 | nn.Conv2d(128,128,3,padding=1), 135 | nn.BatchNorm2d(128), 136 | nn.ReLU(inplace=True), 137 | nn.Conv2d(128,128,3,padding=1), 138 | nn.BatchNorm2d(128), 139 | nn.ReLU(inplace=True), 140 | nn.Conv2d(128,64,3,padding=1), 141 | nn.BatchNorm2d(64), 142 | nn.ReLU(inplace=True)) 143 | self.decoder1_l = nn.Sequential( 144 | nn.Conv2d(64,64,3,padding=1), 145 | nn.BatchNorm2d(64), 146 | nn.ReLU(inplace=True), 147 | nn.Conv2d(64,64,3,padding=1), 148 | nn.BatchNorm2d(64), 149 | nn.ReLU(inplace=True), 150 | nn.Conv2d(64,64,3,padding=1), 151 | nn.BatchNorm2d(64), 152 | nn.ReLU(inplace=True)) 153 | self.decoder0_l = nn.Sequential( 154 | nn.Conv2d(64,64,3,padding=1), 155 | nn.BatchNorm2d(64), 156 | nn.ReLU(inplace=True), 157 | nn.Conv2d(64,64,3,padding=1), 158 | nn.BatchNorm2d(64), 159 | nn.ReLU(inplace=True)) 160 | self.decoder_final_l = nn.Conv2d(64,1,3,padding=1) 161 | 162 | 163 | def forward(self, input): 164 | ########################## 165 | ### Encoder part - RESNET 166 | ########################## 167 | e0 = self.encoder0(input) 168 | e0p, id0 = self.mp0(e0) 169 | e1p, id1 = self.mp1(e0p) 170 | e1 = self.encoder1(e1p) 171 | e2p, id2 = self.mp2(e1) 172 | e2 = self.encoder2(e2p) 173 | e3p, id3 = self.mp3(e2) 174 | e3 = self.encoder3(e3p) 175 | e4p, id4 = self.mp4(e3) 176 | e4 = self.encoder4(e4p) 177 | ########################### 178 | ### Decoder part - Global 179 | ########################### 180 | d4_g = self.decoder4_g(e4) 181 | d3_g = self.decoder3_g(d4_g) 182 | d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, e4) 183 | d2_g = self.decoder2_g(d2_g) 184 | d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, e4) 185 | d1_g = self.decoder1_g(d1_g) 186 | d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, e4) 187 | d0_g = self.decoder0_g(d0_g) 188 | global_sigmoid = d0_g 189 | ########################### 190 | ### Decoder part - Local 191 | ########################### 192 | d4_l = self.decoder4_l(e4) 193 | d4_l = F.max_unpool2d(d4_l, id4, kernel_size=2, stride=2) 194 | d3_l = self.tfi_3(d4_g, d4_l, e3) 195 | d3_l = self.decoder3_l(d3_l) 196 | d3_l = F.max_unpool2d(d3_l, id3, kernel_size=2, stride=2) 197 | d2_l = self.tfi_2(d3_g, d3_l, e2) 198 | d2_l = self.sbfi_2(d2_l, e0) 199 | d2_l = self.decoder2_l(d2_l) 200 | d2_l = F.max_unpool2d(d2_l, id2, kernel_size=2, stride=2) 201 | d1_l = self.tfi_1(d2_g, d2_l, e1) 202 | d1_l = self.sbfi_1(d1_l, e0) 203 | d1_l = self.decoder1_l(d1_l) 204 | d1_l = F.max_unpool2d(d1_l, id1, kernel_size=2, stride=2) 205 | d0_l = self.tfi_0(d1_g, d1_l, e0p) 206 | d0_l = self.sbfi_0(d0_l, e0) 207 | d0_l = self.decoder0_l(d0_l) 208 | d0_l = F.max_unpool2d(d0_l, id0, kernel_size=2, stride=2) 209 | d0_l = self.decoder_final_l(d0_l) 210 | local_sigmoid = F.sigmoid(d0_l) 211 | ########################## 212 | ### Fusion net - G/L 213 | ########################## 214 | fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid) 215 | return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0 216 | -------------------------------------------------------------------------------- /core/network/RestNet34/__init__.py: -------------------------------------------------------------------------------- 1 | from .P3mNet import P3mNet 2 | 3 | 4 | __all__ = ['p3mnet_resnet34'] 5 | 6 | def p3mnet_resnet34(pretrained=True, **kwargs): 7 | return P3mNet(pretrained=pretrained) -------------------------------------------------------------------------------- /core/network/RestNet34/resnet_mp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Callable, Optional 4 | from config import PRETRAINED_R34_MP 5 | from ..modules import conv3x3, conv1x1 6 | 7 | 8 | class BasicBlock(nn.Module): 9 | expansion: int = 1 10 | 11 | def __init__( 12 | self, 13 | inplanes: int, 14 | planes: int, 15 | stride: int = 1, 16 | downsample: Optional[nn.Module] = None, 17 | groups: int = 1, 18 | base_width: int = 64, 19 | dilation: int = 1, 20 | norm_layer: Optional[Callable[..., nn.Module]] = None 21 | ) -> None: 22 | super(BasicBlock, self).__init__() 23 | if norm_layer is None: 24 | norm_layer = nn.BatchNorm2d 25 | if groups != 1 or base_width != 64: 26 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 27 | if dilation > 1: 28 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 29 | 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = norm_layer(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = norm_layer(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | identity = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | __constants__ = ['downsample'] 59 | 60 | def __init__(self, inplanes, planes,stride=1, downsample=None, groups=1, 61 | base_width=64, dilation=1, norm_layer=None): 62 | super(Bottleneck, self).__init__() 63 | if norm_layer is None: 64 | norm_layer = nn.BatchNorm2d 65 | width = int(planes * (base_width / 64.)) * groups 66 | self.conv1 = conv1x1(inplanes, width) 67 | self.bn1 = norm_layer(width) 68 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 69 | self.bn2 = norm_layer(width) 70 | self.conv3 = conv1x1(width, planes * self.expansion) 71 | self.bn3 = norm_layer(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | identity = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | out = self.attention(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | if self.downsample is not None: 91 | identity = self.downsample(x) 92 | out += identity 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class ResNet(nn.Module): 99 | 100 | def __init__(self, block, layers, zero_init_residual=False, 101 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 102 | norm_layer=None): 103 | super(ResNet, self).__init__() 104 | if norm_layer is None: 105 | norm_layer = nn.BatchNorm2d 106 | self._norm_layer = norm_layer 107 | self.inplanes = 64 108 | self.dilation = 1 109 | if replace_stride_with_dilation is None: 110 | replace_stride_with_dilation = [False, False, False] 111 | if len(replace_stride_with_dilation) != 3: 112 | raise ValueError("replace_stride_with_dilation should be None " 113 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 114 | self.groups = groups 115 | self.base_width = width_per_group 116 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3, 117 | bias=False) 118 | self.bn1 = norm_layer(self.inplanes) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 121 | self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 122 | self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 123 | self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 124 | self.maxpool5 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 125 | #pdb.set_trace() 126 | self.layer1 = self._make_layer(block, 64, layers[0]) 127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=1, 128 | dilate=replace_stride_with_dilation[0]) 129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 130 | dilate=replace_stride_with_dilation[1]) 131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 132 | dilate=replace_stride_with_dilation[2]) 133 | 134 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 135 | self.fc = nn.Linear(512 * block.expansion, 1000) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 141 | nn.init.constant_(m.weight, 1) 142 | nn.init.constant_(m.bias, 0) 143 | if zero_init_residual: 144 | for m in self.modules(): 145 | if isinstance(m, Bottleneck): 146 | nn.init.constant_(m.bn3.weight, 0) 147 | elif isinstance(m, BasicBlock): 148 | nn.init.constant_(m.bn2.weight, 0) 149 | 150 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 151 | norm_layer = self._norm_layer 152 | downsample = None 153 | previous_dilation = self.dilation 154 | if dilate: 155 | self.dilation *= stride 156 | stride = 1 157 | if stride != 1 or self.inplanes != planes * block.expansion: 158 | downsample = nn.Sequential( 159 | conv1x1(self.inplanes, planes * block.expansion, stride), 160 | norm_layer(planes * block.expansion), 161 | ) 162 | 163 | layers = [] 164 | layers.append(block(self.inplanes, planes,stride, downsample, self.groups, 165 | self.base_width, previous_dilation, norm_layer)) 166 | self.inplanes = planes * block.expansion 167 | for _ in range(1, blocks): 168 | layers.append(block(self.inplanes, planes,groups=self.groups, 169 | base_width=self.base_width, dilation=self.dilation, 170 | norm_layer=norm_layer)) 171 | 172 | return nn.Sequential(*layers) 173 | 174 | def _forward_impl(self, x): 175 | x1 = self.conv1(x) 176 | x1 = self.bn1(x1) 177 | x1 = self.relu(x1) 178 | x1, idx1 = self.maxpool1(x1) 179 | 180 | x2, idx2 = self.maxpool2(x1) 181 | x2 = self.layer1(x2) 182 | 183 | x3, idx3 = self.maxpool3(x2) 184 | x3 = self.layer2(x3) 185 | 186 | x4, idx4 = self.maxpool4(x3) 187 | x4 = self.layer3(x4) 188 | 189 | x5, idx5 = self.maxpool5(x4) 190 | x5 = self.layer4(x5) 191 | 192 | x_cls = self.avgpool(x5) 193 | x_cls = torch.flatten(x_cls, 1) 194 | x_cls = self.fc(x_cls) 195 | 196 | return x_cls 197 | 198 | def forward(self, x): 199 | return self._forward_impl(x) 200 | 201 | 202 | def resnet34_mp(pretrained=True, **kwargs): 203 | r"""ResNet-34 model from 204 | `"Deep Residual Learning for Image Recognition" ` 205 | """ 206 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 207 | if pretrained: 208 | checkpoint = torch.load(PRETRAINED_R34_MP) 209 | model.load_state_dict(checkpoint) 210 | print('Loaded pretrained model') 211 | return model 212 | 213 | 214 | -------------------------------------------------------------------------------- /core/network/Swin_T/__init__.py: -------------------------------------------------------------------------------- 1 | from .swin_stem_pooling5_transformer import swin_stem_pooling5_encoder 2 | from .swin_stem_pooling5_transformer import SwinStemPooling5TransformerMatting 3 | 4 | from .decoder import SwinStemPooling5TransformerDecoderV1 5 | 6 | 7 | __all__ = ['p3mnet_swin_t'] 8 | 9 | 10 | def p3mnet_swin_t(pretrained=True, img_size=512, **kwargs): 11 | encoder = swin_stem_pooling5_encoder(pretrained=pretrained, img_size=img_size, **kwargs) 12 | decoder = SwinStemPooling5TransformerDecoderV1() 13 | model = SwinStemPooling5TransformerMatting(encoder, decoder) 14 | return model 15 | -------------------------------------------------------------------------------- /core/network/Swin_T/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | 4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 5 | Licensed under the MIT License (see LICENSE for details) 6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 7 | Paper link: https://arxiv.org/abs/2203.16828 8 | 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from util import get_masked_local_from_global 16 | from ..modules import * 17 | 18 | 19 | class SwinStemPooling5TransformerDecoderV1(nn.Module): 20 | def __init__(self) -> None: 21 | super().__init__() 22 | ########################## 23 | ### Decoder part - GLOBAL 24 | ########################## 25 | self.decoder4_g = nn.Sequential( # 768 -> 384 26 | nn.Conv2d(768,384,3,padding=1), 27 | nn.BatchNorm2d(384), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(384,384,3,padding=1), 30 | nn.BatchNorm2d(384), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(384,384,3,padding=1), 33 | nn.BatchNorm2d(384), 34 | nn.ReLU(inplace=True), 35 | nn.Upsample(scale_factor=2, mode='bilinear') ) 36 | self.decoder3_g = nn.Sequential( # 384 -> 192 37 | nn.Conv2d(384,192,3,padding=1), 38 | nn.BatchNorm2d(192), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(192,192,3,padding=1), 41 | nn.BatchNorm2d(192), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(192,192,3,padding=1), 44 | nn.BatchNorm2d(192), 45 | nn.ReLU(inplace=True), 46 | nn.Upsample(scale_factor=2, mode='bilinear') ) 47 | self.decoder2_g = nn.Sequential( # 192 -> 96 48 | nn.Conv2d(192,96,3,padding=1), 49 | nn.BatchNorm2d(96), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(96,96,3,padding=1), 52 | nn.BatchNorm2d(96), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(96,96,3,padding=1), 55 | nn.BatchNorm2d(96), 56 | nn.ReLU(inplace=True), 57 | nn.Upsample(scale_factor=2, mode='bilinear')) 58 | self.decoder1_g = nn.Sequential( # 96 -> 96 59 | nn.Conv2d(96,96,3,padding=1), 60 | nn.BatchNorm2d(96), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(96,96,3,padding=1), 63 | nn.BatchNorm2d(96), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(96,96,3,padding=1), 66 | nn.BatchNorm2d(96), 67 | nn.ReLU(inplace=True), 68 | nn.Upsample(scale_factor=2, mode='bilinear')) 69 | self.decoder0_g = nn.Sequential( # 96 -> 48 70 | nn.Conv2d(96,48,3,padding=1), 71 | nn.BatchNorm2d(48), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(48,48,3,padding=1), 74 | nn.BatchNorm2d(48), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d(48,48,3,padding=1), 77 | nn.Upsample(scale_factor=2, mode='bilinear')) 78 | self.decoder_final_g = nn.Conv2d(48, 3, kernel_size=3, padding=1) # 48 -> 3 79 | 80 | ########################## 81 | ### Decoder part - LOCAL 82 | ########################## 83 | 84 | self.decoder4_l = nn.Sequential( # 768 -> 384 85 | nn.Conv2d(768,384,3,padding=1), 86 | nn.BatchNorm2d(384), 87 | nn.ReLU(inplace=True), 88 | nn.Conv2d(384,384,3,padding=1), 89 | nn.BatchNorm2d(384), 90 | nn.ReLU(inplace=True), 91 | nn.Conv2d(384,384,3,padding=1), 92 | nn.BatchNorm2d(384), 93 | nn.ReLU(inplace=True)) 94 | self.decoder3_l = nn.Sequential( # 384 -> 192 95 | nn.Conv2d(384,192,3,padding=1), 96 | nn.BatchNorm2d(192), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(192,192,3,padding=1), 99 | nn.BatchNorm2d(192), 100 | nn.ReLU(inplace=True), 101 | nn.Conv2d(192,192,3,padding=1), 102 | nn.BatchNorm2d(192), 103 | nn.ReLU(inplace=True)) 104 | self.decoder2_l = nn.Sequential( # 192 -> 96 105 | nn.Conv2d(192,96,3,padding=1), 106 | nn.BatchNorm2d(96), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(96,96,3,padding=1), 109 | nn.BatchNorm2d(96), 110 | nn.ReLU(inplace=True), 111 | nn.Conv2d(96,96,3,padding=1), 112 | nn.BatchNorm2d(96), 113 | nn.ReLU(inplace=True)) 114 | self.decoder1_l = nn.Sequential( # 96 -> 96 115 | nn.Conv2d(96,96,3,padding=1), 116 | nn.BatchNorm2d(96), 117 | nn.ReLU(inplace=True), 118 | nn.Conv2d(96,96,3,padding=1), 119 | nn.BatchNorm2d(96), 120 | nn.ReLU(inplace=True), 121 | nn.Conv2d(96,96,3,padding=1), 122 | nn.BatchNorm2d(96), 123 | nn.ReLU(inplace=True)) 124 | self.decoder0_l = nn.Sequential( # 96 -> 48 125 | nn.Conv2d(96,48,3,padding=1), 126 | nn.BatchNorm2d(48), 127 | nn.ReLU(inplace=True), 128 | nn.Conv2d(48,48,3,padding=1), 129 | nn.BatchNorm2d(48), 130 | nn.ReLU(inplace=True)) 131 | self.decoder_final_l = nn.Conv2d(48,1,3,padding=1) # 16 -> 1 132 | 133 | ########################## 134 | ### Decoder part - MODULES 135 | ########################## 136 | 137 | self.tfi_3 = TFI(384) 138 | self.tfi_2 = TFI(192) 139 | self.tfi_1 = TFI(96) 140 | self.tfi_0 = TFI(96) 141 | 142 | self.sbfi_2 = SBFI(192, 48, 8) 143 | self.sbfi_1 = SBFI(96, 48, 4) 144 | self.sbfi_0 = SBFI(96, 48, 2) 145 | 146 | self.dbfi_2 = DBFI(192, 768, 4) 147 | self.dbfi_1 = DBFI(96, 768, 8) 148 | self.dbfi_0 = DBFI(96, 768, 16) 149 | 150 | def forward(self, x, indices, feas): 151 | r""" 152 | x: [, 768, 16, 16] 153 | 154 | indices: 155 | [None] 156 | [None] 157 | [, 96, 64, 64] 158 | [, 192, 32, 32] 159 | [, 384, 16, 16] 160 | 161 | feas: 162 | [, 3, 512, 512] 163 | [, 48, 256, 256] 164 | [, 96, 128, 128] 165 | [, 192, 64, 64] 166 | [, 384, 32, 32] 167 | """ 168 | ########################### 169 | ### Decoder part - Global 170 | ########################### 171 | d4_g = self.decoder4_g(x) # 768 -> 384 172 | d3_g = self.decoder3_g(d4_g) # 384 -> 192 173 | d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, x) # 192, 768 -> 192 174 | d2_g = self.decoder2_g(d2_g) # 192 -> 96 175 | d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, x) # 96, 768 -> 96 176 | d1_g = self.decoder1_g(d1_g) # 96 -> 96 177 | d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, x) # 96, 768 -> 48 178 | d0_g = self.decoder0_g(d0_g) # 96 -> 48 179 | global_sigmoid = self.decoder_final_g(d0_g) # 48 -> 3 180 | ########################### 181 | ### Decoder part - Local 182 | ########################### 183 | d4_l = self.decoder4_l(x) # 768 -> 384 184 | d4_l = F.max_unpool2d(d4_l, indices[-1], kernel_size=2, stride=2) # 384 185 | d3_l = self.tfi_3(d4_g, d4_l, feas[-1]) # 384, 384, 384 -> 384 186 | d3_l = self.decoder3_l(d3_l) # 384 -> 192 187 | d3_l = F.max_unpool2d(d3_l, indices[-2], kernel_size=2, stride=2) # 192 188 | d2_l = self.tfi_2(d3_g, d3_l, feas[-2]) # 192, 192, 192 -> 192 189 | d2_l = self.sbfi_2(d2_l, feas[-5]) # 192, 3 -> 192 190 | d2_l = self.decoder2_l(d2_l) # 192 -> 96 191 | d2_l = F.max_unpool2d(d2_l, indices[-3], kernel_size=2, stride=2) # 96 192 | d1_l = self.tfi_1(d2_g, d2_l, feas[-3]) # 96, 96, 96 -> 96 193 | d1_l = self.sbfi_1(d1_l, feas[-5]) # 96, 3 -> 96 194 | d1_l = self.decoder1_l(d1_l) # 96 -> 96 195 | d1_l = F.max_unpool2d(d1_l, indices[-4], kernel_size=2, stride=2) # 96 196 | d0_l = self.tfi_0(d1_g, d1_l, feas[-4]) # 96, 96, 96 -> 96 197 | d0_l = self.sbfi_0(d0_l, feas[-5]) # 96 198 | d0_l = self.decoder0_l(d0_l) # 96 -> 48 199 | d0_l = F.max_unpool2d(d0_l, indices[-5], kernel_size=2, stride=2) # 48 200 | d0_l = self.decoder_final_l(d0_l) # 48 -> 1 201 | local_sigmoid = torch.sigmoid(d0_l) # 1 202 | fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid) 203 | return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0 204 | -------------------------------------------------------------------------------- /core/network/ViTAE_S/NormalCell.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. 2 | # 3 | # This source code is licensed under the Clear BSD License 4 | # LICENSE file in the root directory of this file 5 | # All rights reserved. 6 | """ 7 | Borrow from timm(https://github.com/rwightman/pytorch-image-models) 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | from timm.models.layers import DropPath 13 | from .SELayer import SELayer 14 | import math 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.hidden_features = hidden_features 22 | self.fc1 = nn.Linear(in_features, hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Linear(hidden_features, out_features) 25 | self.drop = nn.Dropout(drop) 26 | 27 | def forward(self, x): 28 | x = self.fc1(x) 29 | x = self.act(x) 30 | x = self.drop(x) 31 | x = self.fc2(x) 32 | x = self.drop(x) 33 | return x 34 | 35 | class Attention(nn.Module): 36 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 37 | super().__init__() 38 | self.num_heads = num_heads 39 | head_dim = dim // num_heads 40 | 41 | self.scale = qk_scale or head_dim ** -0.5 42 | 43 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 44 | self.attn_drop = nn.Dropout(attn_drop) 45 | self.proj = nn.Linear(dim, dim) 46 | self.proj_drop = nn.Dropout(proj_drop) 47 | 48 | def forward(self, x): 49 | B, N, C = x.shape 50 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 51 | q, k, v = qkv[0], qkv[1], qkv[2] 52 | 53 | attn = (q @ k.transpose(-2, -1)) * self.scale 54 | attn = attn.softmax(dim=-1) 55 | attn = self.attn_drop(attn) 56 | 57 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 58 | x = self.proj(x) 59 | x = self.proj_drop(x) 60 | return x 61 | 62 | class AttentionPerformer(nn.Module): 63 | def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., kernel_ratio=0.5): 64 | super().__init__() 65 | self.head_dim = dim // num_heads 66 | self.emb = dim 67 | self.kqv = nn.Linear(dim, 3 * self.emb) 68 | self.dp = nn.Dropout(proj_drop) 69 | self.proj = nn.Linear(self.emb, self.emb) 70 | self.head_cnt = num_heads 71 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 72 | self.epsilon = 1e-8 # for stable in division 73 | self.drop_path = nn.Identity() 74 | 75 | self.m = int(self.head_dim * kernel_ratio) 76 | self.w = torch.randn(self.head_cnt, self.m, self.head_dim) 77 | for i in range(self.head_cnt): 78 | self.w[i] = nn.Parameter(nn.init.orthogonal_(self.w[i]) * math.sqrt(self.m), requires_grad=False) 79 | self.w.requires_grad_(False) 80 | 81 | def prm_exp(self, x): 82 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 83 | # and Simo Ryu (https://github.com/cloneofsimo) 84 | # ==== positive random features for gaussian kernels ==== 85 | # x = (B, T, hs) 86 | # w = (m, hs) 87 | # return : x : B, T, m 88 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] 89 | # therefore return exp(w^Tx - |x|/2)/sqrt(m) 90 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, 1, self.m) / 2 91 | wtx = torch.einsum('bhti,hmi->bhtm', x.float(), self.w.to(x.device)) 92 | 93 | return torch.exp(wtx - xd) / math.sqrt(self.m) 94 | 95 | def attn(self, x): 96 | B, N, C = x.shape 97 | kqv = self.kqv(x).reshape(B, N, 3, self.head_cnt, self.head_dim).permute(2, 0, 3, 1, 4) 98 | k, q, v = kqv[0], kqv[1], kqv[2] # (B, H, T, hs) 99 | 100 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, H, T, m), (B, H, T, m) 101 | D = torch.einsum('bhti,bhi->bht', qp, kp.sum(dim=2)).unsqueeze(dim=-1) # (B, H, T, m) * (B, H, m) -> (B, H, T, 1) 102 | kptv = torch.einsum('bhin,bhim->bhnm', v.float(), kp) # (B, H, emb, m) 103 | y = torch.einsum('bhti,bhni->bhtn', qp, kptv) / (D.repeat(1, 1, 1, self.head_dim) + self.epsilon) # (B, H, T, emb)/Diag 104 | 105 | # skip connection 106 | 107 | y = y.permute(0, 2, 1, 3).reshape(B, N, self.emb) 108 | y = self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection 109 | 110 | return y 111 | 112 | def forward(self, x): 113 | x = self.attn(x) 114 | return x 115 | 116 | class NormalCell(nn.Module): 117 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 118 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, class_token=False, group=64, tokens_type='transformer', gamma=False, init_values=1e-4, SE=False): 119 | super().__init__() 120 | self.norm1 = norm_layer(dim) 121 | self.class_token = class_token 122 | if tokens_type == 'transformer': 123 | self.attn = Attention( 124 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 125 | elif tokens_type == 'performer': 126 | self.attn = AttentionPerformer( 127 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 128 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 129 | self.norm2 = norm_layer(dim) 130 | mlp_hidden_dim = int(dim * mlp_ratio) 131 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 132 | self.PCM = nn.Sequential( 133 | nn.Conv2d(dim, mlp_hidden_dim, 3, 1, 1, 1, group), 134 | nn.BatchNorm2d(mlp_hidden_dim), 135 | nn.SiLU(inplace=True), 136 | nn.Conv2d(mlp_hidden_dim, dim, 3, 1, 1, 1, group), 137 | nn.BatchNorm2d(dim), 138 | nn.SiLU(inplace=True), 139 | nn.Conv2d(dim, dim, 3, 1, 1, 1, group), 140 | nn.SiLU(inplace=True), 141 | ) 142 | if gamma: 143 | self.gamma1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 144 | self.gamma2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 145 | self.gamma3 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 146 | else: 147 | self.gamma1 = 1 148 | self.gamma2 = 1 149 | self.gamma3 = 1 150 | if SE: 151 | self.SE = SELayer(dim) 152 | else: 153 | self.SE = nn.Identity() 154 | 155 | def forward(self, x, input_resolution=None): 156 | b, n, c = x.shape 157 | if self.class_token: 158 | n = n - 1 159 | wh = input_resolution[0] if input_resolution is not None else int(math.sqrt(n)) 160 | ww = input_resolution[1] if input_resolution is not None else int(math.sqrt(n)) 161 | convX = self.drop_path(self.gamma2 * self.PCM(x[:, 1:, :].view(b, wh, ww, c).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous().view(b, n, c)) 162 | x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) 163 | x[:, 1:] = x[:, 1:] + convX 164 | else: 165 | wh = input_resolution[0] if input_resolution is not None else int(math.sqrt(n)) 166 | ww = input_resolution[1] if input_resolution is not None else int(math.sqrt(n)) 167 | convX = self.drop_path(self.gamma2 * self.PCM(x.view(b, wh, ww, c).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous().view(b, n, c)) 168 | x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) 169 | x = x + convX 170 | x = x + self.drop_path(self.gamma3 * self.mlp(self.norm2(x))) 171 | x = self.SE(x) 172 | return x 173 | 174 | def get_sinusoid_encoding(n_position, d_hid): 175 | ''' Sinusoid position encoding table ''' 176 | 177 | def get_position_angle_vec(position): 178 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 179 | 180 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 181 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 182 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 183 | 184 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) -------------------------------------------------------------------------------- /core/network/ViTAE_S/SELayer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | 4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 5 | Licensed under the MIT License (see LICENSE for details) 6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 7 | Paper link: https://arxiv.org/abs/2203.16828 8 | 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | class SELayer(nn.Module): 15 | def __init__(self, channel, reduction=16): 16 | super(SELayer, self).__init__() 17 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 18 | self.fc = nn.Sequential( 19 | nn.Linear(channel, channel // reduction, bias=False), 20 | nn.ReLU(inplace=True), 21 | nn.Linear(channel // reduction, channel, bias=False), 22 | nn.Sigmoid() 23 | ) 24 | 25 | def forward(self, x): # x: [B, N, C] 26 | x = torch.transpose(x, 1, 2) # [B, C, N] 27 | b, c, _ = x.size() 28 | y = self.avg_pool(x).view(b, c) 29 | y = self.fc(y).view(b, c, 1) 30 | x = x * y.expand_as(x) 31 | x = torch.transpose(x, 1, 2) # [B, N, C] 32 | return x -------------------------------------------------------------------------------- /core/network/ViTAE_S/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ViTAE_noRC_MaxPooling_bias_basic_stages4_14 2 | from .models import ViTAE_noRC_MaxPooling_DecoderV1 3 | from .models import ViTAE_noRC_MaxPooling_Matting 4 | 5 | 6 | __all__ = ['p3mnet_vitae_s'] 7 | 8 | 9 | def p3mnet_vitae_s(pretrained=True, **kwargs): 10 | encoder = ViTAE_noRC_MaxPooling_bias_basic_stages4_14(pretrained=pretrained, **kwargs) 11 | decoder = ViTAE_noRC_MaxPooling_DecoderV1() 12 | model = ViTAE_noRC_MaxPooling_Matting(encoder, decoder) 13 | return model 14 | -------------------------------------------------------------------------------- /core/network/ViTAE_S/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | 4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 5 | Licensed under the MIT License (see LICENSE for details) 6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 7 | Paper link: https://arxiv.org/abs/2203.16828 8 | 9 | """ 10 | 11 | from functools import partial 12 | import torch 13 | import torch.nn as nn 14 | from timm.models.layers import trunc_normal_ 15 | import numpy as np 16 | from .NormalCell import NormalCell 17 | from timm.models.layers import to_2tuple, trunc_normal_ 18 | 19 | class PatchMerging(nn.Module): 20 | r""" Patch Merging Layer. 21 | 22 | Args: 23 | input_resolution (tuple[int]): Resolution of input feature. 24 | dim (int): Number of input channels. 25 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 26 | """ 27 | 28 | def __init__(self, input_resolution, in_dim, out_dim, norm_layer=nn.LayerNorm): 29 | super().__init__() 30 | self.input_resolution = input_resolution 31 | self.pooling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 32 | self.in_dim = in_dim 33 | self.out_dim = out_dim 34 | # self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 35 | self.norm = norm_layer(in_dim) 36 | self.linear = nn.Linear(in_dim, out_dim, bias=False) 37 | 38 | def forward(self, x, input_resolution=None): 39 | """ 40 | x: B, H*W, C 41 | """ 42 | if input_resolution is None: 43 | input_resolution = self.input_resolution 44 | H, W = input_resolution 45 | B, L, C = x.shape 46 | assert L == H * W, "input feature has wrong size" 47 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 48 | 49 | x = self.norm(x) 50 | x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() 51 | 52 | x_fea = x # 64, 128, 256 53 | x, idx = self.pooling(x) 54 | x = x.permute(0, 2, 3, 1).reshape(B, -1, C).contiguous() 55 | x = self.linear(x) 56 | 57 | return x, [idx], [x_fea] 58 | 59 | def extra_repr(self) -> str: 60 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 61 | 62 | def flops(self): 63 | H, W = self.input_resolution 64 | flops = H * W * self.dim 65 | flops += (H // 2) * (W // 2) * self.dim * 2 * self.dim * 2 66 | return flops 67 | 68 | 69 | class PatchEmbed(nn.Module): 70 | r""" Image to Patch Embedding 71 | 72 | Args: 73 | img_size (int): Image size. Default: 224. 74 | patch_size (int): Patch token size. Default: 4. 75 | in_chans (int): Number of input image channels. Default: 3. 76 | embed_dim (int): Number of linear projection output channels. Default: 96. 77 | norm_layer (nn.Module, optional): Normalization layer. Default: None 78 | """ 79 | 80 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 81 | super().__init__() 82 | _patch_size = patch_size 83 | img_size = to_2tuple(img_size) 84 | patch_size = to_2tuple(patch_size) 85 | strides = [] 86 | while _patch_size > 1: 87 | strides.append(2) 88 | _patch_size = _patch_size // 2 89 | if len(strides) < 3: 90 | strides.append(1) 91 | self.strides = strides 92 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 93 | self.img_size = img_size 94 | self.patch_size = patch_size 95 | self.patches_resolution = patches_resolution 96 | self.num_patches = patches_resolution[0] * patches_resolution[1] 97 | 98 | self.in_chans = in_chans 99 | self.inter_chans = embed_dim // 2 100 | self.embed_dim = embed_dim 101 | # self.proj = nn.Sequential( 102 | # nn.Conv2d(self.in_chans, self.inter_chans, 3, stride=strides[0], padding=1), 103 | # nn.ReLU(), 104 | # nn.Conv2d(self.inter_chans, embed_dim, 3, stride=strides[1], padding=1), 105 | # nn.ReLU(), 106 | # nn.Conv2d(embed_dim, embed_dim, 3, stride=strides[2], padding=1) 107 | # ) 108 | 109 | self.proj1 = nn.Conv2d(self.in_chans, self.inter_chans, 3, stride=1, padding=1) 110 | self.relu1 = nn.ReLU() 111 | self.proj2 = nn.Conv2d(self.inter_chans, embed_dim, 3, stride=1, padding=1) 112 | self.relu2 = nn.ReLU() 113 | self.proj3 = nn.Conv2d(embed_dim, embed_dim, 3, stride=1, padding=1) 114 | 115 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 116 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 117 | 118 | if norm_layer is not None: 119 | self.norm = norm_layer(embed_dim) 120 | else: 121 | self.norm = None 122 | 123 | def forward(self, x, input_resolution=None): 124 | x = self.relu1(self.proj1(x)) 125 | 126 | x_fea1 = x # 32 127 | x, idx1 = self.maxpool1(x) 128 | 129 | x = self.relu2(self.proj2(x)) 130 | x_fea2 = x # 64 131 | x, idx2 = self.maxpool2(x) 132 | 133 | x = self.proj3(x) 134 | x = x.flatten(2).transpose(1, 2) # B Ph*Pw C 135 | if self.norm is not None: 136 | x = self.norm(x) 137 | return x, [idx1, idx2], [x_fea1, x_fea2] 138 | 139 | 140 | class BasicLayer(nn.Module): 141 | def __init__(self, img_size=224, in_chans=3, embed_dims=64, token_dims=64, downsample_ratios=4, kernel_size=7, RC_heads=1, NC_heads=6, dilations=[1, 2, 3, 4], 142 | RC_op='cat', RC_tokens_type='performer', NC_tokens_type='transformer', RC_group=1, NC_group=64, NC_depth=2, dpr=0.1, mlp_ratio=4., qkv_bias=True, 143 | qk_scale=None, drop=0, attn_drop=0., norm_layer=nn.LayerNorm, class_token=False, gamma=False, init_values=1e-4, SE=False): 144 | super().__init__() 145 | self.img_size = to_2tuple(img_size) 146 | self.in_chans = in_chans 147 | self.embed_dims = embed_dims 148 | self.token_dims = token_dims 149 | self.downsample_ratios = downsample_ratios 150 | self.out_size = to_2tuple(img_size // self.downsample_ratios) 151 | self.RC_kernel_size = kernel_size 152 | self.RC_heads = RC_heads 153 | self.NC_heads = NC_heads 154 | self.dilations = dilations 155 | self.RC_op = RC_op 156 | self.RC_tokens_type = RC_tokens_type 157 | self.RC_group = RC_group 158 | self.NC_group = NC_group 159 | self.NC_depth = NC_depth 160 | if downsample_ratios > 2: 161 | self.RC = PatchEmbed(img_size=self.img_size, embed_dim=token_dims, norm_layer=nn.LayerNorm) 162 | elif downsample_ratios == 2: 163 | self.RC = PatchMerging(input_resolution=self.img_size, in_dim=in_chans, out_dim=token_dims) 164 | else: 165 | self.RC = nn.Identity() 166 | 167 | self.NC = nn.ModuleList([ 168 | NormalCell(token_dims, NC_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, 169 | drop_path=dpr[i] if isinstance(dpr, list) else dpr, norm_layer=norm_layer, class_token=class_token, group=NC_group, tokens_type=NC_tokens_type, 170 | gamma=gamma, init_values=init_values, SE=SE) 171 | for i in range(NC_depth)]) 172 | 173 | def forward(self, x, input_resolution=None): 174 | if input_resolution is None: 175 | input_resolution = self.img_size 176 | x, indices, feas = self.RC(x, input_resolution=input_resolution) 177 | input_resolution = [input_resolution[0]//self.downsample_ratios, input_resolution[1]//self.downsample_ratios] 178 | for nc in self.NC: 179 | x = nc(x, input_resolution=input_resolution) 180 | return x, indices, feas 181 | 182 | 183 | class ViTAE_noRC_MaxPooling_basic(nn.Module): 184 | def __init__(self, img_size=224, in_chans=3, stages=4, embed_dims=64, token_dims=64, downsample_ratios=[4, 2, 2, 2], kernel_size=[7, 3, 3, 3], 185 | RC_heads=[1, 1, 1, 1], NC_heads=4, dilations=[[1, 2, 3, 4], [1, 2, 3], [1, 2], [1, 2]], 186 | RC_op='cat', RC_tokens_type=['performer', 'transformer', 'transformer', 'transformer'], NC_tokens_type='transformer', 187 | RC_group=[1, 1, 1, 1], NC_group=[1, 32, 64, 64], NC_depth=[2, 2, 6, 2], mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., 188 | attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=1000, 189 | gamma=False, init_values=1e-4, SE=False): 190 | super().__init__() 191 | self.num_classes = num_classes 192 | self.stages = stages 193 | repeatOrNot = (lambda x, y, z=list: x if isinstance(x, z) else [x for _ in range(y)]) 194 | self.embed_dims = repeatOrNot(embed_dims, stages) 195 | self.tokens_dims = token_dims if isinstance(token_dims, list) else [token_dims * (2 ** i) for i in range(stages)] 196 | self.downsample_ratios = repeatOrNot(downsample_ratios, stages) 197 | self.kernel_size = repeatOrNot(kernel_size, stages) 198 | self.RC_heads = repeatOrNot(RC_heads, stages) 199 | self.NC_heads = repeatOrNot(NC_heads, stages) 200 | self.dilaions = repeatOrNot(dilations, stages) 201 | self.RC_op = repeatOrNot(RC_op, stages) 202 | self.RC_tokens_type = repeatOrNot(RC_tokens_type, stages) 203 | self.NC_tokens_type = repeatOrNot(NC_tokens_type, stages) 204 | self.RC_group = repeatOrNot(RC_group, stages) 205 | self.NC_group = repeatOrNot(NC_group, stages) 206 | self.NC_depth = repeatOrNot(NC_depth, stages) 207 | self.mlp_ratio = repeatOrNot(mlp_ratio, stages) 208 | self.qkv_bias = repeatOrNot(qkv_bias, stages) 209 | self.qk_scale = repeatOrNot(qk_scale, stages) 210 | self.drop = repeatOrNot(drop_rate, stages) 211 | self.attn_drop = repeatOrNot(attn_drop_rate, stages) 212 | self.norm_layer = repeatOrNot(norm_layer, stages) 213 | 214 | self.pos_drop = nn.Dropout(p=drop_rate) 215 | depth = np.sum(self.NC_depth) 216 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 217 | Layers = [] 218 | for i in range(stages): 219 | startDpr = 0 if i==0 else self.NC_depth[i - 1] 220 | Layers.append( 221 | BasicLayer(img_size, in_chans, self.embed_dims[i], self.tokens_dims[i], self.downsample_ratios[i], 222 | self.kernel_size[i], self.RC_heads[i], self.NC_heads[i], self.dilaions[i], self.RC_op[i], 223 | self.RC_tokens_type[i], self.NC_tokens_type[i], self.RC_group[i], self.NC_group[i], self.NC_depth[i], dpr[startDpr:self.NC_depth[i]+startDpr], 224 | mlp_ratio=self.mlp_ratio[i], qkv_bias=self.qkv_bias[i], qk_scale=self.qk_scale[i], drop=self.drop[i], attn_drop=self.attn_drop[i], 225 | norm_layer=self.norm_layer[i], gamma=gamma, init_values=init_values, SE=SE) 226 | ) 227 | img_size = img_size // self.downsample_ratios[i] 228 | in_chans = self.tokens_dims[i] 229 | self.layers = nn.ModuleList(Layers) 230 | 231 | # Classifier head 232 | self.head = nn.Linear(self.tokens_dims[-1], num_classes) if num_classes > 0 else nn.Identity() 233 | 234 | self.apply(self._init_weights) 235 | # load state dict here TODO 236 | 237 | def _init_weights(self, m): 238 | if isinstance(m, nn.Linear): 239 | trunc_normal_(m.weight, std=.02) 240 | if isinstance(m, nn.Linear) and m.bias is not None: 241 | nn.init.constant_(m.bias, 0) 242 | elif isinstance(m, nn.LayerNorm): 243 | nn.init.constant_(m.bias, 0) 244 | nn.init.constant_(m.weight, 1.0) 245 | 246 | @torch.jit.ignore 247 | def no_weight_decay(self): 248 | return {'cls_token'} 249 | 250 | def get_classifier(self): 251 | return self.head 252 | 253 | def reset_classifier(self, num_classes): 254 | self.num_classes = num_classes 255 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 256 | 257 | def forward_features(self, x): 258 | indices = [] 259 | feas = [] 260 | B, C, H, W = x.shape 261 | input_resolution = [H, W] 262 | 263 | for layer_idx in range(len(self.layers)): 264 | x, idx, fea = self.layers[layer_idx](x, input_resolution=input_resolution) 265 | indices = indices + idx 266 | feas = feas + fea 267 | input_resolution = [input_resolution[0]//self.downsample_ratios[layer_idx], input_resolution[1]//self.downsample_ratios[layer_idx]] 268 | 269 | # x = x.view(B, -1, input_resolution[0], input_resolution[1]).contiguous() 270 | x = x.view(B, input_resolution[0], input_resolution[1], -1).permute(0,3,1,2).contiguous() 271 | return x, indices, feas 272 | 273 | def forward(self, x): 274 | return self.forward_features(x) 275 | 276 | def train(self, mode=True, tag='default'): 277 | self.training = mode 278 | if tag == 'default': 279 | for module in self.modules(): 280 | if module.__class__.__name__ != 'ViTAE_noRC_MaxPooling_basic': 281 | module.train(mode) 282 | elif tag == 'linear': 283 | for module in self.modules(): 284 | if module.__class__.__name__ != 'ViTAE_noRC_MaxPooling_basic': 285 | module.eval() 286 | for param in module.parameters(): 287 | param.requires_grad = False 288 | elif tag == 'linearLNBN': 289 | for module in self.modules(): 290 | if module.__class__.__name__ != 'ViTAE_noRC_MaxPooling_basic': 291 | if isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm2d): 292 | module.train(mode) 293 | for param in module.parameters(): 294 | param.requires_grad = True 295 | else: 296 | module.eval() 297 | for param in module.parameters(): 298 | param.requires_grad = False 299 | self.head.train(mode) 300 | for param in self.head.parameters(): 301 | param.requires_grad = True 302 | return self 303 | -------------------------------------------------------------------------------- /core/network/ViTAE_S/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. 2 | # 3 | # This source code is licensed under the Clear BSD License 4 | # LICENSE file in the root directory of this file 5 | # All rights reserved. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from timm.models.registry import register_model 11 | 12 | from .base_model import ViTAE_noRC_MaxPooling_basic 13 | from ..modules import TFI, SBFI, DBFI 14 | from util import get_masked_local_from_global 15 | from config import PRETRAINED_VITAE_NORC_MAXPOOLING_BIAS_BASIC_STAGE4_14 16 | 17 | 18 | ########################################## 19 | ## Encoder 20 | ########################################## 21 | 22 | def _cfg(url='', **kwargs): 23 | return { 24 | 'url': url, 25 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 26 | 'crop_pct': .9, 'interpolation': 'bicubic', 27 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 28 | 'classifier': 'head', 29 | **kwargs 30 | } 31 | 32 | default_cfgs = { 33 | 'ViTAE_stages3_7': _cfg(), 34 | } 35 | 36 | @register_model 37 | def ViTAE_noRC_MaxPooling_bias_basic_stages4_14(pretrained=True, **kwargs): # adopt performer for tokens to token 38 | model = ViTAE_noRC_MaxPooling_basic(RC_tokens_type=['performer', 'transformer', 'transformer', 'transformer'], NC_tokens_type=['performer', 'transformer', 'transformer', 'transformer'], stages=4, embed_dims=[64, 64, 128, 256], token_dims=[64, 128, 256, 512], downsample_ratios=[4, 2, 2, 2], 39 | NC_depth=[2, 2, 12, 2], NC_heads=[1, 2, 4, 8], RC_heads=[1, 1, 2, 4], mlp_ratio=4., NC_group=[1, 32, 64, 128], RC_group=[1, 16, 32, 64], **kwargs) 40 | model.default_cfg = default_cfgs['ViTAE_stages3_7'] 41 | if pretrained: 42 | ckpt = torch.load(PRETRAINED_VITAE_NORC_MAXPOOLING_BIAS_BASIC_STAGE4_14)['state_dict_ema'] 43 | model.load_state_dict(ckpt, strict=True) 44 | return model 45 | 46 | 47 | ########################################## 48 | ## Decoder 49 | ########################################## 50 | 51 | class ViTAE_noRC_MaxPooling_DecoderV1(nn.Module): 52 | def __init__(self): 53 | super().__init__() 54 | ########################## 55 | ### Decoder part - GLOBAL 56 | ########################## 57 | in_chan = 512 58 | out_chan = 256 59 | self.decoder4_g = nn.Sequential( # 512 -> 256 60 | nn.Conv2d(in_chan,out_chan,3,padding=1), 61 | nn.BatchNorm2d(out_chan), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(out_chan,out_chan,3,padding=1), 64 | nn.BatchNorm2d(out_chan), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(out_chan,out_chan,3,padding=1), 67 | nn.BatchNorm2d(out_chan), 68 | nn.ReLU(inplace=True), 69 | nn.Upsample(scale_factor=2, mode='bilinear') ) 70 | 71 | in_chan = 256 72 | out_chan = 128 73 | self.decoder3_g = nn.Sequential( # 256 -> 128 74 | nn.Conv2d(in_chan,out_chan,3,padding=1), 75 | nn.BatchNorm2d(out_chan), 76 | nn.ReLU(inplace=True), 77 | nn.Conv2d(out_chan,out_chan,3,padding=1), 78 | nn.BatchNorm2d(out_chan), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(out_chan,out_chan,3,padding=1), 81 | nn.BatchNorm2d(out_chan), 82 | nn.ReLU(inplace=True), 83 | nn.Upsample(scale_factor=2, mode='bilinear') ) 84 | 85 | in_chan = 128 86 | out_chan = 64 87 | self.decoder2_g = nn.Sequential( # 128 -> 64 88 | nn.Conv2d(in_chan,out_chan,3,padding=1), 89 | nn.BatchNorm2d(out_chan), 90 | nn.ReLU(inplace=True), 91 | nn.Conv2d(out_chan,out_chan,3,padding=1), 92 | nn.BatchNorm2d(out_chan), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(out_chan,out_chan,3,padding=1), 95 | nn.BatchNorm2d(out_chan), 96 | nn.ReLU(inplace=True), 97 | nn.Upsample(scale_factor=2, mode='bilinear')) 98 | 99 | in_chan = 64 100 | out_chan = 64 101 | self.decoder1_g = nn.Sequential( # 64 -> 64 102 | nn.Conv2d(in_chan,out_chan,3,padding=1), 103 | nn.BatchNorm2d(out_chan), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(out_chan,out_chan,3,padding=1), 106 | nn.BatchNorm2d(out_chan), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(out_chan,out_chan,3,padding=1), 109 | nn.BatchNorm2d(out_chan), 110 | nn.ReLU(inplace=True), 111 | nn.Upsample(scale_factor=2, mode='bilinear')) 112 | 113 | in_chan = 64 114 | out_chan = 32 115 | self.decoder0_g = nn.Sequential( # 64 -> 32 116 | nn.Conv2d(in_chan,out_chan,3,padding=1), 117 | nn.BatchNorm2d(out_chan), 118 | nn.ReLU(inplace=True), 119 | nn.Conv2d(out_chan,out_chan,3,padding=1), 120 | nn.BatchNorm2d(out_chan), 121 | nn.ReLU(inplace=True), 122 | nn.Conv2d(out_chan,out_chan,3,padding=1), 123 | nn.Upsample(scale_factor=2, mode='bilinear')) 124 | self.decoder_final_g = nn.Conv2d(out_chan, 3, kernel_size=3, padding=1) # 32 -> 3 125 | 126 | ########################## 127 | ### Decoder part - LOCAL 128 | ########################## 129 | 130 | in_chan = 512 131 | out_chan = 256 132 | self.decoder4_l = nn.Sequential( # 512 -> 256 133 | nn.Conv2d(in_chan,out_chan,3,padding=1), 134 | nn.BatchNorm2d(out_chan), 135 | nn.ReLU(inplace=True), 136 | nn.Conv2d(out_chan,out_chan,3,padding=1), 137 | nn.BatchNorm2d(out_chan), 138 | nn.ReLU(inplace=True), 139 | nn.Conv2d(out_chan,out_chan,3,padding=1), 140 | nn.BatchNorm2d(out_chan), 141 | nn.ReLU(inplace=True)) 142 | 143 | in_chan = 256 144 | out_chan = 128 145 | self.decoder3_l = nn.Sequential( # 256 -> 128 146 | nn.Conv2d(in_chan,out_chan,3,padding=1), 147 | nn.BatchNorm2d(out_chan), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(out_chan,out_chan,3,padding=1), 150 | nn.BatchNorm2d(out_chan), 151 | nn.ReLU(inplace=True), 152 | nn.Conv2d(out_chan,out_chan,3,padding=1), 153 | nn.BatchNorm2d(out_chan), 154 | nn.ReLU(inplace=True)) 155 | 156 | in_chan = 128 157 | out_chan = 64 158 | self.decoder2_l = nn.Sequential( # 128 -> 64 159 | nn.Conv2d(in_chan,out_chan,3,padding=1), 160 | nn.BatchNorm2d(out_chan), 161 | nn.ReLU(inplace=True), 162 | nn.Conv2d(out_chan,out_chan,3,padding=1), 163 | nn.BatchNorm2d(out_chan), 164 | nn.ReLU(inplace=True), 165 | nn.Conv2d(out_chan,out_chan,3,padding=1), 166 | nn.BatchNorm2d(out_chan), 167 | nn.ReLU(inplace=True)) 168 | 169 | in_chan = 64 170 | out_chan = 64 171 | self.decoder1_l = nn.Sequential( # 64 -> 64 172 | nn.Conv2d(in_chan,out_chan,3,padding=1), 173 | nn.BatchNorm2d(out_chan), 174 | nn.ReLU(inplace=True), 175 | nn.Conv2d(out_chan,out_chan,3,padding=1), 176 | nn.BatchNorm2d(out_chan), 177 | nn.ReLU(inplace=True), 178 | nn.Conv2d(out_chan,out_chan,3,padding=1), 179 | nn.BatchNorm2d(out_chan), 180 | nn.ReLU(inplace=True)) 181 | 182 | in_chan = 64 183 | out_chan = 32 184 | self.decoder0_l = nn.Sequential( # 64 -> 32 185 | nn.Conv2d(in_chan,out_chan,3,padding=1), 186 | nn.BatchNorm2d(out_chan), 187 | nn.ReLU(inplace=True), 188 | nn.Conv2d(out_chan,out_chan,3,padding=1), 189 | nn.BatchNorm2d(out_chan), 190 | nn.ReLU(inplace=True)) 191 | self.decoder_final_l = nn.Conv2d(out_chan,1,3,padding=1) # 32 -> 1 192 | 193 | ########################## 194 | ### Decoder part - MODULES 195 | ########################## 196 | 197 | self.tfi_3 = TFI(256) 198 | self.tfi_2 = TFI(128) 199 | self.tfi_1 = TFI(64) 200 | self.tfi_0 = TFI(64) 201 | 202 | self.sbfi_2 = SBFI(128, 32, 8) 203 | self.sbfi_1 = SBFI(64, 32, 4) 204 | self.sbfi_0 = SBFI(64, 32, 2) 205 | 206 | self.dbfi_2 = DBFI(128, 512, 4) 207 | self.dbfi_1 = DBFI(64, 512, 8) 208 | self.dbfi_0 = DBFI(64, 512, 16) 209 | 210 | def forward(self, x, indices, feas): 211 | ########################### 212 | ### Decoder part - Global 213 | ########################### 214 | d4_g = self.decoder4_g(x) # 512 -> 256 215 | d3_g = self.decoder3_g(d4_g) # 256 -> 128 216 | d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, x) # 128, 512 -> 128 217 | d2_g = self.decoder2_g(d2_g) # 128 -> 64 218 | d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, x) # 64, 512 -> 64 219 | d1_g = self.decoder1_g(d1_g) # 64 -> 64 220 | d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, x) # 64, 512 -> 64 221 | d0_g = self.decoder0_g(d0_g) # 64 -> 32 222 | global_sigmoid = self.decoder_final_g(d0_g) # 32 -> 3 223 | ########################### 224 | ### Decoder part - Local 225 | ########################### 226 | d4_l = self.decoder4_l(x) # 512 -> 256 227 | d4_l = F.max_unpool2d(d4_l, indices[-1], kernel_size=2, stride=2) # 256 228 | d3_l = self.tfi_3(d4_g, d4_l, feas[-1]) # 256, 256, 256 -> 256 229 | d3_l = self.decoder3_l(d3_l) # 256 -> 128 230 | d3_l = F.max_unpool2d(d3_l, indices[-2], kernel_size=2, stride=2) # 128 231 | d2_l = self.tfi_2(d3_g, d3_l, feas[-2]) # 128, 128, 128 -> 128 232 | d2_l = self.sbfi_2(d2_l, feas[-5]) # 128, 32 -> 128 233 | d2_l = self.decoder2_l(d2_l) # 128 -> 64 234 | d2_l = F.max_unpool2d(d2_l, indices[-3], kernel_size=2, stride=2) # 64 235 | d1_l = self.tfi_1(d2_g, d2_l, feas[-3]) # 64, 64, 64 -> 64 236 | d1_l = self.sbfi_1(d1_l, feas[-5]) # 64, 32 -> 64 237 | d1_l = self.decoder1_l(d1_l) # 64 -> 64 238 | d1_l = F.max_unpool2d(d1_l, indices[-4], kernel_size=2, stride=2) # 64 239 | d0_l = self.tfi_0(d1_g, d1_l, feas[-4]) # 64, 64, 64 -> 64 240 | d0_l = self.sbfi_0(d0_l, feas[-5]) # 64, 32 -> 64 241 | d0_l = self.decoder0_l(d0_l) # 64 -> 32 242 | d0_l = F.max_unpool2d(d0_l, indices[-5], kernel_size=2, stride=2) # 32 243 | d0_l = self.decoder_final_l(d0_l) # 32 -> 1 244 | local_sigmoid = torch.sigmoid(d0_l) # 1 245 | fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid) 246 | return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0 247 | 248 | 249 | ########################################## 250 | ## Matting Model 251 | ########################################## 252 | 253 | class ViTAE_noRC_MaxPooling_Matting(nn.Module): 254 | def __init__(self, encoder, decoder) -> None: 255 | super().__init__() 256 | self.encoder = encoder 257 | self.decoder = decoder 258 | 259 | def forward(self, x): 260 | embeddings, indices, feas = self.encoder(x) 261 | return self.decoder(embeddings, indices, feas) 262 | -------------------------------------------------------------------------------- /core/network/ViTAE_S/token_performer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | 4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 5 | Licensed under the MIT License (see LICENSE for details) 6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 7 | Paper link: https://arxiv.org/abs/2203.16828 8 | 9 | """ 10 | 11 | """ 12 | Take Performer as T2T Transformer 13 | """ 14 | import math 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | class Token_performer(nn.Module): 20 | def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1, gamma=False, init_values=1e-4): 21 | super().__init__() 22 | self.head_dim = in_dim // head_cnt 23 | self.emb = in_dim 24 | self.kqv = nn.Linear(dim, 3 * self.emb) 25 | self.dp = nn.Dropout(dp1) 26 | self.proj = nn.Linear(self.emb, self.emb) 27 | self.head_cnt = head_cnt 28 | self.norm1 = nn.LayerNorm(dim) 29 | self.norm2 = nn.LayerNorm(self.emb) 30 | self.epsilon = 1e-8 # for stable in division 31 | self.drop_path = nn.Identity() 32 | 33 | self.mlp = nn.Sequential( 34 | nn.Linear(self.emb, 1 * self.emb), 35 | nn.GELU(), 36 | nn.Linear(1 * self.emb, self.emb), 37 | nn.Dropout(dp2), 38 | ) 39 | 40 | self.m = int(self.head_dim * kernel_ratio) 41 | self.w = torch.randn(head_cnt, self.m, self.head_dim) 42 | for i in range(self.head_cnt): 43 | self.w[i] = nn.Parameter(nn.init.orthogonal_(self.w[i]) * math.sqrt(self.m), requires_grad=False) 44 | self.w.requires_grad_(False) 45 | 46 | if gamma: 47 | self.gamma1 = nn.Parameter(init_values * torch.ones((self.emb)),requires_grad=True) 48 | else: 49 | self.gamma1 = 1 50 | 51 | def prm_exp(self, x): 52 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 53 | # and Simo Ryu (https://github.com/cloneofsimo) 54 | # ==== positive random features for gaussian kernels ==== 55 | # x = (B, H, N, hs) 56 | # w = (H, m, hs) 57 | # return : x : B, T, m 58 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] 59 | # therefore return exp(w^Tx - |x|/2)/sqrt(m) 60 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, 1, self.m) / 2 61 | wtx = torch.einsum('bhti,hmi->bhtm', x.float(), self.w.to(x.device)) 62 | 63 | return torch.exp(wtx - xd) / math.sqrt(self.m) 64 | 65 | def attn(self, x): 66 | B, N, C = x.shape 67 | kqv = self.kqv(x).reshape(B, N, 3, self.head_cnt, self.head_dim).permute(2, 0, 3, 1, 4) 68 | k, q, v = kqv[0], kqv[1], kqv[2] # (B, H, T, hs) 69 | 70 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, H, T, m), (B, H, T, m) 71 | D = torch.einsum('bhti,bhi->bht', qp, kp.sum(dim=2)).unsqueeze(dim=-1) # (B, H, T, m) * (B, H, m) -> (B, H, T, 1) 72 | kptv = torch.einsum('bhin,bhim->bhnm', v.float(), kp) # (B, H, emb, m) 73 | y = torch.einsum('bhti,bhni->bhtn', qp, kptv) / (D.repeat(1, 1, 1, self.head_dim) + self.epsilon) # (B, H, T, emb)/Diag 74 | 75 | # skip connection 76 | 77 | y = y.permute(0, 2, 1, 3).reshape(B, N, self.emb) 78 | v = v.permute(0, 2, 1, 3).reshape(B, N, self.emb) 79 | 80 | y = v + self.dp(self.gamma1 * self.proj(y)) # same as token_transformer, use v as skip connection 81 | 82 | return y 83 | 84 | def forward(self, x): 85 | x = self.attn(self.norm1(x)) 86 | x = x + self.mlp(self.norm2(x)) 87 | return x -------------------------------------------------------------------------------- /core/network/ViTAE_S/token_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. 2 | # 3 | # This source code is licensed under the Clear BSD License 4 | # LICENSE file in the root directory of this file 5 | # All rights reserved. 6 | """ 7 | Take the standard Transformer as T2T Transformer 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | from timm.models.layers import DropPath 12 | from .NormalCell import Mlp 13 | 14 | class Attention(nn.Module): 15 | def __init__(self, dim, num_heads=8, in_dim = None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., gamma=False, init_values=1e-4): 16 | super().__init__() 17 | self.num_heads = num_heads 18 | self.in_dim = in_dim 19 | head_dim = in_dim // num_heads 20 | self.scale = qk_scale or head_dim ** -0.5 21 | 22 | self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias) 23 | self.attn_drop = nn.Dropout(attn_drop) 24 | self.proj = nn.Linear(in_dim, in_dim) 25 | self.proj_drop = nn.Dropout(proj_drop) 26 | if gamma: 27 | self.gamma1 = nn.Parameter(init_values * torch.ones((in_dim)),requires_grad=True) 28 | else: 29 | self.gamma1 = 1 30 | 31 | def forward(self, x): 32 | B, N, C = x.shape 33 | 34 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim // self.num_heads).permute(2, 0, 3, 1, 4) 35 | q, k, v = qkv[0], qkv[1], qkv[2] 36 | 37 | attn = (q @ k.transpose(-2, -1)) * self.scale 38 | attn = attn.softmax(dim=-1) 39 | attn = self.attn_drop(attn) 40 | 41 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim) 42 | x = self.proj(x) 43 | x = self.proj_drop(self.gamma1 * x) 44 | v = v.permute(0, 2, 1, 3).view(B, N, self.in_dim).contiguous() 45 | # skip connection 46 | x = v + x # because the original x has different size with current x, use v to do skip connection 47 | 48 | return x 49 | 50 | class Token_transformer(nn.Module): 51 | 52 | def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 53 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, gamma=False, init_values=1e-4): 54 | super().__init__() 55 | self.norm1 = norm_layer(dim) 56 | self.attn = Attention( 57 | dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, gamma=gamma, init_values=init_values) 58 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 59 | self.norm2 = norm_layer(in_dim) 60 | self.mlp = Mlp(in_features=in_dim, hidden_features=int(in_dim*mlp_ratio), out_features=in_dim, act_layer=act_layer, drop=drop) 61 | 62 | def forward(self, x): 63 | x = self.attn(self.norm1(x)) 64 | x = x + self.drop_path(self.mlp(self.norm2(x))) 65 | return x 66 | -------------------------------------------------------------------------------- /core/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_model import build_model -------------------------------------------------------------------------------- /core/network/build_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | 4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 5 | Licensed under the MIT License (see LICENSE for details) 6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 7 | Paper link: https://arxiv.org/abs/2203.16828 8 | 9 | """ 10 | 11 | from .RestNet34 import * 12 | from .Swin_T import * 13 | from .ViTAE_S import * 14 | 15 | 16 | def build_model(model_arch, **kwargs): 17 | if model_arch == "r34": 18 | model = p3mnet_resnet34(**kwargs) 19 | elif model_arch == "swin": 20 | model = p3mnet_swin_t(**kwargs) 21 | elif model_arch == "vitae": 22 | model = p3mnet_vitae_s(**kwargs) 23 | else: 24 | print(model_arch) 25 | raise NotImplementedError 26 | 27 | return model -------------------------------------------------------------------------------- /core/network/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | 4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 5 | Licensed under the MIT License (see LICENSE for details) 6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 7 | Paper link: https://arxiv.org/abs/2203.16828 8 | 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | "3x3 convolution with padding" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | """1x1 convolution""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 24 | 25 | 26 | class TFI(nn.Module): 27 | expansion = 1 28 | def __init__(self, planes,stride=1): 29 | super(TFI, self).__init__() 30 | middle_planes = int(planes/2) 31 | self.transform = conv1x1(planes, middle_planes) 32 | self.conv1 = conv3x3(middle_planes*3, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.stride = stride 36 | 37 | def forward(self, input_s_guidance, input_m_decoder, input_m_encoder): 38 | input_s_guidance_transform = self.transform(input_s_guidance) 39 | input_m_decoder_transform = self.transform(input_m_decoder) 40 | input_m_encoder_transform = self.transform(input_m_encoder) 41 | x = torch.cat((input_s_guidance_transform,input_m_decoder_transform,input_m_encoder_transform),1) 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | return out 46 | 47 | 48 | class SBFI(nn.Module): 49 | def __init__(self, planes,planes2,stride=1): 50 | # planes2, min dim 51 | super(SBFI, self).__init__() 52 | self.stride = stride 53 | self.transform1 = conv1x1(planes, int(planes/2)) 54 | self.transform2 = conv1x1(planes2, int(planes/2)) 55 | self.maxpool = nn.MaxPool2d(2, stride=stride) 56 | self.conv1 = conv3x3(planes, planes, 1) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | def forward(self, input_m_decoder,e0): 61 | input_m_decoder_transform = self.transform1(input_m_decoder) 62 | e0_maxpool = self.maxpool(e0) 63 | e0_transform = self.transform2(e0_maxpool) 64 | x = torch.cat((input_m_decoder_transform,e0_transform),1) 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | out = out+input_m_decoder 69 | return out 70 | 71 | 72 | class DBFI(nn.Module): 73 | def __init__(self, planes,planes2,stride=1): 74 | # planes2, max dim 75 | super(DBFI, self).__init__() 76 | self.stride = stride 77 | self.transform1 = conv1x1(planes, int(planes/2)) 78 | self.transform2 = conv1x1(planes2, int(planes/2)) 79 | self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear') 80 | self.conv1 = conv3x3(planes, planes, 1) 81 | self.bn1 = nn.BatchNorm2d(planes) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.conv2 = conv3x3(planes, 3, 1) 84 | self.upsample2 = nn.Upsample(scale_factor=int(32/stride), mode='bilinear') 85 | 86 | def forward(self, input_s_decoder,e4): 87 | input_s_decoder_transform = self.transform1(input_s_decoder) 88 | e4_transform = self.transform2(e4) 89 | e4_upsample = self.upsample(e4_transform) 90 | x = torch.cat((input_s_decoder_transform,e4_upsample),1) 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | out = out+input_s_decoder 95 | out_side = self.conv2(out) 96 | out_side = self.upsample2(out_side) 97 | return out, out_side 98 | -------------------------------------------------------------------------------- /core/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对test来说,和train是完全分开的。 3 | build model是不一样的。因为cp的原因,train model和test model不一样。 4 | path问题,path是不能和train一起公用的 5 | 另外很多args是train里没有的 6 | 但是test又需要用到config.MODEL里的参数。 7 | 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | import ipdb 12 | import cv2 13 | import argparse 14 | import numpy as np 15 | from tqdm import tqdm 16 | from PIL import Image 17 | from skimage.transform import resize 18 | from torchvision import transforms 19 | import logging 20 | import warnings 21 | import yaml 22 | import torch.distributed as dist 23 | 24 | from config import * 25 | from util import * 26 | from evaluate import * 27 | from network.build_model import build_model 28 | from config_yacs import get_config, CN 29 | from logger import create_logger 30 | 31 | 32 | def get_args(): 33 | parser = argparse.ArgumentParser(description='PyTorch Super Res Example') 34 | parser.add_argument('--tag', type=str, required=True) 35 | parser.add_argument('--test_dataset', type=str, required=True, choices=VALID_TEST_DATA_CHOICE, help="which dataset to test") 36 | parser.add_argument('--test_ckpt', type=str, default='', required=False, help="path of model to use") 37 | parser.add_argument('--test_method', type=str, required=False, help="which dataset to test") 38 | parser.add_argument('--fast_test', action='store_true', default=False, help='skip conn and grad metrics for fast test') 39 | parser.add_argument('--test_privacy', action='store_true', default=False, help='test on the privacy content') 40 | parser.add_argument('--save_result', action='store_true') 41 | args, _ = parser.parse_known_args() 42 | 43 | if args.test_dataset == 'VAL500NP': 44 | args.test_privacy = False 45 | warnings.warn("NO FACEMASK FOR VAL500NP") 46 | 47 | args.existing_cfg = np.load(os.path.join(CKPT_SAVE_FOLDER, args.tag, 'args.npy'), allow_pickle=True).item() 48 | if type(args.existing_cfg) != CN: 49 | new_cfg = CN() 50 | new_cfg.TAG = args.tag 51 | new_cfg.MODEL = CN() 52 | new_cfg.MODEL.TYPE = args.existing_cfg.arch 53 | args.existing_cfg = new_cfg 54 | 55 | config = get_config(args) 56 | return args, config 57 | 58 | def inference_once(config, model, scale_img, scale_trimap=None): 59 | if torch.cuda.device_count() > 0: 60 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1).cuda() 61 | else: 62 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1) 63 | input_t = tensor_img 64 | input_t = input_t/255.0 65 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 66 | std=[0.229, 0.224, 0.225]) 67 | input_t = normalize(input_t) 68 | input_t = input_t.unsqueeze(0) 69 | pred_global, pred_local, pred_fusion = model(input_t)[:3] 70 | pred_global = pred_global.data.cpu().numpy() 71 | pred_global = gen_trimap_from_segmap_e2e(pred_global) 72 | pred_local = pred_local.data.cpu().numpy()[0,0,:,:] 73 | pred_fusion = pred_fusion.data.cpu().numpy()[0,0,:,:] 74 | return pred_global, pred_local, pred_fusion 75 | 76 | def inference_img_modnet(config, model, img, *args): 77 | im_h, im_w, c = img.shape 78 | 79 | if config.TEST.TEST_METHOD=='RESIZE512': 80 | ref_size = 512 81 | if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: 82 | if im_w >= im_h: 83 | im_rh = ref_size 84 | im_rw = int(im_w / im_h * ref_size) 85 | elif im_w < im_h: 86 | im_rw = ref_size 87 | im_rh = int(im_h / im_w * ref_size) 88 | else: 89 | im_rh = im_h 90 | im_rw = im_w 91 | 92 | im_rw = im_rw - im_rw % 32 93 | im_rh = im_rh - im_rh % 32 94 | 95 | scale_img = resize(img,(im_rh, im_rw)) 96 | 97 | scale_img = scale_img*255.0 98 | 99 | elif config.TEST.TEST_METHOD=='ORIGIN': 100 | im_rh = min(MAX_SIZE_H, im_h - im_h % 32) 101 | im_rw = min(MAX_SIZE_W, im_w - im_w % 32) 102 | 103 | scale_img = resize(img, (im_rh, im_rw)) 104 | scale_img = scale_img * 255.0 105 | else: 106 | raise NotImplementedError 107 | 108 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1).cuda() 109 | tensor_img = (tensor_img/255.0).unsqueeze(0) 110 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 111 | std=[0.229, 0.224, 0.225]) 112 | tensor_img = normalize(tensor_img) 113 | 114 | _,_,pred= model(tensor_img, inference=True) 115 | pred = pred.squeeze() 116 | pred = pred.cpu().data.numpy() 117 | pred = resize(pred,(im_h,im_w)) 118 | 119 | return pred 120 | 121 | def inference_img_p3m(config, model, img, *args): 122 | h, w, c = img.shape 123 | new_h = min(MAX_SIZE_H, h - (h % 32)) 124 | new_w = min(MAX_SIZE_W, w - (w % 32)) 125 | 126 | if config.TEST.TEST_METHOD=='HYBRID': 127 | global_ratio = 1/2 128 | local_ratio = 1 129 | resize_h = int(h*global_ratio) 130 | resize_w = int(w*global_ratio) 131 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32)) 132 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32)) 133 | scale_img = resize(img,(new_h,new_w))*255.0 134 | pred_coutour_1, pred_retouching_1, pred_fusion_1 = inference_once(config, model, scale_img) 135 | torch.cuda.empty_cache() 136 | pred_coutour_1 = resize(pred_coutour_1,(h,w))*255.0 137 | resize_h = int(h*local_ratio) 138 | resize_w = int(w*local_ratio) 139 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32)) 140 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32)) 141 | scale_img = resize(img,(new_h,new_w))*255.0 142 | pred_coutour_2, pred_retouching_2, pred_fusion_2 = inference_once(config, model, scale_img) 143 | torch.cuda.empty_cache() 144 | pred_retouching_2 = resize(pred_retouching_2,(h,w)) 145 | pred_fusion = get_masked_local_from_global_test(pred_coutour_1, pred_retouching_2) 146 | return pred_fusion 147 | elif config.TEST.TEST_METHOD=='RESIZE': 148 | resize_h = int(h/2) 149 | resize_w = int(w/2) 150 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32)) 151 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32)) 152 | scale_img = resize(img,(new_h,new_w))*255.0 153 | pred_global, pred_local, pred_fusion = inference_once(config, model, scale_img) 154 | pred_local = resize(pred_local,(h,w)) 155 | pred_global = resize(pred_global,(h,w))*255.0 156 | pred_fusion = resize(pred_fusion,(h,w)) 157 | return pred_fusion 158 | else: 159 | raise NotImplementedError() 160 | 161 | def test_p3m10k(config, model, logger): 162 | if torch.cuda.device_count() > 0: 163 | torch.cuda.empty_cache() 164 | 165 | arch_predict_dict = { 166 | 'r34': inference_img_p3m, 167 | 'swin': inference_img_p3m, 168 | 'vitae': inference_img_p3m, 169 | } 170 | 171 | ############################ 172 | # Some initial setting for paths 173 | ############################ 174 | if config.TEST.DATASET == 'VAL500P': 175 | val_option = 'P3M_500_P' 176 | elif config.TEST.DATASET == 'VAL500NP': 177 | val_option = 'P3M_500_NP' 178 | else: 179 | val_option = config.TEST.DATASET 180 | ORIGINAL_PATH = DATASET_PATHS_DICT['P3M10K'][val_option]['ORIGINAL_PATH'] 181 | MASK_PATH = DATASET_PATHS_DICT['P3M10K'][val_option]['MASK_PATH'] 182 | TRIMAP_PATH = DATASET_PATHS_DICT['P3M10K'][val_option]['TRIMAP_PATH'] 183 | if config.TEST.TEST_PRIVACY: 184 | PRIVACY_PATH = DATASET_PATHS_DICT['P3M10K'][val_option]['PRIVACY_MASK_PATH'] 185 | ############################ 186 | # Start testing 187 | ############################ 188 | sad_diffs = 0. 189 | mse_diffs = 0. 190 | mad_diffs = 0. 191 | sad_trimap_diffs = 0. 192 | mse_trimap_diffs = 0. 193 | mad_trimap_diffs = 0. 194 | sad_fg_diffs = 0. 195 | sad_bg_diffs = 0. 196 | conn_diffs = 0. 197 | grad_diffs = 0. 198 | sad_privacy_diffs = 0. # for test privacy only 199 | mse_privacy_diffs = 0. # for test privacy only 200 | mad_privacy_diffs = 0. # for test privacy only 201 | if config.TEST.SAVE_RESULT: 202 | result_dir = os.path.join(TEST_RESULT_FOLDER, 'test_{}_{}_{}_{}'.format(config.TAG, config.TEST.DATASET, config.TEST.TEST_METHOD, config.TEST.CKPT_NAME.replace('/', '-'))) 203 | refresh_folder(result_dir) 204 | model.eval() 205 | img_list = listdir_nohidden(ORIGINAL_PATH) 206 | total_number = len(img_list) 207 | logger.info("===============================") 208 | logger.info(f'====> Start Testing\n\t--Dataset: {config.TEST.DATASET}\n\t--Test: {config.TEST.TEST_METHOD}\n\t--Number: {total_number}') 209 | 210 | if config.TAG.startswith("debug"): 211 | img_list = img_list[:10] 212 | 213 | if dist.is_initialized(): 214 | img_list = img_list[GET_RANK()::GET_WORLD_SIZE()] 215 | total_number = len(img_list) 216 | print('rank {}/{}, total num: {}'.format(GET_RANK(), GET_WORLD_SIZE(), total_number)) 217 | 218 | for img_name in tqdm(img_list): 219 | img_path = ORIGINAL_PATH+img_name 220 | alpha_path = MASK_PATH+extract_pure_name(img_name)+'.png' 221 | trimap_path = TRIMAP_PATH+extract_pure_name(img_name)+'.png' 222 | img = np.array(Image.open(img_path)) 223 | trimap = np.array(Image.open(trimap_path)) 224 | alpha = np.array(Image.open(alpha_path))/255. 225 | img = img[:,:,:3] if img.ndim>2 else img 226 | trimap = trimap[:,:,0] if trimap.ndim>2 else trimap 227 | alpha = alpha[:,:,0] if alpha.ndim>2 else alpha 228 | 229 | if config.TEST.TEST_PRIVACY: 230 | privacy_path = PRIVACY_PATH+extract_pure_name(img_name)+'.png' 231 | privacy = np.array(Image.open(privacy_path)) 232 | privacy = privacy[:, :, 0] if privacy.ndim>2 else privacy 233 | 234 | with torch.no_grad(): 235 | predict = arch_predict_dict.get(config.MODEL.TYPE, inference_img_p3m)(config, model, img) 236 | 237 | # test on whole image and trimap area 238 | sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap) 239 | sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha) 240 | sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap) 241 | if config.TEST.FAST_TEST: 242 | conn_diff = -1 243 | grad_diff = -1 244 | else: 245 | conn_diff = compute_connectivity_loss_whole_image(predict, alpha) 246 | grad_diff = compute_gradient_whole_image(predict, alpha) 247 | 248 | # test on privacy area 249 | if config.TEST.TEST_PRIVACY: 250 | sad_privacy_diff, mse_privacy_diff, mad_privacy_diff = calculate_sad_mse_mad_privacy(predict, alpha, privacy) 251 | else: 252 | sad_privacy_diff = -1 253 | mse_privacy_diff = -1 254 | mad_privacy_diff = -1 255 | 256 | logger.info(f"[{img_list.index(img_name)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nsad_trimap:{sad_trimap_diff}\nmse_trimap:{mse_trimap_diff}\nmad_trimap:{mad_trimap_diff}\nsad_fg:{sad_fg_diff}\nsad_bg:{sad_bg_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\nsad_privacy:{sad_privacy_diff}\nmse_privacy:{mse_privacy_diff}\nmad_privacy:{mad_privacy_diff}\n-----------") 257 | sad_diffs += sad_diff 258 | mse_diffs += mse_diff 259 | mad_diffs += mad_diff 260 | mse_trimap_diffs += mse_trimap_diff 261 | sad_trimap_diffs += sad_trimap_diff 262 | mad_trimap_diffs += mad_trimap_diff 263 | sad_fg_diffs += sad_fg_diff 264 | sad_bg_diffs += sad_bg_diff 265 | conn_diffs += conn_diff 266 | grad_diffs += grad_diff 267 | sad_privacy_diffs += sad_privacy_diff 268 | mse_privacy_diffs += mse_privacy_diff 269 | mad_privacy_diffs += mad_privacy_diff 270 | 271 | if config.TEST.SAVE_RESULT: 272 | save_test_result(os.path.join(result_dir, extract_pure_name(img_name)+'.png'),predict) 273 | 274 | res_dict = {} 275 | logger.info("===============================") 276 | logger.info(f"Testing numbers: {total_number}") 277 | # res_dict['number'] = total_number 278 | logger.info("SAD: {}".format(sad_diffs / total_number)) 279 | res_dict['SAD'] = sad_diffs / total_number 280 | logger.info("MSE: {}".format(mse_diffs / total_number)) 281 | res_dict['MSE'] = mse_diffs / total_number 282 | logger.info("MAD: {}".format(mad_diffs / total_number)) 283 | res_dict['MAD'] = mad_diffs / total_number 284 | logger.info("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number)) 285 | res_dict['SAD_TRIMAP'] = sad_trimap_diffs / total_number 286 | logger.info("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number)) 287 | res_dict['MSE_TRIMAP'] = mse_trimap_diffs / total_number 288 | logger.info("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number)) 289 | res_dict['MAD_TRIMAP'] = mad_trimap_diffs / total_number 290 | logger.info("SAD FG: {}".format(sad_fg_diffs / total_number)) 291 | res_dict['SAD_FG'] = sad_fg_diffs / total_number 292 | logger.info("SAD BG: {}".format(sad_bg_diffs / total_number)) 293 | res_dict['SAD_BG'] = sad_bg_diffs / total_number 294 | logger.info("CONN: {}".format(conn_diffs / total_number)) 295 | res_dict['CONN'] = conn_diffs / total_number 296 | logger.info("GRAD: {}".format(grad_diffs / total_number)) 297 | res_dict['GRAD'] = grad_diffs / total_number 298 | logger.info("SAD PRIVACY: {}".format(sad_privacy_diffs / total_number)) 299 | res_dict['SAD_PRIVACY'] = sad_privacy_diffs / total_number 300 | logger.info("MSE PRIVACY: {}".format(mse_privacy_diffs / total_number)) 301 | res_dict['MSE_PRIVACY'] = mse_privacy_diffs / total_number 302 | logger.info("MAD PRIVACY: {}".format(mad_privacy_diffs / total_number)) 303 | res_dict['MAD_PRIVACY'] = mad_privacy_diffs / total_number 304 | 305 | # return int(sad_diffs/total_number) 306 | print("SAD: {}\nMSE: {}\nMAD: {}\n".format(res_dict['SAD'], res_dict['MSE'], res_dict['MAD'])) 307 | 308 | if dist.is_initialized(): 309 | for k in res_dict.keys(): 310 | res_dict[k] = torch.tensor(k) # for reduce only 311 | print('rank') 312 | 313 | return res_dict 314 | 315 | def test_samples(config, model): 316 | 317 | arch_predict_dict = { 318 | 'r34': inference_img_p3m, 319 | 'swin': inference_img_p3m, 320 | 'vitae': inference_img_p3m, 321 | } 322 | 323 | print(f'=====> Test on samples and save alpha results') 324 | model.eval() 325 | img_list = listdir_nohidden(SAMPLES_ORIGINAL_PATH) 326 | refresh_folder(SAMPLES_RESULT_ALPHA_PATH) 327 | refresh_folder(SAMPLES_RESULT_COLOR_PATH) 328 | for img_name in tqdm(img_list): 329 | img_path = SAMPLES_ORIGINAL_PATH+img_name 330 | try: 331 | img = np.array(Image.open(img_path))[:,:,:3] 332 | except Exception as e: 333 | print(f'Error: {str(e)} | Name: {img_name}') 334 | h, w, c = img.shape 335 | if min(h, w)>SHORTER_PATH_LIMITATION: 336 | if h>=w: 337 | new_w = SHORTER_PATH_LIMITATION 338 | new_h = int(SHORTER_PATH_LIMITATION*h/w) 339 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 340 | else: 341 | new_h = SHORTER_PATH_LIMITATION 342 | new_w = int(SHORTER_PATH_LIMITATION*w/h) 343 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 344 | 345 | with torch.no_grad(): 346 | if torch.cuda.device_count() > 0: 347 | torch.cuda.empty_cache() 348 | predict = arch_predict_dict.get(config.MODEL.TYPE, inference_img_p3m)(config, model, img) 349 | 350 | composite = generate_composite_img(img, predict) 351 | cv2.imwrite(os.path.join(SAMPLES_RESULT_COLOR_PATH, extract_pure_name(img_name)+'.png'),composite) 352 | predict = predict*255.0 353 | predict = cv2.resize(predict, (w, h), interpolation=cv2.INTER_LINEAR) 354 | cv2.imwrite(os.path.join(SAMPLES_RESULT_ALPHA_PATH, extract_pure_name(img_name)+'.png'),predict.astype(np.uint8)) 355 | 356 | def load_model_and_deploy(config): 357 | 358 | ### build model 359 | model = build_model(config.MODEL.TYPE) 360 | 361 | ### load ckpt 362 | ckpt_path = os.path.join(CKPT_SAVE_FOLDER, '{}/{}.pth'.format(config.TAG, config.TEST.CKPT_NAME)) 363 | if torch.cuda.device_count()==0: 364 | print(f'Running on CPU...') 365 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) 366 | model.load_state_dict(ckpt['state_dict'], strict=True) 367 | else: 368 | print(f'Running on GPU with CUDA...') 369 | ckpt = torch.load(ckpt_path) 370 | model.load_state_dict(ckpt['state_dict'], strict=True) 371 | model = model.cuda() 372 | model.cuda() 373 | 374 | ### Test 375 | if config.TEST.DATASET=='SAMPLES': 376 | test_samples(config, model) 377 | elif config.TEST.DATASET in VALID_TEST_DATASET_CHOICE: 378 | logname = 'test_{}_{}_{}_{}'.format(config.TAG, config.TEST.DATASET, config.TEST.TEST_METHOD, config.TEST.CKPT_NAME.replace('/', '-')) 379 | logging_filename = TEST_LOGS_FOLDER+logname+'.log' 380 | logger = create_logger(logging_filename) 381 | test_p3m10k(config, model, logger) 382 | else: 383 | print('Please input the correct dataset_choice (SAMPLES, P3M_500_P or P3M_500_NP).') 384 | 385 | if __name__ == '__main__': 386 | args, config = get_args() 387 | load_model_and_deploy(config) 388 | -------------------------------------------------------------------------------- /core/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | Inferernce file. 4 | 5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | Licensed under the MIT License (see LICENSE for details) 7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | """ 11 | import os 12 | import ipdb 13 | import random 14 | import argparse 15 | import torch 16 | import torch.optim as optim 17 | from torch.autograd import Variable 18 | from torch.utils.data import DataLoader 19 | import torch.backends.cudnn as cudnn 20 | import numpy as np 21 | import datetime 22 | import wandb 23 | import shutil 24 | from tqdm import tqdm 25 | from yacs.config import CfgNode as CN 26 | 27 | from config import * 28 | from util import * 29 | from evaluate import * 30 | from test import test_p3m10k 31 | from data import MattingDataset, MattingTransform 32 | from network.build_model import build_model 33 | from config_yacs import get_config 34 | from logger import create_logger 35 | 36 | ######### Parsing arguments ######### 37 | def get_args(): 38 | parser = argparse.ArgumentParser(description='Arguments for the training purpose.') 39 | parser.add_argument('--cfg', type=str, help='path to config file', ) 40 | parser.add_argument( 41 | "--opts", 42 | help="Modify config options by adding 'KEY VALUE' pairs. ", 43 | default=None, 44 | nargs='+', 45 | ) 46 | parser.add_argument('--arch', type=str, help="backbone architecture of the model") 47 | parser.add_argument('--train_from_scratch', action='store_true') 48 | parser.add_argument('--nEpochs', type=int, default=50, help='number of epochs to train for, 500 for ORI-Track and 100 for COMP-Track') 49 | parser.add_argument('--lr', type=float, default=0.00001, help='Learning Rate. Default=0.00001') 50 | parser.add_argument('--warmup_nEpochs', type=int, default=0, help='epochs for warming up') 51 | parser.add_argument('--lr_decay', action='store_true', default=None, help='whehter to use lr decay') 52 | parser.add_argument('--clip_grad', action='store_true', default=None, help='whether to clip gradient') 53 | parser.add_argument('--threads', type=int, default=8, help='number of threads for data loader to use') 54 | parser.add_argument('--batchSize', type=int, default=8, help='training batch size') 55 | parser.add_argument('--tag', type=str, default='debug') 56 | parser.add_argument('--enable_wandb', action='store_true', default=None) 57 | parser.add_argument('--test_freq', type=int) 58 | parser.add_argument('--auto_resume', action='store_true') 59 | parser.add_argument('--test_method', default=None) 60 | parser.add_argument('--dataset', type=str) 61 | parser.add_argument('--train_set', type=str) 62 | args, _ = parser.parse_known_args() 63 | 64 | if args.tag.lower().startswith('debug'): 65 | args.enable_wandb = False 66 | 67 | if args.auto_resume: 68 | # load existing cfg 69 | args_path = os.path.join(CKPT_SAVE_FOLDER, args.tag, "args.npy") 70 | if os.path.exists(args_path): 71 | args.existing_cfg = np.load(os.path.join(CKPT_SAVE_FOLDER, args.tag, 'args.npy'), allow_pickle=True).item() 72 | 73 | config = get_config(args) 74 | if args.auto_resume: 75 | set_seed(config.SEED) 76 | 77 | print(config) 78 | print(args) 79 | return args, config 80 | 81 | def set_seed(seed=None): 82 | if seed is None: 83 | seed = torch.seed() 84 | 85 | random.seed(seed) 86 | np.random.seed(seed) 87 | os.environ['PYTHONHASHSEED'] = str(seed) 88 | torch.manual_seed(seed) 89 | torch.cuda.manual_seed(seed) 90 | torch.cuda.manual_seed_all(seed) 91 | cudnn.benchmark = True 92 | return seed 93 | 94 | def load_dataset(config): 95 | train_transform = MattingTransform(crop_size=config.DATA.CROP_SIZE, resize_size=config.DATA.RESIZE_SIZE) 96 | train_set = MattingDataset(config, train_transform) 97 | collate_fn = None 98 | train_loader = DataLoader(dataset=train_set, num_workers=config.DATA.NUM_WORKERS, batch_size=config.DATA.BATCH_SIZE, collate_fn=collate_fn, shuffle=True) 99 | return train_loader 100 | 101 | def build_lr_scheduler(optimizer, total_epochs): 102 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs) 103 | return lr_scheduler 104 | 105 | def warmup_lr(initial_lr, cur_iter, total_iter): 106 | return cur_iter/total_iter*initial_lr 107 | 108 | def update_lr(cur_lr, optimizer): 109 | for param_group in optimizer.param_groups: 110 | param_group['lr'] = cur_lr 111 | 112 | def get_lr(optimizer): 113 | return optimizer.param_groups[0]['lr'] 114 | 115 | def train(config, model, optimizer, train_loader, epoch, lr_scheduler, clip_grad_args): 116 | model = torch.nn.DataParallel(model).cuda() 117 | model.train() 118 | torch.cuda.empty_cache() 119 | 120 | print("===============================") 121 | print("EPOCH: {}/{}".format(epoch, config.TRAIN.EPOCHS)) 122 | for iteration, batch in tqdm(enumerate(train_loader, 1)): 123 | ### update lr 124 | if config.TRAIN.WARMUP_EPOCHS > 0 and epoch <= config.TRAIN.TRAIN.WARMUP_EPOCHS: 125 | cur_lr = warmup_lr(config.TRAIN.LR, len(train_loader)*(epoch-1)+iteration, config.TRAIN.WARMUP_EPOCHS*len(train_loader)) 126 | update_lr(cur_lr, optimizer) 127 | elif config.TRAIN.LR_DECAY is True and epoch > config.TRAIN.WARMUP_EPOCHS: 128 | lr_scheduler.step() 129 | cur_lr = lr_scheduler.get_lr()[0] 130 | else: 131 | cur_lr = optimizer.param_groups[0]['lr'] 132 | 133 | ### get data for general model 134 | batch_new = [] 135 | for item in batch: 136 | if type(item) == torch.Tensor: 137 | item = Variable(item).cuda() 138 | batch_new.append(item) 139 | [ori, mask, fg, bg, trimap] = batch_new[:5] 140 | optimizer.zero_grad() 141 | 142 | ### get model prediction 143 | if config.MODEL.CUT_AND_PASTE.TYPE.upper() == 'NONE': 144 | out_list = model(ori) 145 | else: 146 | raise NotImplementedError() 147 | 148 | ### cal loss 149 | predict_global, predict_local, predict_fusion, predict_global_side2, predict_global_side1, predict_global_side0 = out_list 150 | predict_fusion = predict_fusion.cuda() 151 | loss_global =get_crossentropy_loss(trimap, predict_global) 152 | loss_global_side2 = get_crossentropy_loss(trimap, predict_global_side2) 153 | loss_global_side1 = get_crossentropy_loss(trimap, predict_global_side1) 154 | loss_global_side0 = get_crossentropy_loss(trimap, predict_global_side0) 155 | loss_global = loss_global_side2+loss_global_side1+loss_global_side0+3*loss_global 156 | loss_local = get_alpha_loss(predict_local, mask, trimap) + get_laplacian_loss(predict_local, mask, trimap) 157 | loss_fusion_alpha = get_alpha_loss_whole_img(predict_fusion, mask) + get_laplacian_loss_whole_img(predict_fusion, mask) 158 | loss_fusion_comp = get_composition_loss_whole_img(ori, mask, fg, bg, predict_fusion) 159 | loss = loss_global/6+loss_local*2+loss_fusion_alpha*2+loss_fusion_alpha+loss_fusion_comp 160 | loss.backward() 161 | 162 | ### optimize and clip gradient 163 | if config.TRAIN.CLIP_GRAD is True: 164 | if clip_grad_args.moving_max_grad == 0: 165 | clip_grad_args.moving_max_grad = torch.nn.utils.clip_grad_norm_(model.parameters(), 1e+6).cpu().item() 166 | clip_grad_args.max_grad = clip_grad_args.moving_max_grad 167 | else: 168 | clip_grad_args.max_grad = torch.nn.utils.clip_grad_norm_(model.parameters(), 2 * clip_grad_args.moving_max_grad).cpu().item() 169 | clip_grad_args.moving_max_grad = clip_grad_args.moving_max_grad * clip_grad_args.moving_grad_moment + clip_grad_args.max_grad * ( 170 | 1 - clip_grad_args.moving_grad_moment) 171 | optimizer.step() 172 | 173 | if config.ENABLE_WANDB: 174 | loss_dict = { 175 | 'train/loss': loss.item(), 176 | 'train/loss_global': loss_global.item(), 177 | 'train/loss_local': loss_local.item(), 178 | 'train/loss_fusion_alpha': loss_fusion_alpha.item(), 179 | 'train/loss_fusion_comp': loss_fusion_comp.item(), 180 | 'train/lr': cur_lr, 181 | } 182 | wandb.log(loss_dict, step=(epoch-1)*len(train_loader)+iteration, commit=False) 183 | 184 | if config.TAG.startswith("debug") and iteration > 20: 185 | break 186 | 187 | def save_latest_checkpoint(config, model, epoch, **kwargs): 188 | model_save_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG) 189 | create_folder_if_not_exists(model_save_dir) 190 | model_out_path = os.path.join(model_save_dir, 'ckpt_latest.pth') 191 | model_dict = {'state_dict':model.state_dict(), 'epoch':epoch} 192 | model_dict.update(kwargs) 193 | torch.save(model_dict, model_out_path) 194 | 195 | def save_checkpoint(config, model, epoch, prefix, **kwargs): 196 | model_save_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG) 197 | create_folder_if_not_exists(model_save_dir) 198 | model_out_path = os.path.join(model_save_dir, prefix+'_ckpt.pth') 199 | model_dict = {'state_dict':model.state_dict(), 'epoch':epoch} 200 | model_dict.update(kwargs) 201 | torch.save(model_dict, model_out_path) 202 | 203 | def save_all_ckpts(config, epoch): 204 | model_save_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG, str(epoch)) 205 | create_folder_if_not_exists(model_save_dir) 206 | source_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG) 207 | file_list = os.listdir(source_dir) 208 | for name in file_list: 209 | if name.endswith('.pth'): 210 | shutil.copyfile(os.path.join(source_dir, name), os.path.join(model_save_dir, name)) 211 | 212 | def save_args(config): 213 | save_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG) 214 | create_folder_if_not_exists(save_dir) 215 | np.save(os.path.join(save_dir, 'args.npy'), config) 216 | 217 | def main(): 218 | _, config = get_args() 219 | save_args(config) 220 | 221 | now = datetime.datetime.now() 222 | str_time = now.strftime("%Y-%m-%d-%H:%M") 223 | logger = create_logger(os.path.join(TRAIN_LOGS_FOLDER, config.TAG+'_{}.log'.format(str_time))) 224 | 225 | if not torch.cuda.is_available(): 226 | raise Exception("No GPU and cuda available, please try again") 227 | gpuNums = torch.cuda.device_count() 228 | logger.info(f'Running with GPUs and the number of GPUs: {gpuNums}') 229 | 230 | # Test configs 231 | config.defrost() 232 | config.TEST.DATASET = 'P3M_500_P' 233 | config.TEST.TEST_METHOD = config.TEST.TEST_METHOD 234 | config.TEST.CKPT_NAME = 'latest_epoch' 235 | config.TEST.FAST_TEST = True 236 | config.TEST.TEST_PRIVACY = True 237 | config.TEST.SAVE_RESULT = False 238 | config.freeze() 239 | 240 | # config wandb 241 | if config.ENABLE_WANDB: 242 | project_name = 'p3m_journal' 243 | WANDB_API_KEY = get_wandb_key(WANDB_KEY_FILE) # setup your wandb account 244 | os.environ["WANDB_API_KEY"] = WANDB_API_KEY 245 | os.environ["WANDB_DIR"] = WANDB_LOGS_FOLDER 246 | 247 | if config.AUTO_RESUME and os.path.exists(os.path.join(CKPT_SAVE_FOLDER, config.TAG, 'wandb_run_id.txt')): 248 | with open(os.path.join(CKPT_SAVE_FOLDER, config.TAG, 'wandb_run_id.txt'), 'r') as f: 249 | run_id = f.readline().strip('\n') 250 | wandb.init(id=run_id, project=project_name, resume='must') 251 | else: 252 | run_id = wandb.util.generate_id() 253 | wandb.init(project=project_name, entity='xymsh', id=run_id, resume='allow') 254 | wandb.config.update(get_wandb_config(config)) 255 | wandb.run.name = '{}_{}'.format(config.TAG, str_time) 256 | with open(os.path.join(CKPT_SAVE_FOLDER, config.TAG, 'wandb_run_id.txt'), 'w') as f: 257 | f.write(run_id) 258 | logger.info('===> Enabled wandb run {} at {}'.format(run_id, WANDB_LOGS_FOLDER)) 259 | 260 | # data loader 261 | logger.info('===> Load data') 262 | train_loader = load_dataset(config) 263 | 264 | # build model 265 | logger.info('===> Build the model {}'.format(config.MODEL.TYPE)) 266 | model = build_model(config.MODEL.TYPE).cuda() 267 | start_epoch = config.TRAIN.START_EPOCH 268 | 269 | # build optimizer 270 | logger.info('===> Initialize optimizer {} and lr scheduler {}'.format(config.TRAIN.OPTIMIZER.TYPE, config.TRAIN.LR_DECAY)) 271 | if config.TRAIN.OPTIMIZER.TYPE.upper() == 'ADAM': 272 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.TRAIN.LR) 273 | else: 274 | raise NotImplementedError 275 | lr_scheduler = build_lr_scheduler(optimizer, config.TRAIN.EPOCHS-config.TRAIN.WARMUP_EPOCHS) if config.TRAIN.LR_DECAY is True else None 276 | # init clip grad params 277 | clip_grad_args = CN() 278 | clip_grad_args.moving_max_grad = 0.0 279 | clip_grad_args.moving_grad_moment = 0.999 280 | clip_grad_args.max_grad = 0.0 281 | 282 | # train parameters 283 | best_result = {} # key format: "[metric]_[dataset]" and "[metric]_[dataset]_epoch" 284 | errors = {} 285 | 286 | # auto resume 287 | if config.AUTO_RESUME: 288 | try: 289 | # load latest ckpt 290 | ckpt_path = os.path.join(CKPT_SAVE_FOLDER, config.TAG, 'ckpt_latest.pth') 291 | ckpt_dict = torch.load(ckpt_path) 292 | 293 | model.load_state_dict(ckpt_dict['state_dict']) 294 | start_epoch = ckpt_dict['epoch'] + 1 295 | optimizer.load_state_dict(ckpt_dict['optimizer']) 296 | if config.TRAIN.LR_DECAY: 297 | raise NotImplementedError 298 | 299 | clip_grad_args = ckpt_dict['clip_grad_args'] 300 | best_result = ckpt_dict['best_result'] 301 | logger.info('===> Auto resume succeeded') 302 | del ckpt_dict 303 | except: 304 | pass 305 | 306 | # start training 307 | logger.info('===> Start Training') 308 | for epoch in range(start_epoch, config.TRAIN.EPOCHS + 1): 309 | logger.info("TRAIN Epoch: {}/{}, Warmup: {}".format(epoch, config.TRAIN.EPOCHS, epoch<=config.TRAIN.WARMUP_EPOCHS)) 310 | train(config, model, optimizer, train_loader, epoch, lr_scheduler, clip_grad_args) 311 | 312 | if (config.TEST_FREQ > 0 and epoch % config.TEST_FREQ == 0) or (epoch == config.TRAIN.EPOCHS): 313 | logger.info("TEST Epoch: {}/{}".format(epoch, config.TRAIN.EPOCHS)) 314 | 315 | for test_dataset_choice in ["P3M_500_P", "P3M_500_NP"]: 316 | config.defrost() 317 | config.TEST.DATASET = test_dataset_choice 318 | if test_dataset_choice == 'P3M_500_NP': 319 | config.TEST.TEST_PRIVACY = False 320 | else: 321 | config.TEST.TEST_PRIVACY = True 322 | config.freeze() 323 | errors = test_p3m10k(config, model, logger) 324 | 325 | for m in ['SAD', 'SAD_PRIVACY']: 326 | m_dataset = "{}_{}".format(m, test_dataset_choice) 327 | if m in errors.keys(): 328 | if m_dataset not in best_result.keys() or errors[m] <= best_result[m_dataset]: 329 | best_result[m_dataset] = errors[m] 330 | best_result["{}_epoch".format(m_dataset)] = epoch 331 | save_checkpoint(config, model, epoch, "best_{}".format(m_dataset), best_result=best_result) 332 | 333 | if config.ENABLE_WANDB: 334 | log_errors = {} 335 | for k, v in errors.items(): 336 | log_errors['test_{}/'.format(test_dataset_choice)+k] = v 337 | wandb.log(log_errors, step=epoch*len(train_loader), commit=False) 338 | 339 | if config.ENABLE_WANDB: 340 | # record the best result 341 | wandb.run.summary.update(best_result) 342 | for k, v in best_result.items(): 343 | log_errors['best_result_{}/{}'.format(config.TEST.TEST_METHOD, k)] = v 344 | wandb.log(log_errors, step=epoch*len(train_loader), commit=False) 345 | 346 | save_latest_checkpoint(config, model, epoch, optimizer=optimizer.state_dict(), clip_grad_args=clip_grad_args) 347 | 348 | 349 | if config.AUTO_RESUME: 350 | break 351 | if epoch in [50, 100]: 352 | save_all_ckpts(config, epoch) 353 | 354 | 355 | if __name__ == "__main__": 356 | main() -------------------------------------------------------------------------------- /core/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rethinking Portrait Matting with Privacy Preserving 3 | 4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 5 | Licensed under the MIT License (see LICENSE for details) 6 | Github repo: https://github.com/ViTAE-Transformer/ViTAE-Transformer-Matting.git 7 | Paper link: https://arxiv.org/abs/2203.16828 8 | 9 | """ 10 | import os 11 | import shutil 12 | import cv2 13 | import numpy as np 14 | import torch 15 | import torch.distributed as dist 16 | import glob 17 | import functools 18 | from torchvision import transforms 19 | from config import * 20 | 21 | 22 | def get_wandb_config(config): 23 | wandb_args = {} 24 | wandb_args['model'] = config.MODEL.TYPE 25 | wandb_args['logname'] = config.TAG 26 | return wandb_args 27 | 28 | def get_wandb_key(filename): 29 | f = open(filename) 30 | keys = f.readline().strip() 31 | f.close() 32 | return keys 33 | 34 | ########################## 35 | ### Pure functions 36 | ########################## 37 | 38 | def GET_RANK(): 39 | if dist.is_initialized(): 40 | return dist.get_rank() 41 | else: 42 | return 0 43 | 44 | def GET_WORLD_SIZE(): 45 | if dist.is_initialized(): 46 | return dist.get_world_size() 47 | else: 48 | 1 49 | 50 | def is_main_process(): 51 | if not dist.is_initialized(): 52 | return True 53 | if dist.is_initialized() and dist.get_rank() == 0: 54 | return True 55 | return False 56 | 57 | def extract_pure_name(original_name): 58 | pure_name, extention = os.path.splitext(original_name) 59 | return pure_name 60 | 61 | def listdir_nohidden(path): 62 | new_list = [] 63 | for f in os.listdir(path): 64 | if not f.startswith('.'): 65 | new_list.append(f) 66 | new_list.sort() 67 | return new_list 68 | 69 | def create_folder_if_not_exists(folder_path): 70 | if not os.path.exists(folder_path): 71 | os.makedirs(folder_path) 72 | 73 | def refresh_folder(folder_path): 74 | if not os.path.exists(folder_path): 75 | os.makedirs(folder_path) 76 | else: 77 | shutil.rmtree(folder_path) 78 | os.makedirs(folder_path) 79 | 80 | def save_test_result(save_dir, predict): 81 | predict = (predict * 255).astype(np.uint8) 82 | cv2.imwrite(save_dir, predict) 83 | 84 | def generate_composite_img(img, alpha_channel): 85 | b_channel, g_channel, r_channel = cv2.split(img) 86 | b_channel = b_channel * alpha_channel 87 | g_channel = g_channel * alpha_channel 88 | r_channel = r_channel * alpha_channel 89 | alpha_channel = (alpha_channel*255).astype(b_channel.dtype) 90 | img_BGRA = cv2.merge((r_channel,g_channel,b_channel,alpha_channel)) 91 | return img_BGRA 92 | 93 | ########################## 94 | ### for dataset processing 95 | ########################## 96 | def trim_img(img): 97 | if img.ndim>2: 98 | img = img[:,:,0] 99 | return img 100 | 101 | def gen_trimap_with_dilate(alpha, kernel_size): 102 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size)) 103 | fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) 104 | fg = np.array(np.equal(alpha, 255).astype(np.float32)) 105 | dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1) 106 | erode = cv2.erode(fg, kernel, iterations=1) 107 | trimap = erode *255 + (dilate-erode)*128 108 | return trimap.astype(np.uint8) 109 | 110 | def normalize_batch_torch(data_t): 111 | normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], 112 | std=[0.229, 0.224, 0.225]) 113 | new_data = [] 114 | for i in range(data_t.shape[0]): 115 | new_data.append(normalize_transform(data_t[i])) 116 | return torch.stack(new_data, dim=0) 117 | 118 | ########################## 119 | ### Functions for fusion 120 | ########################## 121 | def gen_trimap_from_segmap_e2e(segmap): 122 | trimap = np.argmax(segmap, axis=1)[0] 123 | trimap = trimap.astype(np.int64) 124 | trimap[trimap==1]=128 125 | trimap[trimap==2]=255 126 | return trimap.astype(np.uint8) 127 | 128 | def get_masked_local_from_global(global_sigmoid, local_sigmoid): 129 | values, index = torch.max(global_sigmoid,1) 130 | index = index[:,None,:,:].float() 131 | ### index <===> [0, 1, 2] 132 | ### bg_mask <===> [1, 0, 0] 133 | bg_mask = index.clone() 134 | bg_mask[bg_mask==2]=1 135 | bg_mask = 1- bg_mask 136 | ### trimap_mask <===> [0, 1, 0] 137 | trimap_mask = index.clone() 138 | trimap_mask[trimap_mask==2]=0 139 | ### fg_mask <===> [0, 0, 1] 140 | fg_mask = index.clone() 141 | fg_mask[fg_mask==1]=0 142 | fg_mask[fg_mask==2]=1 143 | fusion_sigmoid = local_sigmoid*trimap_mask+fg_mask 144 | return fusion_sigmoid 145 | 146 | def get_masked_local_from_global_test(global_result, local_result): 147 | weighted_global = np.ones(global_result.shape) 148 | weighted_global[global_result==255] = 0 149 | weighted_global[global_result==0] = 0 150 | fusion_result = global_result*(1.-weighted_global)/255+local_result*weighted_global 151 | return fusion_result 152 | 153 | ####################################### 154 | ### Function to generate training data 155 | ####################################### 156 | 157 | def generate_paths_for_dataset(dataset="P3M10K", trainset="TRAIN"): 158 | ORI_PATH = DATASET_PATHS_DICT[dataset][trainset]['ORIGINAL_PATH'] 159 | MASK_PATH = DATASET_PATHS_DICT[dataset][trainset]['MASK_PATH'] 160 | FG_PATH = DATASET_PATHS_DICT[dataset][trainset]['FG_PATH'] 161 | BG_PATH = DATASET_PATHS_DICT[dataset][trainset]['BG_PATH'] 162 | FACEMASK_PATH = DATASET_PATHS_DICT[dataset][trainset]['PRIVACY_MASK_PATH'] 163 | mask_list = listdir_nohidden(MASK_PATH) 164 | total_number = len(mask_list) 165 | paths_list = [] 166 | for mask_name in mask_list: 167 | path_list = [] 168 | ori_path = ORI_PATH+extract_pure_name(mask_name)+'.jpg' 169 | mask_path = MASK_PATH+mask_name 170 | fg_path = FG_PATH+mask_name 171 | bg_path = BG_PATH+extract_pure_name(mask_name)+'.jpg' 172 | facemask_path = FACEMASK_PATH+mask_name 173 | path_list.append(ori_path) 174 | path_list.append(mask_path) 175 | path_list.append(fg_path) 176 | path_list.append(bg_path) 177 | path_list.append(facemask_path) 178 | paths_list.append(path_list) 179 | return paths_list 180 | 181 | 182 | def get_valid_names(*dirs): 183 | # Extract valid names 184 | name_sets = [get_name_set(d) for d in dirs] 185 | 186 | # Reduce 187 | def _join_and(a, b): 188 | return a & b 189 | 190 | valid_names = list(functools.reduce(_join_and, name_sets)) 191 | if len(valid_names) == 0: 192 | return None 193 | 194 | valid_names.sort() 195 | 196 | return valid_names 197 | 198 | def get_name_set(dir_name): 199 | path_list = glob.glob(os.path.join(dir_name, '*')) 200 | name_set = set() 201 | for path in path_list: 202 | name = os.path.basename(path) 203 | name = os.path.splitext(name)[0] 204 | if name.startswith(".DS"): continue 205 | name_set.add(name) 206 | return name_set 207 | 208 | def list_abspath(data_dir, ext, data_list): 209 | return [os.path.join(data_dir, name + ext) 210 | for name in data_list] -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | # More Results Demo 2 | 3 | We present more results of our P3M-Net(ViTAE-S) model on P3M-10k dataset as follows. 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /demo/face_obfuscation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/face_obfuscation.jpg -------------------------------------------------------------------------------- /demo/gif/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/gif/2.gif -------------------------------------------------------------------------------- /demo/gif/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/gif/3.gif -------------------------------------------------------------------------------- /demo/gif/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/gif/4.gif -------------------------------------------------------------------------------- /demo/gif/p_3cf7997c.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/gif/p_3cf7997c.gif -------------------------------------------------------------------------------- /demo/imgs/alpha/p_07141906.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_07141906.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_09ba26b4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_09ba26b4.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_22fc0130.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_22fc0130.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_514ca06a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_514ca06a.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_51c916fc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_51c916fc.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_818e689d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_818e689d.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_a09f6d7a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_a09f6d7a.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_bac3c1ff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_bac3c1ff.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_bc5cfad1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_bc5cfad1.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_bd6af989.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_bd6af989.png -------------------------------------------------------------------------------- /demo/imgs/alpha/p_d684dae3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_d684dae3.png -------------------------------------------------------------------------------- /demo/imgs/color/p_07141906.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_07141906.png -------------------------------------------------------------------------------- /demo/imgs/color/p_09ba26b4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_09ba26b4.png -------------------------------------------------------------------------------- /demo/imgs/color/p_22fc0130.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_22fc0130.png -------------------------------------------------------------------------------- /demo/imgs/color/p_514ca06a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_514ca06a.png -------------------------------------------------------------------------------- /demo/imgs/color/p_51c916fc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_51c916fc.png -------------------------------------------------------------------------------- /demo/imgs/color/p_818e689d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_818e689d.png -------------------------------------------------------------------------------- /demo/imgs/color/p_a09f6d7a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_a09f6d7a.png -------------------------------------------------------------------------------- /demo/imgs/color/p_bac3c1ff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_bac3c1ff.png -------------------------------------------------------------------------------- /demo/imgs/color/p_bc5cfad1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_bc5cfad1.png -------------------------------------------------------------------------------- /demo/imgs/color/p_bd6af989.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_bd6af989.png -------------------------------------------------------------------------------- /demo/imgs/color/p_d684dae3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_d684dae3.png -------------------------------------------------------------------------------- /demo/imgs/original/p_07141906.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_07141906.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_09ba26b4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_09ba26b4.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_22fc0130.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_22fc0130.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_514ca06a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_514ca06a.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_51c916fc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_51c916fc.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_818e689d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_818e689d.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_a09f6d7a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_a09f6d7a.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_bac3c1ff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_bac3c1ff.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_bc5cfad1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_bc5cfad1.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_bd6af989.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_bd6af989.jpg -------------------------------------------------------------------------------- /demo/imgs/original/p_d684dae3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_d684dae3.jpg -------------------------------------------------------------------------------- /demo/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/network.png -------------------------------------------------------------------------------- /demo/p3m-cp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/p3m-cp.png -------------------------------------------------------------------------------- /demo/p3m-net-variants.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/p3m-net-variants.png -------------------------------------------------------------------------------- /demo/p3m_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/p3m_dataset.png -------------------------------------------------------------------------------- /demo/results1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/results1.png -------------------------------------------------------------------------------- /demo/results2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/results2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.2 2 | opencv-python==4.4.0.46 3 | Pillow==8.0.0 4 | scikit-image==0.14.5 5 | scipy==1.5.3 6 | tqdm==4.51.0 7 | -------------------------------------------------------------------------------- /samples/original/p_015cd10e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/original/p_015cd10e.jpg -------------------------------------------------------------------------------- /samples/original/p_0865636e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/original/p_0865636e.jpg -------------------------------------------------------------------------------- /samples/original/p_819ea202.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/original/p_819ea202.jpg -------------------------------------------------------------------------------- /samples/result_alpha/p_015cd10e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_alpha/p_015cd10e.png -------------------------------------------------------------------------------- /samples/result_alpha/p_0865636e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_alpha/p_0865636e.png -------------------------------------------------------------------------------- /samples/result_alpha/p_819ea202.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_alpha/p_819ea202.png -------------------------------------------------------------------------------- /samples/result_color/p_015cd10e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_color/p_015cd10e.png -------------------------------------------------------------------------------- /samples/result_color/p_0865636e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_color/p_0865636e.png -------------------------------------------------------------------------------- /samples/result_color/p_819ea202.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_color/p_819ea202.png -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Rethinking Portrait Matting with Privacy Preserving 4 | # 5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | # Licensed under the MIT License (see LICENSE for details) 7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | 9 | result_dir='' 10 | alpha_dir='' 11 | trimap_dir='' 12 | 13 | python core/eval.py \ 14 | --pred_dir $result_dir \ 15 | --alpha_dir $alpha_dir \ 16 | --trimap_dir $trimap_dir \ -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Rethinking Portrait Matting with Privacy Preserving 4 | # 5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | # Licensed under the MIT License (see LICENSE for details) 7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | # Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | tag='' # run name 11 | dataset_choice='P3M_500_P' 12 | ckpt_name='ckpt_latest' 13 | test_choice='HYBRID' 14 | 15 | python core/test.py \ 16 | --tag=$tag \ 17 | --test_dataset=$dataset_choice \ 18 | --test_ckpt=$ckpt_name \ 19 | --test_method=$test_choice \ 20 | --fast_test \ 21 | --test_privacy -------------------------------------------------------------------------------- /scripts/test_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Rethinking Portrait Matting with Privacy Preserving 4 | # 5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | # Licensed under the MIT License (see LICENSE for details) 7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | # Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | arch='vitae' 11 | model_path='./models/P3M-Net_ViTAE-S_trained_on_P3M-10k.pth' 12 | dataset='P3M10K' 13 | valset='P3M_500_P' 14 | test_choice='HYBRID' 15 | 16 | nickname='test_dataset' 17 | result_dir='results/'${nickname} 18 | 19 | python core/infer.py \ 20 | --arch $arch \ 21 | --dataset $dataset \ 22 | --test_set $valset \ 23 | --model_path $model_path \ 24 | --test_choice $test_choice \ 25 | --test_result_dir $result_dir 26 | 27 | -------------------------------------------------------------------------------- /scripts/test_samples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Rethinking Portrait Matting with Privacy Preserving 4 | # 5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | # Licensed under the MIT License (see LICENSE for details) 7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | # Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | arch='vitae' 11 | model_path='./models/P3M-Net_ViTAE-S_trained_on_P3M-10k.pth' 12 | dataset='SAMPLES' 13 | test_choice='RESIZE' 14 | 15 | python core/infer.py \ 16 | --arch $arch \ 17 | --dataset $dataset \ 18 | --model_path $model_path \ 19 | --test_choice $test_choice \ -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Rethinking Portrait Matting with Privacy Preserving 4 | # 5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au) 6 | # Licensed under the MIT License (see LICENSE for details) 7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net 8 | # Paper link: https://arxiv.org/abs/2203.16828 9 | 10 | 11 | cfg='core/configs/ViTAE_S.yaml' # change config file here 12 | nEpochs=150 13 | lr=0.00001 14 | nickname=debug_train_vitae_s # change run name here 15 | batchSize=8 16 | 17 | python core/train.py \ 18 | --cfg $cfg \ 19 | --tag $nickname \ 20 | --nEpochs $nEpochs \ 21 | --lr=$lr \ 22 | --batchSize=$batchSize --------------------------------------------------------------------------------