├── .gitignore
├── LICENSE
├── README.md
├── core
├── README.md
├── config.py
├── config_yacs.py
├── configs
│ ├── ResNet34.yaml
│ ├── Swin_T.yaml
│ └── ViTAE_S.yaml
├── data.py
├── data_util.py
├── eval.py
├── evaluate.py
├── infer.py
├── logger.py
├── metrics.py
├── network
│ ├── RestNet34
│ │ ├── P3mNet.py
│ │ ├── __init__.py
│ │ └── resnet_mp.py
│ ├── Swin_T
│ │ ├── __init__.py
│ │ ├── decoder.py
│ │ └── swin_stem_pooling5_transformer.py
│ ├── ViTAE_S
│ │ ├── NormalCell.py
│ │ ├── SELayer.py
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── models.py
│ │ ├── token_performer.py
│ │ └── token_transformer.py
│ ├── __init__.py
│ ├── build_model.py
│ └── modules.py
├── test.py
├── train.py
└── util.py
├── demo
├── README.md
├── face_obfuscation.jpg
├── gif
│ ├── 2.gif
│ ├── 3.gif
│ ├── 4.gif
│ └── p_3cf7997c.gif
├── imgs
│ ├── alpha
│ │ ├── p_07141906.png
│ │ ├── p_09ba26b4.png
│ │ ├── p_22fc0130.png
│ │ ├── p_514ca06a.png
│ │ ├── p_51c916fc.png
│ │ ├── p_818e689d.png
│ │ ├── p_a09f6d7a.png
│ │ ├── p_bac3c1ff.png
│ │ ├── p_bc5cfad1.png
│ │ ├── p_bd6af989.png
│ │ └── p_d684dae3.png
│ ├── color
│ │ ├── p_07141906.png
│ │ ├── p_09ba26b4.png
│ │ ├── p_22fc0130.png
│ │ ├── p_514ca06a.png
│ │ ├── p_51c916fc.png
│ │ ├── p_818e689d.png
│ │ ├── p_a09f6d7a.png
│ │ ├── p_bac3c1ff.png
│ │ ├── p_bc5cfad1.png
│ │ ├── p_bd6af989.png
│ │ └── p_d684dae3.png
│ └── original
│ │ ├── p_07141906.jpg
│ │ ├── p_09ba26b4.jpg
│ │ ├── p_22fc0130.jpg
│ │ ├── p_514ca06a.jpg
│ │ ├── p_51c916fc.jpg
│ │ ├── p_818e689d.jpg
│ │ ├── p_a09f6d7a.jpg
│ │ ├── p_bac3c1ff.jpg
│ │ ├── p_bc5cfad1.jpg
│ │ ├── p_bd6af989.jpg
│ │ └── p_d684dae3.jpg
├── network.png
├── p3m-cp.png
├── p3m-net-variants.png
├── p3m_dataset.png
├── results1.png
└── results2.png
├── requirements.txt
├── samples
├── original
│ ├── p_015cd10e.jpg
│ ├── p_0865636e.jpg
│ └── p_819ea202.jpg
├── result_alpha
│ ├── p_015cd10e.png
│ ├── p_0865636e.png
│ └── p_819ea202.png
└── result_color
│ ├── p_015cd10e.png
│ ├── p_0865636e.png
│ └── p_819ea202.png
└── scripts
├── eval.sh
├── test.sh
├── test_dataset.sh
├── test_samples.sh
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | */DS_Store
3 | __pycache__/
4 | */__pycache__
5 | *.html
6 | *.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Sihan Ma and Jizhizi Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Rethinking Portrait Matting with Privacy Preserving [IJCV-2023]
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | This is the official repository of the paper [IJCV'23] Rethinking Portrait Matting with Privacy Preserving.
12 |
13 | For further questions, please contact Sihan Ma at [sima7436@uni.sydney.edu.au](mailto:sima7436@uni.sydney.edu.au) or Jizhizi Li at [jili8515@uni.sydney.edu.au](mailto:jili8515@uni.sydney.edu.au).
14 |
15 |
16 | Sihan Ma∗, Jizhizi Li∗, Jing Zhang, He Zhang, and Dacheng Tao
17 |
18 |
19 | Introduction |
20 | PPT and P3M-10k |
21 | P3M-Net |
22 | P3M-CP |
23 | Results |
24 | Train |
25 | Inference code |
26 | Statement
27 |
28 |
29 | 


30 |
31 | ***
32 | >:postbox: News
33 | >
34 | > [2024-3-31]: [Code for training](./core/README.md) is available!
35 | >
36 | > [2023-11-05]: Publish the ViTAE-S and SWIN-T backbone models pretrained on ImageNet that can be used to train our P3M-Net from scratch.
37 | >
38 | > [2023-03-28]: The paper has been accepted by the International Journal of Computer Vision ([IJCV](https://www.springer.com/journal/11263))! 🎉
39 | >
40 | > [2022-03-31]: Publish the inference code and the pretrained model that can be used to test with our SOTA model P3M-Net(ViTAE-S) on your own privacy-preserving or normal portrait images.
41 | >
42 | > [2021-12-06]: Publish the [P3M-10k](#ppt-setting-and-p3m-10k-dataset) dataset.
43 | >
44 | > [2021-11-21]: Publish the conference paper ACM MM 2021 "[Privacy-preserving Portrait Matting](https://dl.acm.org/doi/10.1145/3474085.3475512)". The code and data are available at [github repo](https://github.com/JizhiziLi/P3M).
45 | >
46 | > Other applications of ViTAE Transformer include: [image classification](https://github.com/ViTAE-Transformer/ViTAE-Transformer/tree/main/Image-Classification) | [object detection](https://github.com/ViTAE-Transformer/ViTAE-Transformer/tree/main/Object-Detection) | [semantic segmentation](https://github.com/ViTAE-Transformer/ViTAE-Transformer/tree/main/Semantic-Segmentation) | [animal pose segmentation](https://github.com/ViTAE-Transformer/ViTAE-Transformer/tree/main/Animal-Pose-Estimation) | [remote sensing](https://github.com/ViTAE-Transformer/ViTAE-Transformer-Remote-Sensing)
47 |
48 |
49 | ## Introduction
50 |
51 |
52 | Recently, there has been an increasing concern about the privacy issue raised by using personally identifiable information in machine learning. However, previous portrait matting methods were all based on identifiable portrait images.
53 |
54 | To fill the gap, we present P3M-10k in this paper, which is the first large-scale anonymized benchmark for Privacy-Preserving Portrait Matting. P3M-10k consists of 10,000 high-resolution face-blurred portrait images along with high-quality alpha mattes. We systematically evaluate both trimap-free and trimap-based matting methods on P3M-10k and find that existing matting methods show different generalization capabilities when following the Privacy-Preserving Training (PPT) setting, 𝑖.𝑒., training on face-blurred images and testing on arbitrary images.
55 |
56 | To devise a better trimap-free portrait matting model, we propose P3M-Net, consisting of three carefully designed integration modules that can perform privacy-insensitive semantic perception and detail-reserved matting simultaneously. We further design multiple variants of P3MNet with different CNN and transformer backbones and identify the difference of their generalization abilities.
57 |
58 | To further mitigate this issue, we devise a simple yet effective Copy and Paste strategy (P3M-CP) that can borrow facial information from public celebrity images without privacy concerns and direct the network to reacquire the face context at both data and feature level. P3M-CP only brings a few additional computations during training, while enabling the matting model to process both face-blurred and normal images without extra effort during inference.
59 |
60 | Extensive experiments on P3M-10k demonstrate the superiority of P3M-Net over state-of-the-art methods and the effectiveness of P3MCP in improving the generalization ability of P3M-Net, implying a great significance of P3M for future research and real-world applications.
61 |
62 |
63 | ## PPT Setting and P3M-10k Dataset
64 |
65 |
66 | PPT Setting: Due to the privacy concern, we propose the Privacy-Preserving Training (PPT) setting in portrait matting, 𝑖.𝑒., training on privacy-preserved images (𝑒.𝑔., processed by face obfuscation) and testing on arbitraty images with or without privacy content. As an initial step towards privacy-preserving portrait matting problem, we only define the identifiable faces in frontal and some profile portrait images as the private content in this work.
67 |
68 |
69 | P3M-10k Dataset: To further explore the effect of PPT setting, we establish the first large-scale privacy-preserving portrait matting benchmark named P3M-10k. It contains 10,000 annonymized high-resolution portrait images by face obfuscation along with high-quality ground truth alpha mattes. Specifically, we carefully collect, filter, and annotate about 10,000 high-resolution images from the Internet with free use license. There are 9,421 images in the training set and 500 images in the test set, denoted as P3M-500-P. In addition, we also collect and annotate another 500 public celebrity images from the Internet without face obfuscation, to evaluate the performance of matting models under the PPT setting on normal portrait images, denoted as P3M-500-NP. We show some examples as below, where (a) is from the training set, (b) is from P3M-500-P, and (c) is from P3M-500-NP.
70 |
71 |
72 | P3M-10k and the facemask are now published!! You can get access to it from the following links, please make sure that you have read and agreed to the agreement. Note that the facemask is not used in our work. So it's optional to download it.
73 |
74 |
78 |
79 | | Dataset | Dataset Link
(Google Drive)
| Dataset Link
(Baidu Wangpan 百度网盘)
| Dataset Release Agreement|
80 | | :----:| :----: | :----: | :----: |
81 | |P3M-10k|[Link](https://drive.google.com/file/d/1odzHp2zbQApLm90HH_Cvr5b5OwJVhEQG/view?usp=sharing)|[Link](https://pan.baidu.com/s/1aEmEXO5BflSp5hiA-erVBA?pwd=cied) (pw: cied) |[Agreement (MIT License)](https://jizhizili.github.io/files/p3m_dataset_agreement/P3M-10k_Dataset_Release_Agreement.pdf)|
82 |
83 |
84 |
85 | 
86 |
87 | ## P3M-Net and Variants
88 |
89 | 
90 |
91 | Our P3M-Net network models the comprehensive interactions between the sharing encoder and two decoders through three carefully designed integration modules, i.e., 1) a tripartite-feature integration (TFI) module to enable the interaction between encoder and two decoders; 2) a deep bipartite-feature integration (dBFI) module to enhance the interaction between the encoder and segmentation decoder; and 3) a shallow bipartitefeature integration (sBFI) module to promote the interaction between the encoder and matting decoder.
92 |
93 | We design three variants of P3M Basic Blocks based on CNN and vision transformers, namely P3M-Net(ResNet-34), P3M-Net(Swin-T), P3M-Net(ViTAE-S). We leverage the ability of transformers in modeling long-range dependency to extract more accurate global information and the locality modelling ability to reserve lots of details in the transition areas. The structures are shown in the following figures.
94 |
95 | 
96 |
97 |
103 |
104 |
105 | ## P3M-CP
106 |
107 |
108 | To further improve the generalization ability of P3M-Net, we devise a simple yet effective Copy and Paste strategy (P3M-CP) that can borrow facial information from publicly available celebrity images without privacy concerns and guide the network to reacquire the face context at both data and feature level, namely P3M-ICP and P3M-FCP. The pipeline is shown in the following figure.
109 |
110 | 
111 |
112 | ## Results
113 |
114 | We test our network on our proposed P3M-500-P and P3M-500-NP and compare with previous SOTA methods, we list the results as below.
115 |
116 | 
117 | 
118 |
119 | ## Quick start - Train
120 |
121 | Please follow this [instruction page](./core/README.md).
122 |
123 | ## Inference Code - How to Test on Your Images
124 |
125 | Here we provide the procedure of testing on sample images by our pretrained P3M-Net(ViTAE-S) model:
126 |
127 | 1. Setup environment following this [instruction page](./core/README.md);
128 |
129 | 2. Insert the path `REPOSITORY_ROOT_PATH` in the file `core/config.py`;
130 |
131 | 3. Download the pretrained P3M-Net(ViTAE-S) model from here ([Google Drive](https://drive.google.com/file/d/1QbSjPA_Mxs7rITp_a9OJiPeFRDwxemqK/view?usp=sharing) | [Baidu Wangpan](https://pan.baidu.com/s/19FuiR1RwamqvxfhdXDL1fg) (pw: hxxy))) and unzip to the folder `models/pretrained/`;
132 |
133 | 4. Save your sample images in folder `samples/original/.`;
134 |
135 | 5. Setup parameters in the file `scripts/test_samples.sh` and run by:
136 |
137 | `chmod +x scripts/test_samples.sh`
138 |
139 | `scripts/test_samples.sh`;
140 |
141 | 6. The results of alpha matte and transparent color image will be saved in folder `samples/result_alpha/.` and `samples/result_color/.`.
142 |
143 | We show some sample images, the predicted alpha mattes, and their transparent results as below. We use the pretrained P3M-Net(ViTAE-S) model from section P3M-Net and Variants with `RESIZE` test strategy.
144 |
145 | 

146 | 

147 | 

148 |
149 |
150 | ## Statement
151 |
152 | If you are interested in our work, please consider citing the following:
153 |
154 | ```
155 | @article{rethink_p3m,
156 | title={Rethinking Portrait Matting with Pirvacy Preserving},
157 | author={Ma, Sihan and Li, Jizhizi and Zhang, Jing and Zhang, He and Tao, Dacheng},
158 | journal={International Journal of Computer Vision},
159 | publisher={Springer},
160 | ISSN={1573-1405},
161 | year={2023}
162 | }
163 | ```
164 |
165 | This project is under MIT licence.
166 |
167 | For further questions, please contact Sihan Ma at [sima7436@uni.sydney.edu.au](mailto:sima7436@uni.sydney.edu.au) or Jizhizi Li at [jili8515@uni.sydney.edu.au](mailto:jili8515@uni.sydney.edu.au).
168 |
169 |
170 | ## Relevant Projects
171 |
172 |
173 |
174 |
175 | [1] Deep Automatic Natural Image Matting, IJCAI, 2021 | [Paper](https://www.ijcai.org/proceedings/2021/0111.pdf) | [Github](https://github.com/JizhiziLi/AIM)
176 |
Jizhizi Li, Jing Zhang, and Dacheng Tao
177 |
178 | [2] Privacy-preserving Portrait Matting, ACM MM, 2021 | [Paper](https://dl.acm.org/doi/10.1145/3474085.3475512) | [Github](https://github.com/JizhiziLi/P3M)
179 |
Jizhizi Li∗, Sihan Ma∗, Jing Zhang, Dacheng Tao
180 |
181 | [3] Bridging Composite and Real: Towards End-to-end Deep Image Matting, IJCV, 2022 | [Paper](https://link.springer.com/article/10.1007/s11263-021-01541-0) | [Github](https://github.com/JizhiziLi/GFM)
182 |
Jizhizi Li∗, Jing Zhang∗, Stephen J. Maybank, Dacheng Tao
183 |
184 | [4] Referring Image Matting, CVPR, 2023 | [Paper](https://arxiv.org/pdf/2206.05149.pdf) | [Github](https://github.com/JizhiziLi/RIM)
185 |
Jizhizi Li, Jing Zhang, and Dacheng Tao
186 |
187 |
188 | [5] Deep Image Matting: A Comprehensive Survey, ArXiv, 2023 | [Paper](https://arxiv.org/abs/2304.04672) | [Github](https://github.com/jizhiziLi/matting-survey)
189 |
Jizhizi Li, Jing Zhang, and Dacheng Tao
--------------------------------------------------------------------------------
/core/README.md:
--------------------------------------------------------------------------------
1 | Rethinking Portrait Matting with Privacy Preserving
2 |
3 |
4 | Installation |
5 | Prepare Datasets |
6 | Pretrained Models |
7 | Train on P3M-10k |
8 | Test |
9 | Inference
10 |
11 |
12 |
13 | ## Installation
14 | Requirements:
15 |
16 | - Python 3.7.7+ with Numpy and scikit-image
17 | - Pytorch (version>=1.7.1)
18 | - Torchvision (version 0.8.2)
19 |
20 | 1. Clone this repository
21 |
22 | `git clone https://github.com/ViTAE-Transformer/P3M-Net.git`;
23 |
24 | 2. Go into the repository
25 |
26 | `cd P3M-Net`;
27 |
28 | 3. Create conda environment and activate
29 |
30 | `conda create -n p3m python=3.7.7`,
31 |
32 | `conda activate p3m`;
33 |
34 | 4. Install dependencies, install pytorch and torchvision separately if you need
35 |
36 | `pip install -r requirements.txt`,
37 |
38 | `conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch`.
39 |
40 | Our code has been tested with Python 3.7.7, Pytorch 1.7.1, Torchvision 0.8.2, CUDA 10.2 on Ubuntu 18.04.
41 |
42 |
43 | ## Prepare Datasets
44 |
45 |
49 |
50 | | Dataset | Dataset Link
(Google Drive)
| Dataset Link
(Baidu Wangpan 百度网盘)
| Dataset Release Agreement|
51 | | :----:| :----: | :----: | :----: |
52 | |P3M-10k|[Link](https://drive.google.com/file/d/1odzHp2zbQApLm90HH_Cvr5b5OwJVhEQG/view?usp=sharing)|[Link](https://pan.baidu.com/s/1aEmEXO5BflSp5hiA-erVBA?pwd=cied) (pw: cied) |[Agreement (MIT License)](https://jizhizili.github.io/files/p3m_dataset_agreement/P3M-10k_Dataset_Release_Agreement.pdf)|
53 |
54 |
55 | 1. Download the datasets P3M-10k from the above links and unzip to the folders `P3M_DATASET_ROOT_PATH`, set up the configuratures in the file `core/config.py`. Please make sure that you have checked out and agreed to the agreements.
56 |
57 | After dataset preparation, the structure of the complete datasets should be like the following.
58 | ```text
59 | P3M-10k
60 | ├── train
61 | ├── blurred_image
62 | ├── mask (alpha mattes)
63 | ├── fg_blurred
64 | ├── bg
65 | ├── facemask
66 | ├── validation
67 | ├── P3M-500-P
68 | ├── blurred_image
69 | ├── mask
70 | ├── trimap
71 | ├── facemask
72 | ├── P3M-500-NP
73 | ├── original_image
74 | ├── mask
75 | ├── trimap
76 | ```
77 |
78 | 2. If you want to test on RWP Test set, please download the original images and alpha mattes at this link.
79 |
80 | After datasets preparation, the structure of the complete datasets should be like the following.
81 | ```text
82 | RealWorldPortrait-636
83 | ├── image
84 | ├── alpha
85 | ├── ...
86 | ```
87 |
88 | ## Pretrained Models
89 |
90 | Here we provide the model P3M-Net(ViTAE-S) that is trained on P3M-10k for testing.
91 |
92 | | Model| Google Drive | Baidu Wangpan(百度网盘) |
93 | | :----: | :----:| :----: |
94 | | P3M-Net(ViTAE-S) | [Link](https://drive.google.com/file/d/1QbSjPA_Mxs7rITp_a9OJiPeFRDwxemqK/view?usp=sharing) | [Link](https://pan.baidu.com/s/19FuiR1RwamqvxfhdXDL1fg) (pw: hxxy) |
95 |
96 | Here we provide the pretrained models of all backbones for training.
97 |
98 | | Model| Google Drive | Baidu Wangpan(百度网盘) |
99 | | :----: | :----:| :----: |
100 | | pretrained models | [Link](https://drive.google.com/file/d/1V2xt0BWCVx550Ll7GGfquvopTLseX9gY/view?usp=sharing) | [Link](https://pan.baidu.com/s/1eJ7mTLQszEtMJHJ2zn3dag?pwd=gxn9) (pw:gxn9) |
101 |
102 |
103 | ## Train on P3M-10k
104 |
105 | 1. Download P3M-10k dataset in root `P3M_DATASET_ROOT_PATH` (set up in `core/config.py`);
106 |
107 | 2. Download the pretrained models of all backbones in the previous section, and set up the output folder `REPOSITORY_ROOT_PATH` in `core/config.py`. The folder structure should be like the following,
108 | ```text
109 | [REPOSITORY_ROOT_PATH]
110 | ├── logs
111 | ├── models
112 | ├── pretrained
113 | ├── r34mp_pretrained_imagenet.pth.tar
114 | ├── swin_pretrained_epoch_299.pth
115 | ├── vitae_pretrained_ckpt.pth.tar
116 | ├── trained
117 | ```
118 |
119 | 3. Set up parameters in `scripts/train.sh`, specify config file `cfg`, name for the run `nickname`, etc. Run the file:
120 |
121 | `chmod +x scripts/train.sh`
122 |
123 | `./scripts/train.sh`
124 |
125 | ## Test
126 |
127 | Set up parameters in `scripts/test.sh`, specify config file `cfg`, name for the run `nickname`, etc. Run the file:
128 |
129 | `chmod +x scripts/test.sh`
130 |
131 | `./scripts/test.sh`
132 |
133 |
134 |
135 | Test using provided model
136 |
137 | ### Test on P3M-10k
138 |
139 | 1. Download provided model on P3M-10k as shown in the previous section, unzip to the folder `models/pretrained/`;
140 |
141 | 2. Download P3M-10k dataset in root `P3M_DATASET_ROOT_PATH` (set up in `core/config.py`);
142 |
143 | 3. Setup parameters in `scripts/test_dataset.sh`, choose `dataset=P3M10K`, and `valset=P3M_500_NP` or `valset=P3M_500_P` depends on which validation set you want to use, run the file:
144 |
145 | `chmod +x scripts/test_dataset.sh`
146 |
147 | `./scripts/test_dataset.sh`
148 |
149 | 4. The results of the alpha matte will be saved in folder `args.test_result_dir`. Note that there may be some slight differences of the evaluation results with the ones reported in the paper due to some packages versions differences and the testing strategy.
150 |
151 | ### Test on RWP
152 |
153 | 1. Download provided model on P3M-10k as shown in the previous section, unzip to the folder `models/pretrained/`;
154 |
155 | 2. Download RWP dataset in root `RWP_TEST_SET_ROOT_PATH` (set up in `core/config.py`). Download link is here;
156 |
157 | 3. Setup parameters in `scripts/test_dataset.sh`, choose `dataset=RWP` and `valset=RWP`, run the file:
158 |
159 | `chmod +x scripts/test_dataset.sh`
160 |
161 | `./scripts/test_dataset.sh`
162 |
163 | 4. The results of the alpha matte will be saved in folder `args.test_result_dir`. Note that there may be some slight differences of the evaluation results with the ones reported in the paper due to some packages versions differences and the testing strategy.
164 |
165 | ### Test on Samples
166 |
167 | 1. Download provided model on P3M-10k as shown in the previous section, unzip to the folder `models/pretrained/`;
168 |
169 | 2. Download images in root `SAMPLES_ROOT_PATH/original` (set up in config.py)
170 |
171 | 3. Set up parameters in `scripts/test_samples.sh`, and run the file:
172 |
173 | `chmod +x samples/original/*`
174 |
175 | `chmod +x scripts/test_samples.sh`
176 |
177 | `./scripts/test_samples.sh`
178 |
179 | 4. The results of the alpha matte will be saved in folder `SAMPLES_RESULT_ALPHA_PATH` (set up in config.py). The color results will be saved in folder `SAMPLES_RESULT_COLOR_PATH` (set up in config.py). Note that there may be some slight differences of the evaluation results with the ones reported in the paper due to some packages versions differences and the testing strategy.
180 |
181 |
182 |
183 | ## Inference Code - How to Test on Your Images
184 |
185 | Here we provide the procedure of testing on sample images by our pretrained P3M-Net(ViTAE-S) model:
186 |
187 | 1. Setup environment following this [instruction page](https://github.com/ViTAE-Transformer/P3M-Net/tree/main/core);
188 |
189 | 2. Insert the path `REPOSITORY_ROOT_PATH` in the file `core/config.py`;
190 |
191 | 3. Download the pretrained P3M-Net(ViTAE-S) model from here ([Google Drive](https://drive.google.com/file/d/1QbSjPA_Mxs7rITp_a9OJiPeFRDwxemqK/view?usp=sharing) | [Baidu Wangpan](https://pan.baidu.com/s/19FuiR1RwamqvxfhdXDL1fg) (pw: hxxy))) and unzip to the folder `models/pretrained/`;
192 |
193 | 4. Save your sample images in folder `samples/original/.`;
194 |
195 | 5. Setup parameters in the file `scripts/test_samples.sh` and run by:
196 |
197 | `chmod +x scripts/test_samples.sh`
198 |
199 | `scripts/test_samples.sh`;
200 |
201 | 6. The results of alpha matte and transparent color image will be saved in folder `samples/result_alpha/.` and `samples/result_color/.`.
202 |
203 | We show some sample images, the predicted alpha mattes, and their transparent results as below. We use the pretrained P3M-Net(ViTAE-S) model from section P3M-Net and Variants with `RESIZE` test strategy.
204 |
205 | 

206 | 

207 | 

208 |
209 |
--------------------------------------------------------------------------------
/core/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 | config file.
4 |
5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | Licensed under the MIT License (see LICENSE for details)
7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | """
11 |
12 |
13 | ########## Root paths and logging files paths
14 | REPOSITORY_ROOT_PATH = 'p3m_out/'
15 | P3M_DATASET_ROOT_PATH = 'P3M-10k/'
16 | RWP_TEST_SET_ROOT_PATH = ''
17 |
18 | WANDB_KEY_FILE = '' # setup wandb key file path. Write your wandb key in a text file and set up the path for this file here.
19 |
20 | ########## Paths for training
21 | TRAIN_LOGS_FOLDER = REPOSITORY_ROOT_PATH+'logs/train_logs/'
22 | TEST_LOGS_FOLDER = REPOSITORY_ROOT_PATH+'logs/test_logs/'
23 | WANDB_LOGS_FOLDER = REPOSITORY_ROOT_PATH+'logs/'
24 | CKPT_SAVE_FOLDER = REPOSITORY_ROOT_PATH+'models/trained/'
25 | DEBUG_FOLDER = REPOSITORY_ROOT_PATH+'debug/'
26 | TEST_RESULT_FOLDER = REPOSITORY_ROOT_PATH+'result/'
27 |
28 | ######### Paths of datasets
29 | VALID_TEST_DATASET_CHOICE = [
30 | 'P3M_500_P',
31 | 'P3M_500_NP',
32 | 'VAL500P',
33 | 'VAL500NP',
34 | 'VAL500P_NORMAL',
35 | 'VAL500P_MOSAIC',
36 | 'VAL500P_ZERO',
37 | 'RealWorldPortrait636']
38 | VALID_TEST_DATA_CHOICE = [*VALID_TEST_DATASET_CHOICE, 'SAMPLES']
39 |
40 | DATASET_PATHS_DICT={
41 | 'P3M10K':{
42 | 'TRAIN':{
43 | 'ROOT_PATH':P3M_DATASET_ROOT_PATH+'train/',
44 | 'ORIGINAL_PATH':P3M_DATASET_ROOT_PATH+'train/facemask_blurred/',
45 | 'MASK_PATH':P3M_DATASET_ROOT_PATH+'train/mask/',
46 | 'FG_PATH':P3M_DATASET_ROOT_PATH+'train/fg_blurred/',
47 | 'BG_PATH':P3M_DATASET_ROOT_PATH+'train/bg/',
48 | 'PRIVACY_MASK_PATH': P3M_DATASET_ROOT_PATH+'train/facemask/',
49 | 'NORMAL_ORIGINAL_PATH': P3M_DATASET_ROOT_PATH+'train/original/',
50 | 'SAMPLE_NUMBER':9421
51 | },
52 | 'P3M_500_P':{
53 | 'ROOT_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/',
54 | 'ORIGINAL_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/blurred_image/',
55 | 'MASK_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/mask/',
56 | 'TRIMAP_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/trimap/',
57 | 'PRIVACY_MASK_PATH': P3M_DATASET_ROOT_PATH+'validation/P3M-500-P/facemask/',
58 | 'SAMPLE_NUMBER':500
59 | },
60 | 'P3M_500_NP':{
61 | 'ROOT_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-NP/',
62 | 'ORIGINAL_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-NP/original_image/',
63 | 'MASK_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-NP/mask/',
64 | 'TRIMAP_PATH':P3M_DATASET_ROOT_PATH+'validation/P3M-500-NP/trimap/',
65 | 'PRIVACY_MASK_PATH': None,
66 | 'SAMPLE_NUMBER':500
67 | },
68 | },
69 | 'RWP': {
70 | 'RWP': {
71 | 'ROOT_PATH': RWP_TEST_SET_ROOT_PATH,
72 | 'ORIGINAL_PATH': RWP_TEST_SET_ROOT_PATH+'image/',
73 | 'MASK_PATH': RWP_TEST_SET_ROOT_PATH+'alpha/',
74 | 'TRIMAP_PATH': None,
75 | 'PRIVACY_MASK_PATH': None,
76 | 'SAMPLE_NUMBER': 636
77 | }
78 | }
79 | }
80 |
81 | ######### Paths of samples for test
82 |
83 | SAMPLES_ORIGINAL_PATH = REPOSITORY_ROOT_PATH+'samples/original/'
84 | SAMPLES_RESULT_ALPHA_PATH = REPOSITORY_ROOT_PATH+'samples/result_alpha/'
85 | SAMPLES_RESULT_COLOR_PATH = REPOSITORY_ROOT_PATH+'samples/result_color/'
86 |
87 | ######### Paths of pretrained model
88 | PRETRAINED_R34_MP = REPOSITORY_ROOT_PATH+'models/pretrained/r34mp_pretrained_imagenet.pth.tar'
89 | PRETRAINED_SWIN_STEM_POOLING5 = REPOSITORY_ROOT_PATH+'models/pretrained/swin_pretrained_epoch_299.pth'
90 | PRETRAINED_VITAE_NORC_MAXPOOLING_BIAS_BASIC_STAGE4_14 = REPOSITORY_ROOT_PATH+'models/pretrained/vitae_pretrained_ckpt.pth.tar'
91 |
92 | ######### Test config
93 | MAX_SIZE_H = 1600
94 | MAX_SIZE_W = 1600
95 | MIN_SIZE_H = 512
96 | MIN_SIZE_W = 512
97 | SHORTER_PATH_LIMITATION = 1080
98 |
--------------------------------------------------------------------------------
/core/config_yacs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ipdb
3 | import datetime
4 | import yaml
5 | from yacs.config import CfgNode as CN
6 |
7 | _C = CN()
8 |
9 | # base config files
10 | _C.BASE = ['']
11 |
12 | # -----------------------------------------------------------------------------
13 | # Data settings
14 | # -----------------------------------------------------------------------------
15 | _C.DATA = CN()
16 | _C.DATA.BATCH_SIZE = 8
17 | _C.DATA.DATASET = 'P3M10K'
18 | _C.DATA.TRAIN_SET = 'TRAIN' # other options: [TRAIN_NORMAL, TRAIN_MOSAIC, TRAIN_ZERO]
19 | _C.DATA.HYBRID_OBFUSCATION = None
20 | _C.DATA.NUM_WORKERS = 8
21 | _C.DATA.RESIZE_SIZE = 512
22 | _C.DATA.CROP_SIZE = [512, 768, 1024] # RATIO: 1, 1.5, 2
23 |
24 | # global data setting | for two stage network only
25 | _C.DATA.GLOBAL = CN()
26 | _C.DATA.GLOBAL.CROP_SIZE = [256, 384, 512]
27 | _C.DATA.GLOBAL.RESIZE_SIZE = 256
28 |
29 | # local data setting | for two stage network only
30 | _C.DATA.LOCAL = CN()
31 | _C.DATA.LOCAL.GLOBAL_SIZE = None
32 | _C.DATA.LOCAL.PATCH_RESIZE_SIZE = 64
33 | _C.DATA.LOCAL.PATCH_SIZE = 64
34 | _C.DATA.LOCAL.PATCH_NUMBER = 64
35 |
36 | # Settings for cut and paste at data level
37 | _C.DATA.CUT_AND_PASTE = CN()
38 | _C.DATA.CUT_AND_PASTE.TYPE = 'NONE' # CHOICES: ['NONE', 'VANILLA', 'AUG', 'RESIZE2FILL', 'GRID_SAMPLE']
39 | _C.DATA.CUT_AND_PASTE.SOURCE_DATASET = '' # CHOICES: ['SELF', 'P3M10K', 'CELEBAMASK_HQ']
40 | _C.DATA.CUT_AND_PASTE.PROB = 0.5
41 | # cut and paste: aug
42 | _C.DATA.CUT_AND_PASTE.AUG = CN()
43 | _C.DATA.CUT_AND_PASTE.AUG.DEGREE = 30
44 | _C.DATA.CUT_AND_PASTE.AUG.SCALE = [0.8,1.2]
45 | _C.DATA.CUT_AND_PASTE.AUG.SHEAR = None
46 | _C.DATA.CUT_AND_PASTE.AUG.FLIP = [0.5,0]
47 | _C.DATA.CUT_AND_PASTE.AUG.RANDOM_PASTE = False
48 | # cut and paste: grid_sample
49 | _C.DATA.CUT_AND_PASTE.GRID_SAMPLE = CN()
50 | _C.DATA.CUT_AND_PASTE.GRID_SAMPLE.SELECT_RANGE = [10, 2, 10, 2]
51 | _C.DATA.CUT_AND_PASTE.GRID_SAMPLE.DOWN_SCALE = 4
52 |
53 | # -----------------------------------------------------------------------------
54 | # Model settings
55 | # -----------------------------------------------------------------------------
56 | _C.MODEL = CN()
57 | _C.MODEL.TYPE = ''
58 | _C.MODEL.PRETRAINED = True
59 |
60 | # Settings for swin transformer
61 | _C.MODEL.SWIN = CN()
62 | _C.MODEL.SWIN.PATCH_SIZE = 4
63 | _C.MODEL.SWIN.IN_CHANS = 3
64 | _C.MODEL.SWIN.EMBED_DIM = 96
65 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
66 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
67 | _C.MODEL.SWIN.WINDOW_SIZE = 7
68 | _C.MODEL.SWIN.MLP_RATIO = 4.
69 | _C.MODEL.SWIN.QKV_BIAS = True
70 | _C.MODEL.SWIN.QK_SCALE = None
71 | _C.MODEL.SWIN.APE = False
72 | _C.MODEL.SWIN.PATCH_NORM = True
73 | _C.MODEL.SWIN.USE_CHECKPOINT = False
74 | _C.MODEL.SWIN.DROP_RATE = 0.0
75 | _C.MODEL.SWIN.DROP_PATH_RATE = 0.2
76 |
77 | # Settings for cut and paste at feature level
78 | _C.MODEL.CUT_AND_PASTE = CN()
79 | _C.MODEL.CUT_AND_PASTE.TYPE = 'NONE' # CHOICES: ['NONE', 'CP', 'SHUFFLE'], cp shuffle 只能二选一
80 | _C.MODEL.CUT_AND_PASTE.PROB = 0.5
81 | _C.MODEL.CUT_AND_PASTE.START_EPOCH = 50
82 | _C.MODEL.CUT_AND_PASTE.LAYER = [] # can be a list of layers
83 | _C.MODEL.CUT_AND_PASTE.DETACH = True
84 |
85 | # Settings for cut and paste cp
86 | _C.MODEL.CUT_AND_PASTE.CP = CN()
87 | _C.MODEL.CUT_AND_PASTE.CP.TYPE = 'VANILLA' # CHOICES: ['VANILLA', 'MULTISCALE'], control the fea cut and paste data class type
88 | _C.MODEL.CUT_AND_PASTE.CP.MODEL = 'SELF' # CHOICES: ['COPY_EVERY_ITER', 'COPY_EVERY_EPOCH', 'SELF'], control the model used to extract source fea
89 | _C.MODEL.CUT_AND_PASTE.CP.SOURCE_DATASET = 'SELF' # CHOIES: ['SELF', 'P3M10K', 'CELEBAMASK_HQ'], control the data used to extract source fea
90 | _C.MODEL.CUT_AND_PASTE.CP.SOURCE_BATCH_SIZE = 1
91 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ = CN()
92 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.DEFREE = None # degree ranges, [-degree, +degree]
93 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.FLIP = None # [prob horizon, prob vertical]
94 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.CROP_SIZE = None
95 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.RESIZE_SIZE = None
96 | _C.MODEL.CUT_AND_PASTE.CP.CELEBAMASK_HQ.SCALE = [0.3, 0.7] # scale < 1.0
97 |
98 | # Settings for cut and paste shuffle
99 | _C.MODEL.CUT_AND_PASTE.SHUFFLE = CN()
100 | _C.MODEL.CUT_AND_PASTE.SHUFFLE.TYPE = '' # CHOICES: ['NONE', 'FG2FACE']
101 | _C.MODEL.CUT_AND_PASTE.SHUFFLE.KERNEL_SIZE = 1
102 |
103 | # TODO
104 | # .SHUFFLE.TYPE: FG2FACE
105 |
106 | # -----------------------------------------------------------------------------
107 | # Training settings
108 | # -----------------------------------------------------------------------------
109 | _C.TRAIN = CN()
110 | _C.TRAIN.START_EPOCH = 1
111 | _C.TRAIN.EPOCHS = 150
112 | _C.TRAIN.WARMUP_EPOCHS = 0
113 | _C.TRAIN.LR_DECAY = False
114 | _C.TRAIN.LR = 0.00001
115 | _C.TRAIN.CLIP_GRAD = False
116 | _C.TRAIN.RESUME_CKPT = None
117 |
118 | # Settings for optimizer
119 | _C.TRAIN.OPTIMIZER = CN()
120 | _C.TRAIN.OPTIMIZER.TYPE = 'ADAM'
121 |
122 | # -----------------------------------------------------------------------------
123 | # Test settings
124 | # -----------------------------------------------------------------------------
125 | _C.TEST = CN()
126 | _C.TEST.DATASET = 'VAL500P'
127 | _C.TEST.TEST_METHOD = 'HYBRID'
128 | _C.TEST.CKPT_NAME = 'best_SAD_VAL500P'
129 | _C.TEST.FAST_TEST = False
130 | _C.TEST.TEST_PRIVACY = False
131 | _C.TEST.SAVE_RESULT = False
132 | _C.TEST.LOCAL_PATCH_NUM = 1024 # for test only
133 |
134 | # -----------------------------------------------------------------------------
135 | # Other settings, e.g. logging, wandb, dist, tag, test freq, save ckpt
136 | # -----------------------------------------------------------------------------
137 | _C.TAG = 'debug'
138 | _C.ENABLE_WANDB = False
139 | _C.TEST_FREQ = 1
140 | _C.AUTO_RESUME = False
141 | _C.SEED = 10007
142 | _C.DIST = False
143 | _C.LOCAL_RANK = 0
144 |
145 |
146 | def _update_config_from_file(config, cfg_file):
147 | config.defrost()
148 | with open(cfg_file, 'r') as f:
149 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
150 |
151 | for cfg in yaml_cfg.setdefault('BASE', ['']):
152 | if cfg:
153 | _update_config_from_file(
154 | config, os.path.join(os.path.dirname(cfg_file), cfg)
155 | )
156 | print('=> merge config from {}'.format(cfg_file))
157 | config.merge_from_file(cfg_file)
158 | config.freeze()
159 |
160 |
161 | def update_config(config, args):
162 | # merge config from other config
163 | if getattr(args, 'existing_cfg', None): # for test only
164 | if hasattr(args.existing_cfg.MODEL, 'CUT_AND_PASTE') and hasattr(args.existing_cfg.MODEL.CUT_AND_PASTE, 'LAYER'):
165 | if type(args.existing_cfg.MODEL.CUT_AND_PASTE.LAYER) != list:
166 | args.existing_cfg.defrost()
167 | args.existing_cfg.MODEL.CUT_AND_PASTE.LAYER = [args.existing_cfg.MODEL.CUT_AND_PASTE.LAYER]
168 | args.existing_cfg.freeze()
169 | config.merge_from_other_cfg(args.existing_cfg)
170 | assert args.tag == config.TAG
171 |
172 | # merge config from file
173 | if getattr(args, 'cfg', None):
174 | _update_config_from_file(config, args.cfg)
175 |
176 | config.defrost()
177 | if getattr(args, 'opts', None):
178 | config.merge_from_list(args.opts)
179 |
180 | # merge from specific arguments
181 | if getattr(args, 'arch', None):
182 | config.MODEL.TYPE = args.arch
183 |
184 | if getattr(args, 'train_from_scratch', None):
185 | config.MODEL.PRETRAINED = False
186 |
187 | if getattr(args, 'tag', None):
188 | config.TAG = args.tag
189 |
190 | if getattr(args, 'nEpochs', None):
191 | config.TRAIN.EPOCHS = args.nEpochs
192 |
193 | if getattr(args, 'warmup_nEpochs', None):
194 | config.TRAIN.WARMUP_EPOCHS = args.warmup_nEpochs
195 |
196 | if getattr(args, 'batchSize', None):
197 | config.DATA.BATCH_SIZE = args.batchSize
198 |
199 | if getattr(args, 'lr', None):
200 | config.TRAIN.LR = args.lr
201 |
202 | if getattr(args, 'lr_decay', None):
203 | config.TRAIN.LR_DECAY = args.lr_decay
204 |
205 | if getattr(args, 'clip_grad', None):
206 | config.TRAIN.CLIP_GRAD = args.clip_grad
207 |
208 | if getattr(args, 'threads', None):
209 | config.DATA.NUM_WORKERS = args.threads
210 |
211 | if getattr(args, 'test_freq', None):
212 | config.TEST_FREQ = args.test_freq
213 |
214 | if getattr(args, 'enable_wandb', None):
215 | config.ENABLE_WANDB = args.enable_wandb
216 |
217 | if getattr(args, 'auto_resume', None):
218 | config.AUTO_RESUME = args.auto_resume
219 |
220 | if getattr(args, 'source_batch_size', None):
221 | config.MODEL.CUT_AND_PASTE.CP.SOURCE_BATCH_SIZE = args.source_batch_size
222 |
223 | if getattr(args, 'dataset', None):
224 | config.DATA.DATASET = args.dataset
225 |
226 | if getattr(args, 'train_set', None):
227 | config.DATA.TRAIN_SET = args.train_set
228 |
229 | if getattr(args, 'seed', None):
230 | config.SEED = args.seed
231 |
232 | if getattr(args, 'test_dataset', None):
233 | config.TEST.DATASET = args.test_dataset
234 |
235 | if getattr(args, 'test_ckpt', None):
236 | config.TEST.CKPT_NAME = args.test_ckpt
237 |
238 | if getattr(args, 'test_method', None):
239 | config.TEST.TEST_METHOD = args.test_method
240 |
241 | if getattr(args, 'fast_test', None):
242 | config.TEST.FAST_TEST = args.fast_test
243 |
244 | if getattr(args, 'test_privacy', None):
245 | config.TEST.TEST_PRIVACY = args.test_privacy
246 |
247 | if getattr(args, 'save_result', None):
248 | config.TEST.SAVE_RESULT = args.save_result
249 |
250 | if getattr(args, 'local_rank', None):
251 | config.LOCAL_RANK = args.local_rank
252 |
253 | config.freeze()
254 |
255 | def get_config(args):
256 | config = _C.clone()
257 | update_config(config, args)
258 | return config
--------------------------------------------------------------------------------
/core/configs/ResNet34.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: 'r34'
3 | CUT_AND_PASTE:
4 | TYPE: 'NONE'
5 |
6 | DATA:
7 | CUT_AND_PASTE:
8 | TYPE: 'NONE'
9 |
--------------------------------------------------------------------------------
/core/configs/Swin_T.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: 'swin'
3 | CUT_AND_PASTE:
4 | TYPE: 'NONE'
5 |
6 | DATA:
7 | CUT_AND_PASTE:
8 | TYPE: 'NONE'
--------------------------------------------------------------------------------
/core/configs/ViTAE_S.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: 'vitae'
3 | CUT_AND_PASTE:
4 | TYPE: 'NONE'
5 |
6 | DATA:
7 | CUT_AND_PASTE:
8 | TYPE: 'NONE'
9 |
--------------------------------------------------------------------------------
/core/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 | Inferernce file.
4 |
5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | Licensed under the MIT License (see LICENSE for details)
7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | """
11 | import torch
12 | import cv2
13 | import random
14 | import numpy as np
15 | from PIL import Image
16 | from torchvision import transforms
17 | from copy import deepcopy
18 | import ipdb
19 | import math
20 | from torch.utils.data import DataLoader
21 | from torch.utils.data._utils.collate import default_collate
22 | from collections.abc import Iterable
23 |
24 | from data_util import *
25 | from config import *
26 | from util import *
27 |
28 |
29 | #########################
30 | ## collate
31 | #########################
32 |
33 | def collate_two_stage(batch):
34 | transposed = list(zip(*batch))
35 | global_data = transposed[:6]
36 | local_data = transposed[6:12]
37 | params = transposed[12:14]
38 |
39 | global_batch = [default_collate(list(elem)) for elem in global_data]
40 | local_batch = [torch.cat(list(elem), dim=0) for elem in local_data]
41 | return [*global_batch, *local_batch, *params]
42 |
43 | #########################
44 | ## Data transformer
45 | #########################
46 | class MattingTransform(object):
47 | def __init__(self, crop_size, resize_size):
48 | super(MattingTransform, self).__init__()
49 | self.crop_size = crop_size
50 | self.resize_size = resize_size
51 |
52 | # args: image(blurred), mask, fg, bg, trimap, facemask, source_img, source_facemask
53 | def __call__(self, *args):
54 | ori = args[0]
55 | trimap = args[4]
56 |
57 | h, w, c = ori.shape
58 | crop_size = random.choice(self.crop_size)
59 | crop_size = crop_size if crop_size < min(h, w) else 512
60 | resize_size = self.resize_size
61 |
62 | target = np.where(trimap == 128) if random.random() < 0.5 else np.where(trimap > -100)
63 | if len(target[0]) == 0:
64 | target = np.where(trimap > -100)
65 |
66 | random_idx = np.random.choice(len(target[0]))
67 | centerh = target[0][random_idx]
68 | centerw = target[1][random_idx]
69 | crop_loc = self.safe_crop([centerh, centerw], crop_size, trimap.shape)
70 |
71 | flip_flag = True if random.random() < 0.5 else False
72 |
73 | args_transform = []
74 | for item in args:
75 | item = item[crop_loc[0]:crop_loc[2], crop_loc[1]:crop_loc[3]]
76 | if flip_flag:
77 | item = cv2.flip(item, 1)
78 | item = cv2.resize(item, (resize_size, resize_size), interpolation=cv2.INTER_LINEAR)
79 | args_transform.append(item)
80 |
81 | return args_transform
82 |
83 | def safe_crop(self, center_pt, crop_size, img_size):
84 | h, w = img_size[:2]
85 | crop_size = min(h, w, crop_size) # make sure crop_size <= min(h,w)
86 |
87 | center_h, center_w = center_pt
88 |
89 | left_top_h = max(center_h-crop_size//2, 0)
90 | right_bottom_h = min(h, left_top_h+crop_size)
91 | left_top_h = min(left_top_h, right_bottom_h-crop_size)
92 |
93 | left_top_w = max(center_w-crop_size//2, 0)
94 | right_bottom_w = min(w, left_top_w+crop_size)
95 | left_top_w = min(left_top_w, right_bottom_w-crop_size)
96 |
97 | return (left_top_h, left_top_w, right_bottom_h, right_bottom_w)
98 |
99 |
100 | class MattingDataset(torch.utils.data.Dataset):
101 | def __init__(self, config, transform):
102 | # Prepare transform
103 | self.transform = transform
104 |
105 | # Load data
106 | self.samples=[]
107 | self.samples += generate_paths_for_dataset(dataset=config.DATA.DATASET, trainset=config.DATA.TRAIN_SET)
108 |
109 | def __getitem__(self,index):
110 | # Prepare training sample paths
111 | ori_path, mask_path, fg_path, bg_path, facemask_path = self.samples[index]
112 |
113 | ori = np.array(Image.open(ori_path))
114 | mask = trim_img(np.array(Image.open(mask_path)))
115 | fg = np.array(Image.open(fg_path))
116 | bg = np.array(Image.open(bg_path))
117 | facemask = np.array(Image.open(facemask_path))[:,:,0:1]
118 | # Generate trimap/dilation/erosion online
119 | kernel_size = random.randint(15, 30)
120 | trimap = gen_trimap_with_dilate(mask, kernel_size)
121 |
122 | # Data transformation to generate samples (crop/flip/resize)
123 | # Transform input order: ori, mask, fg, bg, trimap
124 | argv = self.transform(ori, mask, fg, bg, trimap, facemask)
125 | argv_transform = []
126 | for item in argv:
127 | if item.ndim<3:
128 | item = torch.from_numpy(item.astype(np.float32)[np.newaxis, :, :])
129 | else:
130 | item = torch.from_numpy(item.astype(np.float32)).permute(2, 0, 1)
131 | argv_transform.append(item)
132 |
133 | [ori, mask, fg, bg, trimap, facemask] = argv_transform
134 |
135 | trimap[trimap > 180] = 255
136 | trimap[trimap < 50] = 0
137 | trimap[(trimap < 255) * (trimap > 0)] = 128
138 |
139 | facemask[facemask>100] = 255
140 | facemask[facemask<=100] = 0
141 |
142 | # normalize ori, fg, bg
143 | ori = ori/255.0
144 | fg = fg/255.0
145 | bg = bg/255.0
146 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
147 | std=[0.229, 0.224, 0.225])
148 | ori = normalize(ori)
149 | fg = normalize(fg)
150 | bg = normalize(bg)
151 |
152 | # output order: ori, mask, fg, bg, trimap
153 | return ori, mask, fg, bg, trimap, facemask
154 |
155 | def __len__(self):
156 | return len(self.samples)
157 |
158 |
--------------------------------------------------------------------------------
/core/data_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 | Inferernce file.
4 |
5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | Licensed under the MIT License (see LICENSE for details)
7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | """
11 |
12 | import ipdb
13 | import random
14 | import numpy as np
15 | from copy import deepcopy
16 | import math
17 |
18 |
19 | #########################################
20 | # Pure Functions
21 | #########################################
22 |
23 | def safe_crop(center_pt, crop_size, img_size):
24 | h, w = img_size[:2]
25 | crop_size = min(h, w, crop_size) # make sure crop_size <= min(h,w)
26 |
27 | center_h, center_w = center_pt
28 |
29 | left_top_h = max(center_h-crop_size//2, 0)
30 | right_bottom_h = min(h, left_top_h+crop_size)
31 | left_top_h = min(left_top_h, right_bottom_h-crop_size)
32 |
33 | left_top_w = max(center_w-crop_size//2, 0)
34 | right_bottom_w = min(w, left_top_w+crop_size)
35 | left_top_w = min(left_top_w, right_bottom_w-crop_size)
36 |
37 | return (left_top_h, left_top_w, right_bottom_h, right_bottom_w)
38 |
39 |
40 | #########################################
41 | # Functions for Affine Transform
42 | #########################################
43 |
44 |
45 | def get_random_params_for_inverse_affine_matrix(degrees, translate, scale_ranges, shears, flip, img_size):
46 | """Get parameters for affine transformation
47 |
48 | Returns:
49 | sequence: params to be passed to the affine transformation
50 | """
51 | angle = random.uniform(-degrees, degrees)
52 | assert translate is None, "bugs here, figure out the xy axis and img size problem"
53 | if translate is not None:
54 | max_dx = translate[0] * img_size[0]
55 | max_dy = translate[1] * img_size[1]
56 | translations = (np.round(random.uniform(-max_dx, max_dx)),
57 | np.round(random.uniform(-max_dy, max_dy)))
58 | else:
59 | translations = (0, 0)
60 |
61 | if scale_ranges is not None:
62 | scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
63 | random.uniform(scale_ranges[0], scale_ranges[1]))
64 | else:
65 | scale = (1.0, 1.0)
66 |
67 | if shears is not None:
68 | shear = random.uniform(shears[0], shears[1])
69 | else:
70 | shear = 0.0
71 |
72 | if flip is not None:
73 | flip = 1 - (np.random.rand(2) < flip).astype(np.int) * 2
74 | flip = flip.tolist()
75 | else:
76 | flip = [1.0,1.0]
77 |
78 | return angle, translations, scale, shear, flip
79 |
80 |
81 | def get_inverse_affine_matrix(center, angle, translate=None, scale=None, shear=None, flip=None):
82 | # Helper method to compute inverse matrix for affine transformation
83 |
84 | # As it is explained in PIL.Image.rotate
85 | # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
86 | # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
87 | # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
88 | # RSS is rotation with scale and shear matrix
89 | # It is different from the original function in torchvision
90 | # The order are changed to flip -> scale -> rotation -> shear
91 | # x and y have different scale factors
92 | # RSS(shear, a, scale, f) = [ cos(a + shear)*scale_x*f -sin(a + shear)*scale_y 0]
93 | # [ sin(a)*scale_x*f cos(a)*scale_y 0]
94 | # [ 0 0 1]
95 | # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
96 |
97 | """
98 | Args:
99 | center (tuple(int,int)) : the center of the image, required
100 | angle (int) : angle for rotation, range [-360,360], required
101 | translate (tuple(int,int)) : default (0,0). Bugs here, so DON'T use it.
102 | scale (tuple(double,double)): default (1.,1.), scale for x and y axis
103 | shear (double) : default 0.0
104 | flip (tuple(int,int)) : default no flip, choices [0 horizonal, 1 vertical]
105 | """
106 | # assertions, check param range
107 | if translate is not None:
108 | assert translate == (0,0), "BUG UNSOLVED"
109 | else:
110 | translate = (0,0)
111 |
112 | if scale is not None:
113 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
114 | "scale should be a list or tuple and it must be of length 2."
115 | for s in scale:
116 | if s <= 0:
117 | raise ValueError("scale values should be positive")
118 | else:
119 | scale = (1.,1.)
120 |
121 | if shear is None:
122 | shear = 0.0
123 |
124 | if flip is not None:
125 | assert isinstance(flip, (tuple, list)) and len(flip) == 2, \
126 | "flip should be a list or tuple and it must be of length 2."
127 | for f in flip:
128 | if f != -1 and f != 1:
129 | raise ValueError("flip values should be -1 or 1.")
130 | else:
131 | # final flip value -1, means flip
132 | # final flip value 1, means not flip
133 | # flip = (np.random.rand(2) < flip).astype(np.int) * 2 - 1
134 | # flip[0] horizonal
135 | # flip[1] vertical
136 | flip = (1.0,1.0)
137 |
138 | angle = math.radians(angle)
139 | shear = math.radians(shear)
140 | scale_x = 1.0 / scale[0] * flip[0]
141 | scale_y = 1.0 / scale[1] * flip[1]
142 |
143 | # Inverted rotation matrix with scale and shear
144 | d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
145 | matrix = [
146 | math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
147 | -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
148 | ]
149 | matrix = [m / d for m in matrix]
150 |
151 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
152 | matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
153 | matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
154 |
155 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1
156 | matrix[2] += center[0]
157 | matrix[5] += center[1]
158 |
159 | return matrix
160 |
161 |
--------------------------------------------------------------------------------
/core/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 | evaluation file.
4 |
5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | Licensed under the MIT License (see LICENSE for details)
7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | """
11 |
12 | import os
13 | import argparse
14 | import numpy as np
15 | import cv2
16 | from tqdm import tqdm
17 | from metrics import *
18 | from util import listdir_nohidden
19 |
20 |
21 | def get_args():
22 | parser = argparse.ArgumentParser(description='Arguments for the training purpose.')
23 | parser.add_argument('--pred_dir', type=str, required=True, help='path to predictions')
24 | parser.add_argument('--alpha_dir', type=str, required=True, help='path to groundtruth alpha')
25 | parser.add_argument('--trimap_dir', type=str, help='path to trimap')
26 | parser.add_argument('--fast_test', action='store_true', help='skip grad and conn')
27 | args, _ = parser.parse_known_args()
28 |
29 | print('Prediction dir: {}'.format(args.pred_dir))
30 | print('Alpha dir: {}'.format(args.alpha_dir))
31 | print('Trimap dir: {}'.format(args.trimap_dir))
32 | if args.fast_test:
33 | print('Will skip gradient and connectivity...')
34 | return args
35 |
36 | def evaluate_folder(args):
37 | img_list = listdir_nohidden(args.alpha_dir)
38 | total_number = len(img_list)
39 |
40 | sad_diffs = 0.
41 | mse_diffs = 0.
42 | mad_diffs = 0.
43 | sad_trimap_diffs = 0.
44 | mse_trimap_diffs = 0.
45 | mad_trimap_diffs = 0.
46 | sad_fg_diffs = 0.
47 | sad_bg_diffs = 0.
48 | conn_diffs = 0.
49 | grad_diffs = 0.
50 |
51 | for img_name in tqdm(img_list):
52 | predict = cv2.imread(os.path.join(args.pred_dir, img_name), 0).astype(np.float32)/255.0
53 | alpha = cv2.imread(os.path.join(args.alpha_dir, img_name), 0).astype(np.float32)/255.0
54 |
55 | if args.trimap_dir is not None:
56 | trimap = cv2.imread(os.path.join(args.trimap_dir, img_name), 0).astype(np.float32)
57 | else:
58 | trimap = None
59 |
60 | sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
61 |
62 | if trimap is not None:
63 | sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
64 | sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
65 | else:
66 | sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = 0.0, 0.0, 0.0
67 | sad_fg_diff, sad_bg_diff = 0.0, 0.0
68 |
69 | if args.fast_test:
70 | conn_diff = 0.0
71 | grad_diff = 0.0
72 | else:
73 | conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
74 | grad_diff = compute_gradient_whole_image(predict, alpha)
75 |
76 |
77 | sad_diffs += sad_diff
78 | mse_diffs += mse_diff
79 | mad_diffs += mad_diff
80 | mse_trimap_diffs += mse_trimap_diff
81 | sad_trimap_diffs += sad_trimap_diff
82 | mad_trimap_diffs += mad_trimap_diff
83 | sad_fg_diffs += sad_fg_diff
84 | sad_bg_diffs += sad_bg_diff
85 | conn_diffs += conn_diff
86 | grad_diffs += grad_diff
87 |
88 | res_dict = {}
89 | res_dict['SAD'] = sad_diffs / total_number
90 | res_dict['MSE'] = mse_diffs / total_number
91 | res_dict['MAD'] = mad_diffs / total_number
92 | res_dict['SAD_TRIMAP'] = sad_trimap_diffs / total_number
93 | res_dict['MSE_TRIMAP'] = mse_trimap_diffs / total_number
94 | res_dict['MAD_TRIMAP'] = mad_trimap_diffs / total_number
95 | res_dict['SAD_FG'] = sad_fg_diffs / total_number
96 | res_dict['SAD_BG'] = sad_bg_diffs / total_number
97 | res_dict['CONN'] = conn_diffs / total_number
98 | res_dict['GRAD'] = grad_diffs / total_number
99 |
100 | print('Average results')
101 | print('Test image numbers: {}'.format(total_number))
102 | print('Whole image SAD:', res_dict['SAD'])
103 | print('Whole image MSE:', res_dict['MSE'])
104 | print('Whole image MAD:', res_dict['MAD'])
105 | print('Unknown region SAD:', res_dict['SAD_TRIMAP'])
106 | print('Unknown region MSE:', res_dict['MSE_TRIMAP'])
107 | print('Unknown region MAD:', res_dict['MAD_TRIMAP'])
108 | print('Foreground SAD:', res_dict['SAD_FG'])
109 | print('Background SAD:', res_dict['SAD_BG'])
110 | print('Gradient:', res_dict['GRAD'])
111 | print('Connectivity:', res_dict['CONN'])
112 | return res_dict
113 |
114 |
115 | if __name__ == '__main__':
116 | args = get_args()
117 | evaluate_folder(args)
118 |
--------------------------------------------------------------------------------
/core/evaluate.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 | Inferernce file.
4 |
5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | Licensed under the MIT License (see LICENSE for details)
7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | """
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | import math
15 | from torch.autograd import Variable
16 | import torch.nn.functional as fnn
17 | import ipdb
18 |
19 |
20 | ##############################
21 | ### Training loses for P3M-NET
22 | ##############################
23 | def get_crossentropy_loss(gt,pre):
24 | gt_copy = gt.clone()
25 | gt_copy[gt_copy==0] = 0
26 | gt_copy[gt_copy==255] = 2
27 | gt_copy[gt_copy>2] = 1
28 | gt_copy = gt_copy.long()
29 | gt_copy = gt_copy[:,0,:,:]
30 | criterion = nn.CrossEntropyLoss()
31 | entropy_loss = criterion(pre, gt_copy)
32 | return entropy_loss
33 |
34 | def get_alpha_loss(predict, alpha, trimap):
35 | weighted = torch.zeros(trimap.shape).cuda()
36 | weighted[trimap == 128] = 1.
37 | alpha_f = alpha / 255.
38 | alpha_f = alpha_f.cuda()
39 | diff = predict - alpha_f
40 | diff = diff * weighted
41 | alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
42 | alpha_loss_weighted = alpha_loss.sum() / (weighted.sum() + 1.)
43 | return alpha_loss_weighted
44 |
45 | def get_alpha_loss_whole_img(predict, alpha):
46 | weighted = torch.ones(alpha.shape).cuda()
47 | alpha_f = alpha / 255.
48 | alpha_f = alpha_f.cuda()
49 | diff = predict - alpha_f
50 | alpha_loss = torch.sqrt(diff ** 2 + 1e-12)
51 | alpha_loss = alpha_loss.sum()/(weighted.sum())
52 | return alpha_loss
53 |
54 | ## Laplacian loss is refer to
55 | ## https://gist.github.com/MarcoForte/a07c40a2b721739bb5c5987671aa5270
56 | def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False):
57 | if size % 2 != 1:
58 | raise ValueError("kernel size must be uneven")
59 | grid = np.float32(np.mgrid[0:size,0:size].T)
60 | gaussian = lambda x: np.exp((x - size//2)**2/(-2*sigma**2))**2
61 | kernel = np.sum(gaussian(grid), axis=2)
62 | kernel /= np.sum(kernel)
63 | kernel = np.tile(kernel, (n_channels, 1, 1))
64 | kernel = torch.FloatTensor(kernel[:, None, :, :]).cuda()
65 | return Variable(kernel, requires_grad=False)
66 |
67 | def conv_gauss(img, kernel):
68 | """ convolve img with a gaussian kernel that has been built with build_gauss_kernel """
69 | n_channels, _, kw, kh = kernel.shape
70 | img = fnn.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
71 | return fnn.conv2d(img, kernel, groups=n_channels)
72 |
73 | def laplacian_pyramid(img, kernel, max_levels=5):
74 | current = img
75 | pyr = []
76 | for level in range(max_levels):
77 | filtered = conv_gauss(current, kernel)
78 | diff = current - filtered
79 | pyr.append(diff)
80 | current = fnn.avg_pool2d(filtered, 2)
81 | pyr.append(current)
82 | return pyr
83 |
84 | def get_laplacian_loss(predict, alpha, trimap):
85 | weighted = torch.zeros(trimap.shape).cuda()
86 | weighted[trimap == 128] = 1.
87 | alpha_f = alpha / 255.
88 | alpha_f = alpha_f.cuda()
89 | alpha_f = alpha_f.clone()*weighted
90 | predict = predict.clone()*weighted
91 | gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
92 | pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
93 | pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
94 | laplacian_loss_weighted = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
95 | return laplacian_loss_weighted
96 |
97 | def get_laplacian_loss_whole_img(predict, alpha):
98 | alpha_f = alpha / 255.
99 | alpha_f = alpha_f.cuda()
100 | gauss_kernel = build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=True)
101 | pyr_alpha = laplacian_pyramid(alpha_f, gauss_kernel, 5)
102 | pyr_predict = laplacian_pyramid(predict, gauss_kernel, 5)
103 | laplacian_loss = sum(fnn.l1_loss(a, b) for a, b in zip(pyr_alpha, pyr_predict))
104 | return laplacian_loss
105 |
106 | def get_composition_loss_whole_img(img, alpha, fg, bg, predict):
107 | weighted = torch.ones(alpha.shape).cuda()
108 | predict_3 = torch.cat((predict, predict, predict), 1)
109 | comp = predict_3 * fg + (1. - predict_3) * bg
110 | comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12) # fixed bug, remove "/255."
111 | comp_loss = comp_loss.sum()/(weighted.sum())
112 | return comp_loss
113 |
114 | def get_composition_loss(img, alpha, fg, bg, trimap, predict):
115 | wi = torch.zeros(trimap.shape)
116 | wi[trimap == 128] = 1.
117 | t_wi = wi.cuda()
118 | t3_wi = torch.cat((wi, wi, wi), 1).cuda()
119 | unknown_region_size = t_wi.sum()
120 | predict_3 = torch.cat((predict, predict, predict), 1)
121 | comp = predict_3 * fg + (1. - predict_3) * bg
122 | comp_loss = torch.sqrt((comp - img) ** 2 + 1e-12) / 255.
123 | comp_loss = (comp_loss * t3_wi).sum() / (unknown_region_size + 1.) / 3.
124 | return comp_loss
125 |
126 | ##############################
127 | ### Test loss for matting
128 | ##############################
129 | def calculate_sad_mse_mad_privacy(predict_old, alpha, privacy):
130 | predict = np.copy(predict_old)
131 | pixel = float((privacy == 255).sum())
132 | sad_diff_mask = np.abs(predict - alpha)
133 | sad_diff_mask[privacy != 255] = 0.
134 | sad_diff = np.sum(sad_diff_mask) / 1000
135 | if pixel == 0:
136 | return 0, 0, 0
137 | mse_diff = np.sum(sad_diff_mask ** 2)/pixel
138 | mad_diff = np.sum(sad_diff_mask)/pixel
139 | return sad_diff, mse_diff, mad_diff
140 |
141 | def calculate_sad_mse_mad(predict_old,alpha,trimap):
142 | predict = np.copy(predict_old)
143 | pixel = float((trimap == 128).sum())
144 | predict[trimap == 255] = 1.
145 | predict[trimap == 0 ] = 0.
146 | sad_diff = np.sum(np.abs(predict - alpha))/1000
147 | if pixel==0:
148 | pixel = trimap.shape[0]*trimap.shape[1]-float((trimap==255).sum())-float((trimap==0).sum())
149 | mse_diff = np.sum((predict - alpha) ** 2)/pixel
150 | mad_diff = np.sum(np.abs(predict - alpha))/pixel
151 | return sad_diff, mse_diff, mad_diff
152 |
153 | def calculate_sad_mse_mad_whole_img(predict, alpha):
154 | pixel = predict.shape[0]*predict.shape[1]
155 | sad_diff = np.sum(np.abs(predict - alpha))/1000
156 | mse_diff = np.sum((predict - alpha) ** 2)/pixel
157 | mad_diff = np.sum(np.abs(predict - alpha))/pixel
158 | return sad_diff, mse_diff, mad_diff
159 |
160 | def calculate_sad_fgbg(predict, alpha, trimap):
161 | sad_diff = np.abs(predict-alpha)
162 | weight_fg = np.zeros(predict.shape)
163 | weight_bg = np.zeros(predict.shape)
164 | weight_trimap = np.zeros(predict.shape)
165 | weight_fg[trimap==255] = 1.
166 | weight_bg[trimap==0 ] = 1.
167 | weight_trimap[trimap==128 ] = 1.
168 | sad_fg = np.sum(sad_diff*weight_fg)/1000
169 | sad_bg = np.sum(sad_diff*weight_bg)/1000
170 | sad_trimap = np.sum(sad_diff*weight_trimap)/1000
171 | return sad_fg, sad_bg
172 |
173 | def compute_gradient_whole_image(pd, gt):
174 | from scipy.ndimage import gaussian_filter
175 | pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32)
176 | pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32)
177 | gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32)
178 | gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32)
179 | pd_mag = np.sqrt(pd_x**2 + pd_y**2)
180 | gt_mag = np.sqrt(gt_x**2 + gt_y**2)
181 |
182 | error_map = np.square(pd_mag - gt_mag)
183 | loss = np.sum(error_map) / 10
184 | return loss
185 |
186 | def compute_connectivity_loss_whole_image(pd, gt, step=0.1):
187 |
188 | from scipy.ndimage import morphology
189 | from skimage.measure import label, regionprops
190 | h, w = pd.shape
191 | thresh_steps = np.arange(0, 1.1, step)
192 | l_map = -1 * np.ones((h, w), dtype=np.float32)
193 | lambda_map = np.ones((h, w), dtype=np.float32)
194 | for i in range(1, thresh_steps.size):
195 | pd_th = pd >= thresh_steps[i]
196 | gt_th = gt >= thresh_steps[i]
197 | label_image = label(pd_th & gt_th, connectivity=1)
198 | cc = regionprops(label_image)
199 | size_vec = np.array([c.area for c in cc])
200 | if len(size_vec) == 0:
201 | continue
202 | max_id = np.argmax(size_vec)
203 | coords = cc[max_id].coords
204 | omega = np.zeros((h, w), dtype=np.float32)
205 | omega[coords[:, 0], coords[:, 1]] = 1
206 | flag = (l_map == -1) & (omega == 0)
207 | l_map[flag == 1] = thresh_steps[i-1]
208 | dist_maps = morphology.distance_transform_edt(omega==0)
209 | dist_maps = dist_maps / dist_maps.max()
210 | l_map[l_map == -1] = 1
211 | d_pd = pd - l_map
212 | d_gt = gt - l_map
213 | phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32)
214 | phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32)
215 | loss = np.sum(np.abs(phi_pd - phi_gt)) / 1000
216 | return loss
--------------------------------------------------------------------------------
/core/infer.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 | Inferernce file.
4 |
5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | Licensed under the MIT License (see LICENSE for details)
7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | """
11 |
12 | import torch
13 | import cv2
14 | import argparse
15 | import numpy as np
16 | from tqdm import tqdm
17 | from PIL import Image
18 | from skimage.transform import resize
19 | from torchvision import transforms
20 |
21 | from config import *
22 | from util import *
23 | from network import build_model
24 |
25 |
26 | def get_args():
27 | parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
28 | parser.add_argument('--arch', type=str, required=True, choices=['r34', 'swin', 'vitae'], help='model architecture')
29 | parser.add_argument('--dataset', type=str, required=True, choices=['P3M10K', 'RWP', 'SAMPLES'], help='dataset to test')
30 | parser.add_argument('--test_set', type=str, choices=['RWP', 'P3M_500_P', 'P3M_500_NP'], help='the validation set to test')
31 | parser.add_argument('--model_path', type=str, required=True, help='path of checkpoint')
32 | parser.add_argument('--test_choice', type=str, choices=['HYBRID', 'RESIZE'], required=True, help='how to test')
33 | parser.add_argument('--test_result_dir', type=str, help='path to save results of datasets')
34 | args, _ = parser.parse_known_args()
35 |
36 | print('Model architecture: {}'.format(args.arch))
37 | print('Model path: {}'.format(args.model_path))
38 | print('Test dataset: {}, set: {}'.format(args.dataset, args.test_set))
39 | print('Test choice: {}'.format(args.test_choice))
40 | if args.dataset != 'SAMPLES':
41 | print('Save results to {}'.format(args.test_result_dir))
42 | else:
43 | print('Save alpha results to {}'.format(SAMPLES_RESULT_ALPHA_PATH))
44 | print('Save color results to {}'.format(SAMPLES_RESULT_COLOR_PATH))
45 |
46 | return args
47 |
48 | def inference_once(args, model, scale_img, scale_trimap=None):
49 | if torch.cuda.device_count() > 0:
50 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1).cuda()
51 | else:
52 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1)
53 | input_t = tensor_img
54 | input_t = input_t/255.0
55 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
56 | std=[0.229, 0.224, 0.225])
57 | input_t = normalize(input_t)
58 | input_t = input_t.unsqueeze(0)
59 | pred_global, pred_local, pred_fusion = model(input_t)[:3]
60 | pred_global = pred_global.data.cpu().numpy()
61 | pred_global = gen_trimap_from_segmap_e2e(pred_global)
62 | pred_local = pred_local.data.cpu().numpy()[0,0,:,:]
63 | pred_fusion = pred_fusion.data.cpu().numpy()[0,0,:,:]
64 | return pred_global, pred_local, pred_fusion
65 |
66 | def inference_img_p3m(args, model, img):
67 | h, w, c = img.shape
68 | new_h = min(MAX_SIZE_H, h - (h % 32))
69 | new_w = min(MAX_SIZE_W, w - (w % 32))
70 |
71 | # for P3M-Net(Swin-T) model on very small images in RWP
72 | if args.dataset in ['RWP', 'SAMPLES'] and args.arch == 'swin' and (h < MIN_SIZE_H or w < MIN_SIZE_W):
73 | ratioh = float(MIN_SIZE_H)/float(h)
74 | ratiow = float(MIN_SIZE_W)/float(w)
75 | ratio = max(ratioh, ratiow)
76 | new_h = int(ratio*h)
77 | new_w = int(ratio*w)
78 | new_h = min(MAX_SIZE_H, new_h - (new_h % 32))
79 | new_w = min(MAX_SIZE_W, new_w - (new_w % 32))
80 | scale_img = resize(img, (new_h,new_w))*255.0
81 | print(scale_img.shape)
82 | pred_global, pred_local, pred_fusion = inference_once(args, model, scale_img)
83 | pred_local = resize(pred_local,(h,w))
84 | pred_global = resize(pred_global,(h,w))*255.0
85 | pred_fusion = resize(pred_fusion,(h,w))
86 | return pred_fusion
87 |
88 | if args.test_choice=='HYBRID':
89 | global_ratio = 1/2
90 | local_ratio = 1
91 | resize_h = int(h*global_ratio)
92 | resize_w = int(w*global_ratio)
93 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32))
94 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32))
95 | scale_img = resize(img,(new_h,new_w))*255.0
96 | pred_coutour_1, pred_retouching_1, pred_fusion_1 = inference_once(args, model, scale_img)
97 | # torch.cuda.empty_cache()
98 | pred_coutour_1 = resize(pred_coutour_1,(h,w))*255.0
99 | resize_h = int(h*local_ratio)
100 | resize_w = int(w*local_ratio)
101 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32))
102 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32))
103 | scale_img = resize(img,(new_h,new_w))*255.0
104 | pred_coutour_2, pred_retouching_2, pred_fusion_2 = inference_once(args, model, scale_img)
105 | # torch.cuda.empty_cache()
106 | pred_retouching_2 = resize(pred_retouching_2,(h,w))
107 | pred_fusion = get_masked_local_from_global_test(pred_coutour_1, pred_retouching_2)
108 | return pred_fusion
109 | elif args.test_choice=='RESIZE':
110 | resize_h = int(h/2)
111 | resize_w = int(w/2)
112 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32))
113 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32))
114 | scale_img = resize(img,(new_h,new_w))*255.0
115 | pred_global, pred_local, pred_fusion = inference_once(args, model, scale_img)
116 | pred_local = resize(pred_local,(h,w))
117 | pred_global = resize(pred_global,(h,w))*255.0
118 | pred_fusion = resize(pred_fusion,(h,w))
119 | return pred_fusion
120 | else:
121 | raise NotImplementedError
122 |
123 | def test_dataset(args, model):
124 | if torch.cuda.device_count() > 0:
125 | torch.cuda.empty_cache()
126 | else:
127 | print('NO GPU AVAILABLE')
128 | return
129 |
130 | ############################
131 | # Some initial setting for paths
132 | ############################
133 | ORIGINAL_PATH = DATASET_PATHS_DICT[args.dataset][args.test_set]['ORIGINAL_PATH']
134 |
135 | ############################
136 | # Start testing
137 | ############################
138 | result_dir = args.test_result_dir
139 | refresh_folder(result_dir)
140 |
141 | model.eval()
142 | img_list = listdir_nohidden(ORIGINAL_PATH)
143 |
144 | for img_name in tqdm(img_list):
145 | img_path = ORIGINAL_PATH+img_name
146 | img = np.array(Image.open(img_path))
147 | img = img[:,:,:3] if img.ndim>2 else img
148 |
149 | with torch.no_grad():
150 | predict = inference_img_p3m(args, model, img)
151 | save_test_result(os.path.join(result_dir, extract_pure_name(img_name)+'.png'),predict)
152 |
153 | def test_samples(args, model):
154 | model.eval()
155 | img_list = listdir_nohidden(SAMPLES_ORIGINAL_PATH)
156 | refresh_folder(SAMPLES_RESULT_ALPHA_PATH)
157 | refresh_folder(SAMPLES_RESULT_COLOR_PATH)
158 | for img_name in tqdm(img_list):
159 | img_path = SAMPLES_ORIGINAL_PATH+img_name
160 | try:
161 | img = np.array(Image.open(img_path))[:,:,:3]
162 | except Exception as e:
163 | print(f'Error: {str(e)} | Name: {img_name}')
164 | h, w, c = img.shape
165 | if min(h, w)>SHORTER_PATH_LIMITATION:
166 | if h>=w:
167 | new_w = SHORTER_PATH_LIMITATION
168 | new_h = int(SHORTER_PATH_LIMITATION*h/w)
169 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
170 | else:
171 | new_h = SHORTER_PATH_LIMITATION
172 | new_w = int(SHORTER_PATH_LIMITATION*w/h)
173 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
174 |
175 | with torch.no_grad():
176 | if torch.cuda.device_count() > 0:
177 | torch.cuda.empty_cache()
178 | predict = inference_img_p3m(args, model, img)
179 |
180 | composite = generate_composite_img(img, predict)
181 | cv2.imwrite(os.path.join(SAMPLES_RESULT_COLOR_PATH, extract_pure_name(img_name)+'.png'),composite)
182 | predict = predict*255.0
183 | predict = cv2.resize(predict, (w, h), interpolation=cv2.INTER_LINEAR)
184 | cv2.imwrite(os.path.join(SAMPLES_RESULT_ALPHA_PATH, extract_pure_name(img_name)+'.png'),predict.astype(np.uint8))
185 |
186 | def load_model_and_deploy(args):
187 | ### build model
188 | model = build_model(args.arch, pretrained=False)
189 |
190 | ### load ckpt
191 | ckpt = torch.load(args.model_path)
192 | model.load_state_dict(ckpt['state_dict'], strict=True)
193 | model = model.cuda()
194 |
195 | ### Test
196 | if args.dataset=='SAMPLES':
197 | test_samples(args, model)
198 | elif args.dataset in ['P3M10K', 'RWP']:
199 | test_dataset(args, model)
200 | else:
201 | print('Please input the correct dataset_choice (SAMPLES, P3M10K or RWP).')
202 |
203 | if __name__ == '__main__':
204 | args = get_args()
205 | load_model_and_deploy(args)
206 |
--------------------------------------------------------------------------------
/core/logger.py:
--------------------------------------------------------------------------------
1 | from termcolor import colored
2 | import logging
3 | import sys
4 |
5 |
6 | def create_logger(filename):
7 | # create logger
8 | logger = logging.getLogger('Logger')
9 | logger.setLevel(logging.DEBUG)
10 | logger.propagate = False
11 |
12 | # create formatter
13 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
14 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
15 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
16 |
17 | console_handler = logging.StreamHandler(sys.stdout)
18 | console_handler.setLevel(logging.DEBUG)
19 | console_handler.setFormatter(
20 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
21 | logger.addHandler(console_handler)
22 |
23 | # create file handlers
24 | file_handler = logging.FileHandler(filename, mode='a')
25 | file_handler.setLevel(logging.DEBUG)
26 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
27 | logger.addHandler(file_handler)
28 |
29 | return logger
--------------------------------------------------------------------------------
/core/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 |
4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
5 | Licensed under the MIT License (see LICENSE for details)
6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
7 | Paper link: https://arxiv.org/abs/2203.16828
8 |
9 | """
10 |
11 | import numpy as np
12 |
13 | ##############################
14 | ### Test loss for matting
15 | ##############################
16 |
17 | def calculate_sad_mse_mad(predict_old,alpha,trimap):
18 | predict = np.copy(predict_old)
19 | pixel = float((trimap == 128).sum())
20 | predict[trimap == 255] = 1.
21 | predict[trimap == 0 ] = 0.
22 | sad_diff = np.sum(np.abs(predict - alpha))/1000
23 | if pixel==0:
24 | pixel = trimap.shape[0]*trimap.shape[1]-float((trimap==255).sum())-float((trimap==0).sum())
25 | mse_diff = np.sum((predict - alpha) ** 2)/pixel
26 | mad_diff = np.sum(np.abs(predict - alpha))/pixel
27 | return sad_diff, mse_diff, mad_diff
28 |
29 | def calculate_sad_mse_mad_whole_img(predict, alpha):
30 | pixel = predict.shape[0]*predict.shape[1]
31 | sad_diff = np.sum(np.abs(predict - alpha))/1000
32 | mse_diff = np.sum((predict - alpha) ** 2)/pixel
33 | mad_diff = np.sum(np.abs(predict - alpha))/pixel
34 | return sad_diff, mse_diff, mad_diff
35 |
36 | def calculate_sad_fgbg(predict, alpha, trimap):
37 | sad_diff = np.abs(predict-alpha)
38 | weight_fg = np.zeros(predict.shape)
39 | weight_bg = np.zeros(predict.shape)
40 | weight_trimap = np.zeros(predict.shape)
41 | weight_fg[trimap==255] = 1.
42 | weight_bg[trimap==0 ] = 1.
43 | weight_trimap[trimap==128 ] = 1.
44 | sad_fg = np.sum(sad_diff*weight_fg)/1000
45 | sad_bg = np.sum(sad_diff*weight_bg)/1000
46 | sad_trimap = np.sum(sad_diff*weight_trimap)/1000
47 | return sad_fg, sad_bg
48 |
49 | def compute_gradient_whole_image(pd, gt):
50 | from scipy.ndimage import gaussian_filter
51 | pd_x = gaussian_filter(pd, sigma=1.4, order=[1, 0], output=np.float32)
52 | pd_y = gaussian_filter(pd, sigma=1.4, order=[0, 1], output=np.float32)
53 | gt_x = gaussian_filter(gt, sigma=1.4, order=[1, 0], output=np.float32)
54 | gt_y = gaussian_filter(gt, sigma=1.4, order=[0, 1], output=np.float32)
55 | pd_mag = np.sqrt(pd_x**2 + pd_y**2)
56 | gt_mag = np.sqrt(gt_x**2 + gt_y**2)
57 |
58 | error_map = np.square(pd_mag - gt_mag)
59 | loss = np.sum(error_map) / 10
60 | return loss
61 |
62 | def compute_connectivity_loss_whole_image(pd, gt, step=0.1):
63 | from scipy.ndimage import morphology
64 | from skimage.measure import label, regionprops
65 | h, w = pd.shape
66 | thresh_steps = np.arange(0, 1.1, step)
67 | l_map = -1 * np.ones((h, w), dtype=np.float32)
68 | lambda_map = np.ones((h, w), dtype=np.float32)
69 | for i in range(1, thresh_steps.size):
70 | pd_th = pd >= thresh_steps[i]
71 | gt_th = gt >= thresh_steps[i]
72 | label_image = label(pd_th & gt_th, connectivity=1)
73 | cc = regionprops(label_image)
74 | size_vec = np.array([c.area for c in cc])
75 | if len(size_vec) == 0:
76 | continue
77 | max_id = np.argmax(size_vec)
78 | coords = cc[max_id].coords
79 | omega = np.zeros((h, w), dtype=np.float32)
80 | omega[coords[:, 0], coords[:, 1]] = 1
81 | flag = (l_map == -1) & (omega == 0)
82 | l_map[flag == 1] = thresh_steps[i-1]
83 | dist_maps = morphology.distance_transform_edt(omega==0)
84 | dist_maps = dist_maps / dist_maps.max()
85 | l_map[l_map == -1] = 1
86 | d_pd = pd - l_map
87 | d_gt = gt - l_map
88 | phi_pd = 1 - d_pd * (d_pd >= 0.15).astype(np.float32)
89 | phi_gt = 1 - d_gt * (d_gt >= 0.15).astype(np.float32)
90 | loss = np.sum(np.abs(phi_pd - phi_gt)) / 1000
91 | return loss
--------------------------------------------------------------------------------
/core/network/RestNet34/P3mNet.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 |
4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
5 | Licensed under the MIT License (see LICENSE for details)
6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
7 | Paper link: https://arxiv.org/abs/2203.16828
8 |
9 | """
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from util import get_masked_local_from_global
13 | from ..modules import TFI, SBFI, DBFI
14 | from .resnet_mp import *
15 |
16 |
17 | class P3mNet(nn.Module):
18 | def __init__(self, pretrained=True):
19 | super().__init__()
20 | self.resnet = resnet34_mp(pretrained=pretrained)
21 | ############################
22 | ### Encoder part - RESNETMP
23 | ############################
24 | self.encoder0 = nn.Sequential(
25 | self.resnet.conv1,
26 | self.resnet.bn1,
27 | self.resnet.relu,
28 | )
29 | self.mp0 = self.resnet.maxpool1
30 | self.encoder1 = nn.Sequential(
31 | self.resnet.layer1)
32 | self.mp1 = self.resnet.maxpool2
33 | self.encoder2 = self.resnet.layer2
34 | self.mp2 = self.resnet.maxpool3
35 | self.encoder3 = self.resnet.layer3
36 | self.mp3 = self.resnet.maxpool4
37 | self.encoder4 = self.resnet.layer4
38 | self.mp4 = self.resnet.maxpool5
39 |
40 | self.tfi_3 = TFI(256)
41 | self.tfi_2 = TFI(128)
42 | self.tfi_1 = TFI(64)
43 | self.tfi_0 = TFI(64)
44 |
45 | self.sbfi_2 = SBFI(128, 64, 8)
46 | self.sbfi_1 = SBFI(64, 64, 4)
47 | self.sbfi_0 = SBFI(64, 64, 2)
48 |
49 | self.dbfi_2 = DBFI(128, 512, 4)
50 | self.dbfi_1 = DBFI(64, 512, 8)
51 | self.dbfi_0 = DBFI(64, 512, 16)
52 |
53 | ##########################
54 | ### Decoder part - GLOBAL
55 | ##########################
56 | self.decoder4_g = nn.Sequential(
57 | nn.Conv2d(512,512,3,padding=1),
58 | nn.BatchNorm2d(512),
59 | nn.ReLU(inplace=True),
60 | nn.Conv2d(512,512,3,padding=1),
61 | nn.BatchNorm2d(512),
62 | nn.ReLU(inplace=True),
63 | nn.Conv2d(512,256,3,padding=1),
64 | nn.BatchNorm2d(256),
65 | nn.ReLU(inplace=True),
66 | nn.Upsample(scale_factor=2, mode='bilinear') )
67 | self.decoder3_g = nn.Sequential(
68 | nn.Conv2d(256,256,3,padding=1),
69 | nn.BatchNorm2d(256),
70 | nn.ReLU(inplace=True),
71 | nn.Conv2d(256,256,3,padding=1),
72 | nn.BatchNorm2d(256),
73 | nn.ReLU(inplace=True),
74 | nn.Conv2d(256,128,3,padding=1),
75 | nn.BatchNorm2d(128),
76 | nn.ReLU(inplace=True),
77 | nn.Upsample(scale_factor=2, mode='bilinear') )
78 | self.decoder2_g = nn.Sequential(
79 | nn.Conv2d(128,128,3,padding=1),
80 | nn.BatchNorm2d(128),
81 | nn.ReLU(inplace=True),
82 | nn.Conv2d(128,128,3,padding=1),
83 | nn.BatchNorm2d(128),
84 | nn.ReLU(inplace=True),
85 | nn.Conv2d(128,64,3,padding=1),
86 | nn.BatchNorm2d(64),
87 | nn.ReLU(inplace=True),
88 | nn.Upsample(scale_factor=2, mode='bilinear'))
89 | self.decoder1_g = nn.Sequential(
90 | nn.Conv2d(64,64,3,padding=1),
91 | nn.BatchNorm2d(64),
92 | nn.ReLU(inplace=True),
93 | nn.Conv2d(64,64,3,padding=1),
94 | nn.BatchNorm2d(64),
95 | nn.ReLU(inplace=True),
96 | nn.Conv2d(64,64,3,padding=1),
97 | nn.BatchNorm2d(64),
98 | nn.ReLU(inplace=True),
99 | nn.Upsample(scale_factor=2, mode='bilinear'))
100 | self.decoder0_g = nn.Sequential(
101 | nn.Conv2d(64,64,3,padding=1),
102 | nn.BatchNorm2d(64),
103 | nn.ReLU(inplace=True),
104 | nn.Conv2d(64,64,3,padding=1),
105 | nn.BatchNorm2d(64),
106 | nn.ReLU(inplace=True),
107 | nn.Conv2d(64,3,3,padding=1),
108 | nn.Upsample(scale_factor=2, mode='bilinear'))
109 |
110 | ##########################
111 | ### Decoder part - LOCAL
112 | ##########################
113 | self.decoder4_l = nn.Sequential(
114 | nn.Conv2d(512,512,3,padding=1),
115 | nn.BatchNorm2d(512),
116 | nn.ReLU(inplace=True),
117 | nn.Conv2d(512,512,3,padding=1),
118 | nn.BatchNorm2d(512),
119 | nn.ReLU(inplace=True),
120 | nn.Conv2d(512,256,3,padding=1),
121 | nn.BatchNorm2d(256),
122 | nn.ReLU(inplace=True))
123 | self.decoder3_l = nn.Sequential(
124 | nn.Conv2d(256,256,3,padding=1),
125 | nn.BatchNorm2d(256),
126 | nn.ReLU(inplace=True),
127 | nn.Conv2d(256,256,3,padding=1),
128 | nn.BatchNorm2d(256),
129 | nn.ReLU(inplace=True),
130 | nn.Conv2d(256,128,3,padding=1),
131 | nn.BatchNorm2d(128),
132 | nn.ReLU(inplace=True))
133 | self.decoder2_l = nn.Sequential(
134 | nn.Conv2d(128,128,3,padding=1),
135 | nn.BatchNorm2d(128),
136 | nn.ReLU(inplace=True),
137 | nn.Conv2d(128,128,3,padding=1),
138 | nn.BatchNorm2d(128),
139 | nn.ReLU(inplace=True),
140 | nn.Conv2d(128,64,3,padding=1),
141 | nn.BatchNorm2d(64),
142 | nn.ReLU(inplace=True))
143 | self.decoder1_l = nn.Sequential(
144 | nn.Conv2d(64,64,3,padding=1),
145 | nn.BatchNorm2d(64),
146 | nn.ReLU(inplace=True),
147 | nn.Conv2d(64,64,3,padding=1),
148 | nn.BatchNorm2d(64),
149 | nn.ReLU(inplace=True),
150 | nn.Conv2d(64,64,3,padding=1),
151 | nn.BatchNorm2d(64),
152 | nn.ReLU(inplace=True))
153 | self.decoder0_l = nn.Sequential(
154 | nn.Conv2d(64,64,3,padding=1),
155 | nn.BatchNorm2d(64),
156 | nn.ReLU(inplace=True),
157 | nn.Conv2d(64,64,3,padding=1),
158 | nn.BatchNorm2d(64),
159 | nn.ReLU(inplace=True))
160 | self.decoder_final_l = nn.Conv2d(64,1,3,padding=1)
161 |
162 |
163 | def forward(self, input):
164 | ##########################
165 | ### Encoder part - RESNET
166 | ##########################
167 | e0 = self.encoder0(input)
168 | e0p, id0 = self.mp0(e0)
169 | e1p, id1 = self.mp1(e0p)
170 | e1 = self.encoder1(e1p)
171 | e2p, id2 = self.mp2(e1)
172 | e2 = self.encoder2(e2p)
173 | e3p, id3 = self.mp3(e2)
174 | e3 = self.encoder3(e3p)
175 | e4p, id4 = self.mp4(e3)
176 | e4 = self.encoder4(e4p)
177 | ###########################
178 | ### Decoder part - Global
179 | ###########################
180 | d4_g = self.decoder4_g(e4)
181 | d3_g = self.decoder3_g(d4_g)
182 | d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, e4)
183 | d2_g = self.decoder2_g(d2_g)
184 | d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, e4)
185 | d1_g = self.decoder1_g(d1_g)
186 | d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, e4)
187 | d0_g = self.decoder0_g(d0_g)
188 | global_sigmoid = d0_g
189 | ###########################
190 | ### Decoder part - Local
191 | ###########################
192 | d4_l = self.decoder4_l(e4)
193 | d4_l = F.max_unpool2d(d4_l, id4, kernel_size=2, stride=2)
194 | d3_l = self.tfi_3(d4_g, d4_l, e3)
195 | d3_l = self.decoder3_l(d3_l)
196 | d3_l = F.max_unpool2d(d3_l, id3, kernel_size=2, stride=2)
197 | d2_l = self.tfi_2(d3_g, d3_l, e2)
198 | d2_l = self.sbfi_2(d2_l, e0)
199 | d2_l = self.decoder2_l(d2_l)
200 | d2_l = F.max_unpool2d(d2_l, id2, kernel_size=2, stride=2)
201 | d1_l = self.tfi_1(d2_g, d2_l, e1)
202 | d1_l = self.sbfi_1(d1_l, e0)
203 | d1_l = self.decoder1_l(d1_l)
204 | d1_l = F.max_unpool2d(d1_l, id1, kernel_size=2, stride=2)
205 | d0_l = self.tfi_0(d1_g, d1_l, e0p)
206 | d0_l = self.sbfi_0(d0_l, e0)
207 | d0_l = self.decoder0_l(d0_l)
208 | d0_l = F.max_unpool2d(d0_l, id0, kernel_size=2, stride=2)
209 | d0_l = self.decoder_final_l(d0_l)
210 | local_sigmoid = F.sigmoid(d0_l)
211 | ##########################
212 | ### Fusion net - G/L
213 | ##########################
214 | fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid)
215 | return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0
216 |
--------------------------------------------------------------------------------
/core/network/RestNet34/__init__.py:
--------------------------------------------------------------------------------
1 | from .P3mNet import P3mNet
2 |
3 |
4 | __all__ = ['p3mnet_resnet34']
5 |
6 | def p3mnet_resnet34(pretrained=True, **kwargs):
7 | return P3mNet(pretrained=pretrained)
--------------------------------------------------------------------------------
/core/network/RestNet34/resnet_mp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Callable, Optional
4 | from config import PRETRAINED_R34_MP
5 | from ..modules import conv3x3, conv1x1
6 |
7 |
8 | class BasicBlock(nn.Module):
9 | expansion: int = 1
10 |
11 | def __init__(
12 | self,
13 | inplanes: int,
14 | planes: int,
15 | stride: int = 1,
16 | downsample: Optional[nn.Module] = None,
17 | groups: int = 1,
18 | base_width: int = 64,
19 | dilation: int = 1,
20 | norm_layer: Optional[Callable[..., nn.Module]] = None
21 | ) -> None:
22 | super(BasicBlock, self).__init__()
23 | if norm_layer is None:
24 | norm_layer = nn.BatchNorm2d
25 | if groups != 1 or base_width != 64:
26 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
27 | if dilation > 1:
28 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
29 |
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = norm_layer(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = norm_layer(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | identity = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | identity = self.downsample(x)
50 | out += identity
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 | __constants__ = ['downsample']
59 |
60 | def __init__(self, inplanes, planes,stride=1, downsample=None, groups=1,
61 | base_width=64, dilation=1, norm_layer=None):
62 | super(Bottleneck, self).__init__()
63 | if norm_layer is None:
64 | norm_layer = nn.BatchNorm2d
65 | width = int(planes * (base_width / 64.)) * groups
66 | self.conv1 = conv1x1(inplanes, width)
67 | self.bn1 = norm_layer(width)
68 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
69 | self.bn2 = norm_layer(width)
70 | self.conv3 = conv1x1(width, planes * self.expansion)
71 | self.bn3 = norm_layer(planes * self.expansion)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 |
76 | def forward(self, x):
77 | identity = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 | out = self.relu(out)
86 | out = self.attention(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 | if self.downsample is not None:
91 | identity = self.downsample(x)
92 | out += identity
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class ResNet(nn.Module):
99 |
100 | def __init__(self, block, layers, zero_init_residual=False,
101 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
102 | norm_layer=None):
103 | super(ResNet, self).__init__()
104 | if norm_layer is None:
105 | norm_layer = nn.BatchNorm2d
106 | self._norm_layer = norm_layer
107 | self.inplanes = 64
108 | self.dilation = 1
109 | if replace_stride_with_dilation is None:
110 | replace_stride_with_dilation = [False, False, False]
111 | if len(replace_stride_with_dilation) != 3:
112 | raise ValueError("replace_stride_with_dilation should be None "
113 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
114 | self.groups = groups
115 | self.base_width = width_per_group
116 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3,
117 | bias=False)
118 | self.bn1 = norm_layer(self.inplanes)
119 | self.relu = nn.ReLU(inplace=True)
120 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
121 | self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
122 | self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
123 | self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
124 | self.maxpool5 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
125 | #pdb.set_trace()
126 | self.layer1 = self._make_layer(block, 64, layers[0])
127 | self.layer2 = self._make_layer(block, 128, layers[1], stride=1,
128 | dilate=replace_stride_with_dilation[0])
129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
130 | dilate=replace_stride_with_dilation[1])
131 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
132 | dilate=replace_stride_with_dilation[2])
133 |
134 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
135 | self.fc = nn.Linear(512 * block.expansion, 1000)
136 |
137 | for m in self.modules():
138 | if isinstance(m, nn.Conv2d):
139 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
141 | nn.init.constant_(m.weight, 1)
142 | nn.init.constant_(m.bias, 0)
143 | if zero_init_residual:
144 | for m in self.modules():
145 | if isinstance(m, Bottleneck):
146 | nn.init.constant_(m.bn3.weight, 0)
147 | elif isinstance(m, BasicBlock):
148 | nn.init.constant_(m.bn2.weight, 0)
149 |
150 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
151 | norm_layer = self._norm_layer
152 | downsample = None
153 | previous_dilation = self.dilation
154 | if dilate:
155 | self.dilation *= stride
156 | stride = 1
157 | if stride != 1 or self.inplanes != planes * block.expansion:
158 | downsample = nn.Sequential(
159 | conv1x1(self.inplanes, planes * block.expansion, stride),
160 | norm_layer(planes * block.expansion),
161 | )
162 |
163 | layers = []
164 | layers.append(block(self.inplanes, planes,stride, downsample, self.groups,
165 | self.base_width, previous_dilation, norm_layer))
166 | self.inplanes = planes * block.expansion
167 | for _ in range(1, blocks):
168 | layers.append(block(self.inplanes, planes,groups=self.groups,
169 | base_width=self.base_width, dilation=self.dilation,
170 | norm_layer=norm_layer))
171 |
172 | return nn.Sequential(*layers)
173 |
174 | def _forward_impl(self, x):
175 | x1 = self.conv1(x)
176 | x1 = self.bn1(x1)
177 | x1 = self.relu(x1)
178 | x1, idx1 = self.maxpool1(x1)
179 |
180 | x2, idx2 = self.maxpool2(x1)
181 | x2 = self.layer1(x2)
182 |
183 | x3, idx3 = self.maxpool3(x2)
184 | x3 = self.layer2(x3)
185 |
186 | x4, idx4 = self.maxpool4(x3)
187 | x4 = self.layer3(x4)
188 |
189 | x5, idx5 = self.maxpool5(x4)
190 | x5 = self.layer4(x5)
191 |
192 | x_cls = self.avgpool(x5)
193 | x_cls = torch.flatten(x_cls, 1)
194 | x_cls = self.fc(x_cls)
195 |
196 | return x_cls
197 |
198 | def forward(self, x):
199 | return self._forward_impl(x)
200 |
201 |
202 | def resnet34_mp(pretrained=True, **kwargs):
203 | r"""ResNet-34 model from
204 | `"Deep Residual Learning for Image Recognition" `
205 | """
206 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
207 | if pretrained:
208 | checkpoint = torch.load(PRETRAINED_R34_MP)
209 | model.load_state_dict(checkpoint)
210 | print('Loaded pretrained model')
211 | return model
212 |
213 |
214 |
--------------------------------------------------------------------------------
/core/network/Swin_T/__init__.py:
--------------------------------------------------------------------------------
1 | from .swin_stem_pooling5_transformer import swin_stem_pooling5_encoder
2 | from .swin_stem_pooling5_transformer import SwinStemPooling5TransformerMatting
3 |
4 | from .decoder import SwinStemPooling5TransformerDecoderV1
5 |
6 |
7 | __all__ = ['p3mnet_swin_t']
8 |
9 |
10 | def p3mnet_swin_t(pretrained=True, img_size=512, **kwargs):
11 | encoder = swin_stem_pooling5_encoder(pretrained=pretrained, img_size=img_size, **kwargs)
12 | decoder = SwinStemPooling5TransformerDecoderV1()
13 | model = SwinStemPooling5TransformerMatting(encoder, decoder)
14 | return model
15 |
--------------------------------------------------------------------------------
/core/network/Swin_T/decoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 |
4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
5 | Licensed under the MIT License (see LICENSE for details)
6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
7 | Paper link: https://arxiv.org/abs/2203.16828
8 |
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | from util import get_masked_local_from_global
16 | from ..modules import *
17 |
18 |
19 | class SwinStemPooling5TransformerDecoderV1(nn.Module):
20 | def __init__(self) -> None:
21 | super().__init__()
22 | ##########################
23 | ### Decoder part - GLOBAL
24 | ##########################
25 | self.decoder4_g = nn.Sequential( # 768 -> 384
26 | nn.Conv2d(768,384,3,padding=1),
27 | nn.BatchNorm2d(384),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(384,384,3,padding=1),
30 | nn.BatchNorm2d(384),
31 | nn.ReLU(inplace=True),
32 | nn.Conv2d(384,384,3,padding=1),
33 | nn.BatchNorm2d(384),
34 | nn.ReLU(inplace=True),
35 | nn.Upsample(scale_factor=2, mode='bilinear') )
36 | self.decoder3_g = nn.Sequential( # 384 -> 192
37 | nn.Conv2d(384,192,3,padding=1),
38 | nn.BatchNorm2d(192),
39 | nn.ReLU(inplace=True),
40 | nn.Conv2d(192,192,3,padding=1),
41 | nn.BatchNorm2d(192),
42 | nn.ReLU(inplace=True),
43 | nn.Conv2d(192,192,3,padding=1),
44 | nn.BatchNorm2d(192),
45 | nn.ReLU(inplace=True),
46 | nn.Upsample(scale_factor=2, mode='bilinear') )
47 | self.decoder2_g = nn.Sequential( # 192 -> 96
48 | nn.Conv2d(192,96,3,padding=1),
49 | nn.BatchNorm2d(96),
50 | nn.ReLU(inplace=True),
51 | nn.Conv2d(96,96,3,padding=1),
52 | nn.BatchNorm2d(96),
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(96,96,3,padding=1),
55 | nn.BatchNorm2d(96),
56 | nn.ReLU(inplace=True),
57 | nn.Upsample(scale_factor=2, mode='bilinear'))
58 | self.decoder1_g = nn.Sequential( # 96 -> 96
59 | nn.Conv2d(96,96,3,padding=1),
60 | nn.BatchNorm2d(96),
61 | nn.ReLU(inplace=True),
62 | nn.Conv2d(96,96,3,padding=1),
63 | nn.BatchNorm2d(96),
64 | nn.ReLU(inplace=True),
65 | nn.Conv2d(96,96,3,padding=1),
66 | nn.BatchNorm2d(96),
67 | nn.ReLU(inplace=True),
68 | nn.Upsample(scale_factor=2, mode='bilinear'))
69 | self.decoder0_g = nn.Sequential( # 96 -> 48
70 | nn.Conv2d(96,48,3,padding=1),
71 | nn.BatchNorm2d(48),
72 | nn.ReLU(inplace=True),
73 | nn.Conv2d(48,48,3,padding=1),
74 | nn.BatchNorm2d(48),
75 | nn.ReLU(inplace=True),
76 | nn.Conv2d(48,48,3,padding=1),
77 | nn.Upsample(scale_factor=2, mode='bilinear'))
78 | self.decoder_final_g = nn.Conv2d(48, 3, kernel_size=3, padding=1) # 48 -> 3
79 |
80 | ##########################
81 | ### Decoder part - LOCAL
82 | ##########################
83 |
84 | self.decoder4_l = nn.Sequential( # 768 -> 384
85 | nn.Conv2d(768,384,3,padding=1),
86 | nn.BatchNorm2d(384),
87 | nn.ReLU(inplace=True),
88 | nn.Conv2d(384,384,3,padding=1),
89 | nn.BatchNorm2d(384),
90 | nn.ReLU(inplace=True),
91 | nn.Conv2d(384,384,3,padding=1),
92 | nn.BatchNorm2d(384),
93 | nn.ReLU(inplace=True))
94 | self.decoder3_l = nn.Sequential( # 384 -> 192
95 | nn.Conv2d(384,192,3,padding=1),
96 | nn.BatchNorm2d(192),
97 | nn.ReLU(inplace=True),
98 | nn.Conv2d(192,192,3,padding=1),
99 | nn.BatchNorm2d(192),
100 | nn.ReLU(inplace=True),
101 | nn.Conv2d(192,192,3,padding=1),
102 | nn.BatchNorm2d(192),
103 | nn.ReLU(inplace=True))
104 | self.decoder2_l = nn.Sequential( # 192 -> 96
105 | nn.Conv2d(192,96,3,padding=1),
106 | nn.BatchNorm2d(96),
107 | nn.ReLU(inplace=True),
108 | nn.Conv2d(96,96,3,padding=1),
109 | nn.BatchNorm2d(96),
110 | nn.ReLU(inplace=True),
111 | nn.Conv2d(96,96,3,padding=1),
112 | nn.BatchNorm2d(96),
113 | nn.ReLU(inplace=True))
114 | self.decoder1_l = nn.Sequential( # 96 -> 96
115 | nn.Conv2d(96,96,3,padding=1),
116 | nn.BatchNorm2d(96),
117 | nn.ReLU(inplace=True),
118 | nn.Conv2d(96,96,3,padding=1),
119 | nn.BatchNorm2d(96),
120 | nn.ReLU(inplace=True),
121 | nn.Conv2d(96,96,3,padding=1),
122 | nn.BatchNorm2d(96),
123 | nn.ReLU(inplace=True))
124 | self.decoder0_l = nn.Sequential( # 96 -> 48
125 | nn.Conv2d(96,48,3,padding=1),
126 | nn.BatchNorm2d(48),
127 | nn.ReLU(inplace=True),
128 | nn.Conv2d(48,48,3,padding=1),
129 | nn.BatchNorm2d(48),
130 | nn.ReLU(inplace=True))
131 | self.decoder_final_l = nn.Conv2d(48,1,3,padding=1) # 16 -> 1
132 |
133 | ##########################
134 | ### Decoder part - MODULES
135 | ##########################
136 |
137 | self.tfi_3 = TFI(384)
138 | self.tfi_2 = TFI(192)
139 | self.tfi_1 = TFI(96)
140 | self.tfi_0 = TFI(96)
141 |
142 | self.sbfi_2 = SBFI(192, 48, 8)
143 | self.sbfi_1 = SBFI(96, 48, 4)
144 | self.sbfi_0 = SBFI(96, 48, 2)
145 |
146 | self.dbfi_2 = DBFI(192, 768, 4)
147 | self.dbfi_1 = DBFI(96, 768, 8)
148 | self.dbfi_0 = DBFI(96, 768, 16)
149 |
150 | def forward(self, x, indices, feas):
151 | r"""
152 | x: [, 768, 16, 16]
153 |
154 | indices:
155 | [None]
156 | [None]
157 | [, 96, 64, 64]
158 | [, 192, 32, 32]
159 | [, 384, 16, 16]
160 |
161 | feas:
162 | [, 3, 512, 512]
163 | [, 48, 256, 256]
164 | [, 96, 128, 128]
165 | [, 192, 64, 64]
166 | [, 384, 32, 32]
167 | """
168 | ###########################
169 | ### Decoder part - Global
170 | ###########################
171 | d4_g = self.decoder4_g(x) # 768 -> 384
172 | d3_g = self.decoder3_g(d4_g) # 384 -> 192
173 | d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, x) # 192, 768 -> 192
174 | d2_g = self.decoder2_g(d2_g) # 192 -> 96
175 | d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, x) # 96, 768 -> 96
176 | d1_g = self.decoder1_g(d1_g) # 96 -> 96
177 | d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, x) # 96, 768 -> 48
178 | d0_g = self.decoder0_g(d0_g) # 96 -> 48
179 | global_sigmoid = self.decoder_final_g(d0_g) # 48 -> 3
180 | ###########################
181 | ### Decoder part - Local
182 | ###########################
183 | d4_l = self.decoder4_l(x) # 768 -> 384
184 | d4_l = F.max_unpool2d(d4_l, indices[-1], kernel_size=2, stride=2) # 384
185 | d3_l = self.tfi_3(d4_g, d4_l, feas[-1]) # 384, 384, 384 -> 384
186 | d3_l = self.decoder3_l(d3_l) # 384 -> 192
187 | d3_l = F.max_unpool2d(d3_l, indices[-2], kernel_size=2, stride=2) # 192
188 | d2_l = self.tfi_2(d3_g, d3_l, feas[-2]) # 192, 192, 192 -> 192
189 | d2_l = self.sbfi_2(d2_l, feas[-5]) # 192, 3 -> 192
190 | d2_l = self.decoder2_l(d2_l) # 192 -> 96
191 | d2_l = F.max_unpool2d(d2_l, indices[-3], kernel_size=2, stride=2) # 96
192 | d1_l = self.tfi_1(d2_g, d2_l, feas[-3]) # 96, 96, 96 -> 96
193 | d1_l = self.sbfi_1(d1_l, feas[-5]) # 96, 3 -> 96
194 | d1_l = self.decoder1_l(d1_l) # 96 -> 96
195 | d1_l = F.max_unpool2d(d1_l, indices[-4], kernel_size=2, stride=2) # 96
196 | d0_l = self.tfi_0(d1_g, d1_l, feas[-4]) # 96, 96, 96 -> 96
197 | d0_l = self.sbfi_0(d0_l, feas[-5]) # 96
198 | d0_l = self.decoder0_l(d0_l) # 96 -> 48
199 | d0_l = F.max_unpool2d(d0_l, indices[-5], kernel_size=2, stride=2) # 48
200 | d0_l = self.decoder_final_l(d0_l) # 48 -> 1
201 | local_sigmoid = torch.sigmoid(d0_l) # 1
202 | fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid)
203 | return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0
204 |
--------------------------------------------------------------------------------
/core/network/ViTAE_S/NormalCell.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd.
2 | #
3 | # This source code is licensed under the Clear BSD License
4 | # LICENSE file in the root directory of this file
5 | # All rights reserved.
6 | """
7 | Borrow from timm(https://github.com/rwightman/pytorch-image-models)
8 | """
9 | import torch
10 | import torch.nn as nn
11 | import numpy as np
12 | from timm.models.layers import DropPath
13 | from .SELayer import SELayer
14 | import math
15 |
16 | class Mlp(nn.Module):
17 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
18 | super().__init__()
19 | out_features = out_features or in_features
20 | hidden_features = hidden_features or in_features
21 | self.hidden_features = hidden_features
22 | self.fc1 = nn.Linear(in_features, hidden_features)
23 | self.act = act_layer()
24 | self.fc2 = nn.Linear(hidden_features, out_features)
25 | self.drop = nn.Dropout(drop)
26 |
27 | def forward(self, x):
28 | x = self.fc1(x)
29 | x = self.act(x)
30 | x = self.drop(x)
31 | x = self.fc2(x)
32 | x = self.drop(x)
33 | return x
34 |
35 | class Attention(nn.Module):
36 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
37 | super().__init__()
38 | self.num_heads = num_heads
39 | head_dim = dim // num_heads
40 |
41 | self.scale = qk_scale or head_dim ** -0.5
42 |
43 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
44 | self.attn_drop = nn.Dropout(attn_drop)
45 | self.proj = nn.Linear(dim, dim)
46 | self.proj_drop = nn.Dropout(proj_drop)
47 |
48 | def forward(self, x):
49 | B, N, C = x.shape
50 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
51 | q, k, v = qkv[0], qkv[1], qkv[2]
52 |
53 | attn = (q @ k.transpose(-2, -1)) * self.scale
54 | attn = attn.softmax(dim=-1)
55 | attn = self.attn_drop(attn)
56 |
57 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
58 | x = self.proj(x)
59 | x = self.proj_drop(x)
60 | return x
61 |
62 | class AttentionPerformer(nn.Module):
63 | def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., kernel_ratio=0.5):
64 | super().__init__()
65 | self.head_dim = dim // num_heads
66 | self.emb = dim
67 | self.kqv = nn.Linear(dim, 3 * self.emb)
68 | self.dp = nn.Dropout(proj_drop)
69 | self.proj = nn.Linear(self.emb, self.emb)
70 | self.head_cnt = num_heads
71 | self.norm1 = nn.LayerNorm(dim, eps=1e-6)
72 | self.epsilon = 1e-8 # for stable in division
73 | self.drop_path = nn.Identity()
74 |
75 | self.m = int(self.head_dim * kernel_ratio)
76 | self.w = torch.randn(self.head_cnt, self.m, self.head_dim)
77 | for i in range(self.head_cnt):
78 | self.w[i] = nn.Parameter(nn.init.orthogonal_(self.w[i]) * math.sqrt(self.m), requires_grad=False)
79 | self.w.requires_grad_(False)
80 |
81 | def prm_exp(self, x):
82 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch
83 | # and Simo Ryu (https://github.com/cloneofsimo)
84 | # ==== positive random features for gaussian kernels ====
85 | # x = (B, T, hs)
86 | # w = (m, hs)
87 | # return : x : B, T, m
88 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)]
89 | # therefore return exp(w^Tx - |x|/2)/sqrt(m)
90 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, 1, self.m) / 2
91 | wtx = torch.einsum('bhti,hmi->bhtm', x.float(), self.w.to(x.device))
92 |
93 | return torch.exp(wtx - xd) / math.sqrt(self.m)
94 |
95 | def attn(self, x):
96 | B, N, C = x.shape
97 | kqv = self.kqv(x).reshape(B, N, 3, self.head_cnt, self.head_dim).permute(2, 0, 3, 1, 4)
98 | k, q, v = kqv[0], kqv[1], kqv[2] # (B, H, T, hs)
99 |
100 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, H, T, m), (B, H, T, m)
101 | D = torch.einsum('bhti,bhi->bht', qp, kp.sum(dim=2)).unsqueeze(dim=-1) # (B, H, T, m) * (B, H, m) -> (B, H, T, 1)
102 | kptv = torch.einsum('bhin,bhim->bhnm', v.float(), kp) # (B, H, emb, m)
103 | y = torch.einsum('bhti,bhni->bhtn', qp, kptv) / (D.repeat(1, 1, 1, self.head_dim) + self.epsilon) # (B, H, T, emb)/Diag
104 |
105 | # skip connection
106 |
107 | y = y.permute(0, 2, 1, 3).reshape(B, N, self.emb)
108 | y = self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection
109 |
110 | return y
111 |
112 | def forward(self, x):
113 | x = self.attn(x)
114 | return x
115 |
116 | class NormalCell(nn.Module):
117 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
118 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, class_token=False, group=64, tokens_type='transformer', gamma=False, init_values=1e-4, SE=False):
119 | super().__init__()
120 | self.norm1 = norm_layer(dim)
121 | self.class_token = class_token
122 | if tokens_type == 'transformer':
123 | self.attn = Attention(
124 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
125 | elif tokens_type == 'performer':
126 | self.attn = AttentionPerformer(
127 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
128 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
129 | self.norm2 = norm_layer(dim)
130 | mlp_hidden_dim = int(dim * mlp_ratio)
131 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
132 | self.PCM = nn.Sequential(
133 | nn.Conv2d(dim, mlp_hidden_dim, 3, 1, 1, 1, group),
134 | nn.BatchNorm2d(mlp_hidden_dim),
135 | nn.SiLU(inplace=True),
136 | nn.Conv2d(mlp_hidden_dim, dim, 3, 1, 1, 1, group),
137 | nn.BatchNorm2d(dim),
138 | nn.SiLU(inplace=True),
139 | nn.Conv2d(dim, dim, 3, 1, 1, 1, group),
140 | nn.SiLU(inplace=True),
141 | )
142 | if gamma:
143 | self.gamma1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
144 | self.gamma2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
145 | self.gamma3 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
146 | else:
147 | self.gamma1 = 1
148 | self.gamma2 = 1
149 | self.gamma3 = 1
150 | if SE:
151 | self.SE = SELayer(dim)
152 | else:
153 | self.SE = nn.Identity()
154 |
155 | def forward(self, x, input_resolution=None):
156 | b, n, c = x.shape
157 | if self.class_token:
158 | n = n - 1
159 | wh = input_resolution[0] if input_resolution is not None else int(math.sqrt(n))
160 | ww = input_resolution[1] if input_resolution is not None else int(math.sqrt(n))
161 | convX = self.drop_path(self.gamma2 * self.PCM(x[:, 1:, :].view(b, wh, ww, c).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous().view(b, n, c))
162 | x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
163 | x[:, 1:] = x[:, 1:] + convX
164 | else:
165 | wh = input_resolution[0] if input_resolution is not None else int(math.sqrt(n))
166 | ww = input_resolution[1] if input_resolution is not None else int(math.sqrt(n))
167 | convX = self.drop_path(self.gamma2 * self.PCM(x.view(b, wh, ww, c).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous().view(b, n, c))
168 | x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
169 | x = x + convX
170 | x = x + self.drop_path(self.gamma3 * self.mlp(self.norm2(x)))
171 | x = self.SE(x)
172 | return x
173 |
174 | def get_sinusoid_encoding(n_position, d_hid):
175 | ''' Sinusoid position encoding table '''
176 |
177 | def get_position_angle_vec(position):
178 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
179 |
180 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
181 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
182 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
183 |
184 | return torch.FloatTensor(sinusoid_table).unsqueeze(0)
--------------------------------------------------------------------------------
/core/network/ViTAE_S/SELayer.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 |
4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
5 | Licensed under the MIT License (see LICENSE for details)
6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
7 | Paper link: https://arxiv.org/abs/2203.16828
8 |
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | class SELayer(nn.Module):
15 | def __init__(self, channel, reduction=16):
16 | super(SELayer, self).__init__()
17 | self.avg_pool = nn.AdaptiveAvgPool1d(1)
18 | self.fc = nn.Sequential(
19 | nn.Linear(channel, channel // reduction, bias=False),
20 | nn.ReLU(inplace=True),
21 | nn.Linear(channel // reduction, channel, bias=False),
22 | nn.Sigmoid()
23 | )
24 |
25 | def forward(self, x): # x: [B, N, C]
26 | x = torch.transpose(x, 1, 2) # [B, C, N]
27 | b, c, _ = x.size()
28 | y = self.avg_pool(x).view(b, c)
29 | y = self.fc(y).view(b, c, 1)
30 | x = x * y.expand_as(x)
31 | x = torch.transpose(x, 1, 2) # [B, N, C]
32 | return x
--------------------------------------------------------------------------------
/core/network/ViTAE_S/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import ViTAE_noRC_MaxPooling_bias_basic_stages4_14
2 | from .models import ViTAE_noRC_MaxPooling_DecoderV1
3 | from .models import ViTAE_noRC_MaxPooling_Matting
4 |
5 |
6 | __all__ = ['p3mnet_vitae_s']
7 |
8 |
9 | def p3mnet_vitae_s(pretrained=True, **kwargs):
10 | encoder = ViTAE_noRC_MaxPooling_bias_basic_stages4_14(pretrained=pretrained, **kwargs)
11 | decoder = ViTAE_noRC_MaxPooling_DecoderV1()
12 | model = ViTAE_noRC_MaxPooling_Matting(encoder, decoder)
13 | return model
14 |
--------------------------------------------------------------------------------
/core/network/ViTAE_S/base_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 |
4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
5 | Licensed under the MIT License (see LICENSE for details)
6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
7 | Paper link: https://arxiv.org/abs/2203.16828
8 |
9 | """
10 |
11 | from functools import partial
12 | import torch
13 | import torch.nn as nn
14 | from timm.models.layers import trunc_normal_
15 | import numpy as np
16 | from .NormalCell import NormalCell
17 | from timm.models.layers import to_2tuple, trunc_normal_
18 |
19 | class PatchMerging(nn.Module):
20 | r""" Patch Merging Layer.
21 |
22 | Args:
23 | input_resolution (tuple[int]): Resolution of input feature.
24 | dim (int): Number of input channels.
25 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
26 | """
27 |
28 | def __init__(self, input_resolution, in_dim, out_dim, norm_layer=nn.LayerNorm):
29 | super().__init__()
30 | self.input_resolution = input_resolution
31 | self.pooling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
32 | self.in_dim = in_dim
33 | self.out_dim = out_dim
34 | # self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
35 | self.norm = norm_layer(in_dim)
36 | self.linear = nn.Linear(in_dim, out_dim, bias=False)
37 |
38 | def forward(self, x, input_resolution=None):
39 | """
40 | x: B, H*W, C
41 | """
42 | if input_resolution is None:
43 | input_resolution = self.input_resolution
44 | H, W = input_resolution
45 | B, L, C = x.shape
46 | assert L == H * W, "input feature has wrong size"
47 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
48 |
49 | x = self.norm(x)
50 | x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
51 |
52 | x_fea = x # 64, 128, 256
53 | x, idx = self.pooling(x)
54 | x = x.permute(0, 2, 3, 1).reshape(B, -1, C).contiguous()
55 | x = self.linear(x)
56 |
57 | return x, [idx], [x_fea]
58 |
59 | def extra_repr(self) -> str:
60 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
61 |
62 | def flops(self):
63 | H, W = self.input_resolution
64 | flops = H * W * self.dim
65 | flops += (H // 2) * (W // 2) * self.dim * 2 * self.dim * 2
66 | return flops
67 |
68 |
69 | class PatchEmbed(nn.Module):
70 | r""" Image to Patch Embedding
71 |
72 | Args:
73 | img_size (int): Image size. Default: 224.
74 | patch_size (int): Patch token size. Default: 4.
75 | in_chans (int): Number of input image channels. Default: 3.
76 | embed_dim (int): Number of linear projection output channels. Default: 96.
77 | norm_layer (nn.Module, optional): Normalization layer. Default: None
78 | """
79 |
80 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
81 | super().__init__()
82 | _patch_size = patch_size
83 | img_size = to_2tuple(img_size)
84 | patch_size = to_2tuple(patch_size)
85 | strides = []
86 | while _patch_size > 1:
87 | strides.append(2)
88 | _patch_size = _patch_size // 2
89 | if len(strides) < 3:
90 | strides.append(1)
91 | self.strides = strides
92 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
93 | self.img_size = img_size
94 | self.patch_size = patch_size
95 | self.patches_resolution = patches_resolution
96 | self.num_patches = patches_resolution[0] * patches_resolution[1]
97 |
98 | self.in_chans = in_chans
99 | self.inter_chans = embed_dim // 2
100 | self.embed_dim = embed_dim
101 | # self.proj = nn.Sequential(
102 | # nn.Conv2d(self.in_chans, self.inter_chans, 3, stride=strides[0], padding=1),
103 | # nn.ReLU(),
104 | # nn.Conv2d(self.inter_chans, embed_dim, 3, stride=strides[1], padding=1),
105 | # nn.ReLU(),
106 | # nn.Conv2d(embed_dim, embed_dim, 3, stride=strides[2], padding=1)
107 | # )
108 |
109 | self.proj1 = nn.Conv2d(self.in_chans, self.inter_chans, 3, stride=1, padding=1)
110 | self.relu1 = nn.ReLU()
111 | self.proj2 = nn.Conv2d(self.inter_chans, embed_dim, 3, stride=1, padding=1)
112 | self.relu2 = nn.ReLU()
113 | self.proj3 = nn.Conv2d(embed_dim, embed_dim, 3, stride=1, padding=1)
114 |
115 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
116 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
117 |
118 | if norm_layer is not None:
119 | self.norm = norm_layer(embed_dim)
120 | else:
121 | self.norm = None
122 |
123 | def forward(self, x, input_resolution=None):
124 | x = self.relu1(self.proj1(x))
125 |
126 | x_fea1 = x # 32
127 | x, idx1 = self.maxpool1(x)
128 |
129 | x = self.relu2(self.proj2(x))
130 | x_fea2 = x # 64
131 | x, idx2 = self.maxpool2(x)
132 |
133 | x = self.proj3(x)
134 | x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
135 | if self.norm is not None:
136 | x = self.norm(x)
137 | return x, [idx1, idx2], [x_fea1, x_fea2]
138 |
139 |
140 | class BasicLayer(nn.Module):
141 | def __init__(self, img_size=224, in_chans=3, embed_dims=64, token_dims=64, downsample_ratios=4, kernel_size=7, RC_heads=1, NC_heads=6, dilations=[1, 2, 3, 4],
142 | RC_op='cat', RC_tokens_type='performer', NC_tokens_type='transformer', RC_group=1, NC_group=64, NC_depth=2, dpr=0.1, mlp_ratio=4., qkv_bias=True,
143 | qk_scale=None, drop=0, attn_drop=0., norm_layer=nn.LayerNorm, class_token=False, gamma=False, init_values=1e-4, SE=False):
144 | super().__init__()
145 | self.img_size = to_2tuple(img_size)
146 | self.in_chans = in_chans
147 | self.embed_dims = embed_dims
148 | self.token_dims = token_dims
149 | self.downsample_ratios = downsample_ratios
150 | self.out_size = to_2tuple(img_size // self.downsample_ratios)
151 | self.RC_kernel_size = kernel_size
152 | self.RC_heads = RC_heads
153 | self.NC_heads = NC_heads
154 | self.dilations = dilations
155 | self.RC_op = RC_op
156 | self.RC_tokens_type = RC_tokens_type
157 | self.RC_group = RC_group
158 | self.NC_group = NC_group
159 | self.NC_depth = NC_depth
160 | if downsample_ratios > 2:
161 | self.RC = PatchEmbed(img_size=self.img_size, embed_dim=token_dims, norm_layer=nn.LayerNorm)
162 | elif downsample_ratios == 2:
163 | self.RC = PatchMerging(input_resolution=self.img_size, in_dim=in_chans, out_dim=token_dims)
164 | else:
165 | self.RC = nn.Identity()
166 |
167 | self.NC = nn.ModuleList([
168 | NormalCell(token_dims, NC_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
169 | drop_path=dpr[i] if isinstance(dpr, list) else dpr, norm_layer=norm_layer, class_token=class_token, group=NC_group, tokens_type=NC_tokens_type,
170 | gamma=gamma, init_values=init_values, SE=SE)
171 | for i in range(NC_depth)])
172 |
173 | def forward(self, x, input_resolution=None):
174 | if input_resolution is None:
175 | input_resolution = self.img_size
176 | x, indices, feas = self.RC(x, input_resolution=input_resolution)
177 | input_resolution = [input_resolution[0]//self.downsample_ratios, input_resolution[1]//self.downsample_ratios]
178 | for nc in self.NC:
179 | x = nc(x, input_resolution=input_resolution)
180 | return x, indices, feas
181 |
182 |
183 | class ViTAE_noRC_MaxPooling_basic(nn.Module):
184 | def __init__(self, img_size=224, in_chans=3, stages=4, embed_dims=64, token_dims=64, downsample_ratios=[4, 2, 2, 2], kernel_size=[7, 3, 3, 3],
185 | RC_heads=[1, 1, 1, 1], NC_heads=4, dilations=[[1, 2, 3, 4], [1, 2, 3], [1, 2], [1, 2]],
186 | RC_op='cat', RC_tokens_type=['performer', 'transformer', 'transformer', 'transformer'], NC_tokens_type='transformer',
187 | RC_group=[1, 1, 1, 1], NC_group=[1, 32, 64, 64], NC_depth=[2, 2, 6, 2], mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0.,
188 | attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=1000,
189 | gamma=False, init_values=1e-4, SE=False):
190 | super().__init__()
191 | self.num_classes = num_classes
192 | self.stages = stages
193 | repeatOrNot = (lambda x, y, z=list: x if isinstance(x, z) else [x for _ in range(y)])
194 | self.embed_dims = repeatOrNot(embed_dims, stages)
195 | self.tokens_dims = token_dims if isinstance(token_dims, list) else [token_dims * (2 ** i) for i in range(stages)]
196 | self.downsample_ratios = repeatOrNot(downsample_ratios, stages)
197 | self.kernel_size = repeatOrNot(kernel_size, stages)
198 | self.RC_heads = repeatOrNot(RC_heads, stages)
199 | self.NC_heads = repeatOrNot(NC_heads, stages)
200 | self.dilaions = repeatOrNot(dilations, stages)
201 | self.RC_op = repeatOrNot(RC_op, stages)
202 | self.RC_tokens_type = repeatOrNot(RC_tokens_type, stages)
203 | self.NC_tokens_type = repeatOrNot(NC_tokens_type, stages)
204 | self.RC_group = repeatOrNot(RC_group, stages)
205 | self.NC_group = repeatOrNot(NC_group, stages)
206 | self.NC_depth = repeatOrNot(NC_depth, stages)
207 | self.mlp_ratio = repeatOrNot(mlp_ratio, stages)
208 | self.qkv_bias = repeatOrNot(qkv_bias, stages)
209 | self.qk_scale = repeatOrNot(qk_scale, stages)
210 | self.drop = repeatOrNot(drop_rate, stages)
211 | self.attn_drop = repeatOrNot(attn_drop_rate, stages)
212 | self.norm_layer = repeatOrNot(norm_layer, stages)
213 |
214 | self.pos_drop = nn.Dropout(p=drop_rate)
215 | depth = np.sum(self.NC_depth)
216 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
217 | Layers = []
218 | for i in range(stages):
219 | startDpr = 0 if i==0 else self.NC_depth[i - 1]
220 | Layers.append(
221 | BasicLayer(img_size, in_chans, self.embed_dims[i], self.tokens_dims[i], self.downsample_ratios[i],
222 | self.kernel_size[i], self.RC_heads[i], self.NC_heads[i], self.dilaions[i], self.RC_op[i],
223 | self.RC_tokens_type[i], self.NC_tokens_type[i], self.RC_group[i], self.NC_group[i], self.NC_depth[i], dpr[startDpr:self.NC_depth[i]+startDpr],
224 | mlp_ratio=self.mlp_ratio[i], qkv_bias=self.qkv_bias[i], qk_scale=self.qk_scale[i], drop=self.drop[i], attn_drop=self.attn_drop[i],
225 | norm_layer=self.norm_layer[i], gamma=gamma, init_values=init_values, SE=SE)
226 | )
227 | img_size = img_size // self.downsample_ratios[i]
228 | in_chans = self.tokens_dims[i]
229 | self.layers = nn.ModuleList(Layers)
230 |
231 | # Classifier head
232 | self.head = nn.Linear(self.tokens_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
233 |
234 | self.apply(self._init_weights)
235 | # load state dict here TODO
236 |
237 | def _init_weights(self, m):
238 | if isinstance(m, nn.Linear):
239 | trunc_normal_(m.weight, std=.02)
240 | if isinstance(m, nn.Linear) and m.bias is not None:
241 | nn.init.constant_(m.bias, 0)
242 | elif isinstance(m, nn.LayerNorm):
243 | nn.init.constant_(m.bias, 0)
244 | nn.init.constant_(m.weight, 1.0)
245 |
246 | @torch.jit.ignore
247 | def no_weight_decay(self):
248 | return {'cls_token'}
249 |
250 | def get_classifier(self):
251 | return self.head
252 |
253 | def reset_classifier(self, num_classes):
254 | self.num_classes = num_classes
255 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
256 |
257 | def forward_features(self, x):
258 | indices = []
259 | feas = []
260 | B, C, H, W = x.shape
261 | input_resolution = [H, W]
262 |
263 | for layer_idx in range(len(self.layers)):
264 | x, idx, fea = self.layers[layer_idx](x, input_resolution=input_resolution)
265 | indices = indices + idx
266 | feas = feas + fea
267 | input_resolution = [input_resolution[0]//self.downsample_ratios[layer_idx], input_resolution[1]//self.downsample_ratios[layer_idx]]
268 |
269 | # x = x.view(B, -1, input_resolution[0], input_resolution[1]).contiguous()
270 | x = x.view(B, input_resolution[0], input_resolution[1], -1).permute(0,3,1,2).contiguous()
271 | return x, indices, feas
272 |
273 | def forward(self, x):
274 | return self.forward_features(x)
275 |
276 | def train(self, mode=True, tag='default'):
277 | self.training = mode
278 | if tag == 'default':
279 | for module in self.modules():
280 | if module.__class__.__name__ != 'ViTAE_noRC_MaxPooling_basic':
281 | module.train(mode)
282 | elif tag == 'linear':
283 | for module in self.modules():
284 | if module.__class__.__name__ != 'ViTAE_noRC_MaxPooling_basic':
285 | module.eval()
286 | for param in module.parameters():
287 | param.requires_grad = False
288 | elif tag == 'linearLNBN':
289 | for module in self.modules():
290 | if module.__class__.__name__ != 'ViTAE_noRC_MaxPooling_basic':
291 | if isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm2d):
292 | module.train(mode)
293 | for param in module.parameters():
294 | param.requires_grad = True
295 | else:
296 | module.eval()
297 | for param in module.parameters():
298 | param.requires_grad = False
299 | self.head.train(mode)
300 | for param in self.head.parameters():
301 | param.requires_grad = True
302 | return self
303 |
--------------------------------------------------------------------------------
/core/network/ViTAE_S/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd.
2 | #
3 | # This source code is licensed under the Clear BSD License
4 | # LICENSE file in the root directory of this file
5 | # All rights reserved.
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from timm.models.registry import register_model
11 |
12 | from .base_model import ViTAE_noRC_MaxPooling_basic
13 | from ..modules import TFI, SBFI, DBFI
14 | from util import get_masked_local_from_global
15 | from config import PRETRAINED_VITAE_NORC_MAXPOOLING_BIAS_BASIC_STAGE4_14
16 |
17 |
18 | ##########################################
19 | ## Encoder
20 | ##########################################
21 |
22 | def _cfg(url='', **kwargs):
23 | return {
24 | 'url': url,
25 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
26 | 'crop_pct': .9, 'interpolation': 'bicubic',
27 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
28 | 'classifier': 'head',
29 | **kwargs
30 | }
31 |
32 | default_cfgs = {
33 | 'ViTAE_stages3_7': _cfg(),
34 | }
35 |
36 | @register_model
37 | def ViTAE_noRC_MaxPooling_bias_basic_stages4_14(pretrained=True, **kwargs): # adopt performer for tokens to token
38 | model = ViTAE_noRC_MaxPooling_basic(RC_tokens_type=['performer', 'transformer', 'transformer', 'transformer'], NC_tokens_type=['performer', 'transformer', 'transformer', 'transformer'], stages=4, embed_dims=[64, 64, 128, 256], token_dims=[64, 128, 256, 512], downsample_ratios=[4, 2, 2, 2],
39 | NC_depth=[2, 2, 12, 2], NC_heads=[1, 2, 4, 8], RC_heads=[1, 1, 2, 4], mlp_ratio=4., NC_group=[1, 32, 64, 128], RC_group=[1, 16, 32, 64], **kwargs)
40 | model.default_cfg = default_cfgs['ViTAE_stages3_7']
41 | if pretrained:
42 | ckpt = torch.load(PRETRAINED_VITAE_NORC_MAXPOOLING_BIAS_BASIC_STAGE4_14)['state_dict_ema']
43 | model.load_state_dict(ckpt, strict=True)
44 | return model
45 |
46 |
47 | ##########################################
48 | ## Decoder
49 | ##########################################
50 |
51 | class ViTAE_noRC_MaxPooling_DecoderV1(nn.Module):
52 | def __init__(self):
53 | super().__init__()
54 | ##########################
55 | ### Decoder part - GLOBAL
56 | ##########################
57 | in_chan = 512
58 | out_chan = 256
59 | self.decoder4_g = nn.Sequential( # 512 -> 256
60 | nn.Conv2d(in_chan,out_chan,3,padding=1),
61 | nn.BatchNorm2d(out_chan),
62 | nn.ReLU(inplace=True),
63 | nn.Conv2d(out_chan,out_chan,3,padding=1),
64 | nn.BatchNorm2d(out_chan),
65 | nn.ReLU(inplace=True),
66 | nn.Conv2d(out_chan,out_chan,3,padding=1),
67 | nn.BatchNorm2d(out_chan),
68 | nn.ReLU(inplace=True),
69 | nn.Upsample(scale_factor=2, mode='bilinear') )
70 |
71 | in_chan = 256
72 | out_chan = 128
73 | self.decoder3_g = nn.Sequential( # 256 -> 128
74 | nn.Conv2d(in_chan,out_chan,3,padding=1),
75 | nn.BatchNorm2d(out_chan),
76 | nn.ReLU(inplace=True),
77 | nn.Conv2d(out_chan,out_chan,3,padding=1),
78 | nn.BatchNorm2d(out_chan),
79 | nn.ReLU(inplace=True),
80 | nn.Conv2d(out_chan,out_chan,3,padding=1),
81 | nn.BatchNorm2d(out_chan),
82 | nn.ReLU(inplace=True),
83 | nn.Upsample(scale_factor=2, mode='bilinear') )
84 |
85 | in_chan = 128
86 | out_chan = 64
87 | self.decoder2_g = nn.Sequential( # 128 -> 64
88 | nn.Conv2d(in_chan,out_chan,3,padding=1),
89 | nn.BatchNorm2d(out_chan),
90 | nn.ReLU(inplace=True),
91 | nn.Conv2d(out_chan,out_chan,3,padding=1),
92 | nn.BatchNorm2d(out_chan),
93 | nn.ReLU(inplace=True),
94 | nn.Conv2d(out_chan,out_chan,3,padding=1),
95 | nn.BatchNorm2d(out_chan),
96 | nn.ReLU(inplace=True),
97 | nn.Upsample(scale_factor=2, mode='bilinear'))
98 |
99 | in_chan = 64
100 | out_chan = 64
101 | self.decoder1_g = nn.Sequential( # 64 -> 64
102 | nn.Conv2d(in_chan,out_chan,3,padding=1),
103 | nn.BatchNorm2d(out_chan),
104 | nn.ReLU(inplace=True),
105 | nn.Conv2d(out_chan,out_chan,3,padding=1),
106 | nn.BatchNorm2d(out_chan),
107 | nn.ReLU(inplace=True),
108 | nn.Conv2d(out_chan,out_chan,3,padding=1),
109 | nn.BatchNorm2d(out_chan),
110 | nn.ReLU(inplace=True),
111 | nn.Upsample(scale_factor=2, mode='bilinear'))
112 |
113 | in_chan = 64
114 | out_chan = 32
115 | self.decoder0_g = nn.Sequential( # 64 -> 32
116 | nn.Conv2d(in_chan,out_chan,3,padding=1),
117 | nn.BatchNorm2d(out_chan),
118 | nn.ReLU(inplace=True),
119 | nn.Conv2d(out_chan,out_chan,3,padding=1),
120 | nn.BatchNorm2d(out_chan),
121 | nn.ReLU(inplace=True),
122 | nn.Conv2d(out_chan,out_chan,3,padding=1),
123 | nn.Upsample(scale_factor=2, mode='bilinear'))
124 | self.decoder_final_g = nn.Conv2d(out_chan, 3, kernel_size=3, padding=1) # 32 -> 3
125 |
126 | ##########################
127 | ### Decoder part - LOCAL
128 | ##########################
129 |
130 | in_chan = 512
131 | out_chan = 256
132 | self.decoder4_l = nn.Sequential( # 512 -> 256
133 | nn.Conv2d(in_chan,out_chan,3,padding=1),
134 | nn.BatchNorm2d(out_chan),
135 | nn.ReLU(inplace=True),
136 | nn.Conv2d(out_chan,out_chan,3,padding=1),
137 | nn.BatchNorm2d(out_chan),
138 | nn.ReLU(inplace=True),
139 | nn.Conv2d(out_chan,out_chan,3,padding=1),
140 | nn.BatchNorm2d(out_chan),
141 | nn.ReLU(inplace=True))
142 |
143 | in_chan = 256
144 | out_chan = 128
145 | self.decoder3_l = nn.Sequential( # 256 -> 128
146 | nn.Conv2d(in_chan,out_chan,3,padding=1),
147 | nn.BatchNorm2d(out_chan),
148 | nn.ReLU(inplace=True),
149 | nn.Conv2d(out_chan,out_chan,3,padding=1),
150 | nn.BatchNorm2d(out_chan),
151 | nn.ReLU(inplace=True),
152 | nn.Conv2d(out_chan,out_chan,3,padding=1),
153 | nn.BatchNorm2d(out_chan),
154 | nn.ReLU(inplace=True))
155 |
156 | in_chan = 128
157 | out_chan = 64
158 | self.decoder2_l = nn.Sequential( # 128 -> 64
159 | nn.Conv2d(in_chan,out_chan,3,padding=1),
160 | nn.BatchNorm2d(out_chan),
161 | nn.ReLU(inplace=True),
162 | nn.Conv2d(out_chan,out_chan,3,padding=1),
163 | nn.BatchNorm2d(out_chan),
164 | nn.ReLU(inplace=True),
165 | nn.Conv2d(out_chan,out_chan,3,padding=1),
166 | nn.BatchNorm2d(out_chan),
167 | nn.ReLU(inplace=True))
168 |
169 | in_chan = 64
170 | out_chan = 64
171 | self.decoder1_l = nn.Sequential( # 64 -> 64
172 | nn.Conv2d(in_chan,out_chan,3,padding=1),
173 | nn.BatchNorm2d(out_chan),
174 | nn.ReLU(inplace=True),
175 | nn.Conv2d(out_chan,out_chan,3,padding=1),
176 | nn.BatchNorm2d(out_chan),
177 | nn.ReLU(inplace=True),
178 | nn.Conv2d(out_chan,out_chan,3,padding=1),
179 | nn.BatchNorm2d(out_chan),
180 | nn.ReLU(inplace=True))
181 |
182 | in_chan = 64
183 | out_chan = 32
184 | self.decoder0_l = nn.Sequential( # 64 -> 32
185 | nn.Conv2d(in_chan,out_chan,3,padding=1),
186 | nn.BatchNorm2d(out_chan),
187 | nn.ReLU(inplace=True),
188 | nn.Conv2d(out_chan,out_chan,3,padding=1),
189 | nn.BatchNorm2d(out_chan),
190 | nn.ReLU(inplace=True))
191 | self.decoder_final_l = nn.Conv2d(out_chan,1,3,padding=1) # 32 -> 1
192 |
193 | ##########################
194 | ### Decoder part - MODULES
195 | ##########################
196 |
197 | self.tfi_3 = TFI(256)
198 | self.tfi_2 = TFI(128)
199 | self.tfi_1 = TFI(64)
200 | self.tfi_0 = TFI(64)
201 |
202 | self.sbfi_2 = SBFI(128, 32, 8)
203 | self.sbfi_1 = SBFI(64, 32, 4)
204 | self.sbfi_0 = SBFI(64, 32, 2)
205 |
206 | self.dbfi_2 = DBFI(128, 512, 4)
207 | self.dbfi_1 = DBFI(64, 512, 8)
208 | self.dbfi_0 = DBFI(64, 512, 16)
209 |
210 | def forward(self, x, indices, feas):
211 | ###########################
212 | ### Decoder part - Global
213 | ###########################
214 | d4_g = self.decoder4_g(x) # 512 -> 256
215 | d3_g = self.decoder3_g(d4_g) # 256 -> 128
216 | d2_g, global_sigmoid_side2 = self.dbfi_2(d3_g, x) # 128, 512 -> 128
217 | d2_g = self.decoder2_g(d2_g) # 128 -> 64
218 | d1_g, global_sigmoid_side1 = self.dbfi_1(d2_g, x) # 64, 512 -> 64
219 | d1_g = self.decoder1_g(d1_g) # 64 -> 64
220 | d0_g, global_sigmoid_side0 = self.dbfi_0(d1_g, x) # 64, 512 -> 64
221 | d0_g = self.decoder0_g(d0_g) # 64 -> 32
222 | global_sigmoid = self.decoder_final_g(d0_g) # 32 -> 3
223 | ###########################
224 | ### Decoder part - Local
225 | ###########################
226 | d4_l = self.decoder4_l(x) # 512 -> 256
227 | d4_l = F.max_unpool2d(d4_l, indices[-1], kernel_size=2, stride=2) # 256
228 | d3_l = self.tfi_3(d4_g, d4_l, feas[-1]) # 256, 256, 256 -> 256
229 | d3_l = self.decoder3_l(d3_l) # 256 -> 128
230 | d3_l = F.max_unpool2d(d3_l, indices[-2], kernel_size=2, stride=2) # 128
231 | d2_l = self.tfi_2(d3_g, d3_l, feas[-2]) # 128, 128, 128 -> 128
232 | d2_l = self.sbfi_2(d2_l, feas[-5]) # 128, 32 -> 128
233 | d2_l = self.decoder2_l(d2_l) # 128 -> 64
234 | d2_l = F.max_unpool2d(d2_l, indices[-3], kernel_size=2, stride=2) # 64
235 | d1_l = self.tfi_1(d2_g, d2_l, feas[-3]) # 64, 64, 64 -> 64
236 | d1_l = self.sbfi_1(d1_l, feas[-5]) # 64, 32 -> 64
237 | d1_l = self.decoder1_l(d1_l) # 64 -> 64
238 | d1_l = F.max_unpool2d(d1_l, indices[-4], kernel_size=2, stride=2) # 64
239 | d0_l = self.tfi_0(d1_g, d1_l, feas[-4]) # 64, 64, 64 -> 64
240 | d0_l = self.sbfi_0(d0_l, feas[-5]) # 64, 32 -> 64
241 | d0_l = self.decoder0_l(d0_l) # 64 -> 32
242 | d0_l = F.max_unpool2d(d0_l, indices[-5], kernel_size=2, stride=2) # 32
243 | d0_l = self.decoder_final_l(d0_l) # 32 -> 1
244 | local_sigmoid = torch.sigmoid(d0_l) # 1
245 | fusion_sigmoid = get_masked_local_from_global(global_sigmoid, local_sigmoid)
246 | return global_sigmoid, local_sigmoid, fusion_sigmoid, global_sigmoid_side2, global_sigmoid_side1, global_sigmoid_side0
247 |
248 |
249 | ##########################################
250 | ## Matting Model
251 | ##########################################
252 |
253 | class ViTAE_noRC_MaxPooling_Matting(nn.Module):
254 | def __init__(self, encoder, decoder) -> None:
255 | super().__init__()
256 | self.encoder = encoder
257 | self.decoder = decoder
258 |
259 | def forward(self, x):
260 | embeddings, indices, feas = self.encoder(x)
261 | return self.decoder(embeddings, indices, feas)
262 |
--------------------------------------------------------------------------------
/core/network/ViTAE_S/token_performer.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 |
4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
5 | Licensed under the MIT License (see LICENSE for details)
6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
7 | Paper link: https://arxiv.org/abs/2203.16828
8 |
9 | """
10 |
11 | """
12 | Take Performer as T2T Transformer
13 | """
14 | import math
15 | import torch
16 | import torch.nn as nn
17 |
18 |
19 | class Token_performer(nn.Module):
20 | def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1, gamma=False, init_values=1e-4):
21 | super().__init__()
22 | self.head_dim = in_dim // head_cnt
23 | self.emb = in_dim
24 | self.kqv = nn.Linear(dim, 3 * self.emb)
25 | self.dp = nn.Dropout(dp1)
26 | self.proj = nn.Linear(self.emb, self.emb)
27 | self.head_cnt = head_cnt
28 | self.norm1 = nn.LayerNorm(dim)
29 | self.norm2 = nn.LayerNorm(self.emb)
30 | self.epsilon = 1e-8 # for stable in division
31 | self.drop_path = nn.Identity()
32 |
33 | self.mlp = nn.Sequential(
34 | nn.Linear(self.emb, 1 * self.emb),
35 | nn.GELU(),
36 | nn.Linear(1 * self.emb, self.emb),
37 | nn.Dropout(dp2),
38 | )
39 |
40 | self.m = int(self.head_dim * kernel_ratio)
41 | self.w = torch.randn(head_cnt, self.m, self.head_dim)
42 | for i in range(self.head_cnt):
43 | self.w[i] = nn.Parameter(nn.init.orthogonal_(self.w[i]) * math.sqrt(self.m), requires_grad=False)
44 | self.w.requires_grad_(False)
45 |
46 | if gamma:
47 | self.gamma1 = nn.Parameter(init_values * torch.ones((self.emb)),requires_grad=True)
48 | else:
49 | self.gamma1 = 1
50 |
51 | def prm_exp(self, x):
52 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch
53 | # and Simo Ryu (https://github.com/cloneofsimo)
54 | # ==== positive random features for gaussian kernels ====
55 | # x = (B, H, N, hs)
56 | # w = (H, m, hs)
57 | # return : x : B, T, m
58 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)]
59 | # therefore return exp(w^Tx - |x|/2)/sqrt(m)
60 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, 1, self.m) / 2
61 | wtx = torch.einsum('bhti,hmi->bhtm', x.float(), self.w.to(x.device))
62 |
63 | return torch.exp(wtx - xd) / math.sqrt(self.m)
64 |
65 | def attn(self, x):
66 | B, N, C = x.shape
67 | kqv = self.kqv(x).reshape(B, N, 3, self.head_cnt, self.head_dim).permute(2, 0, 3, 1, 4)
68 | k, q, v = kqv[0], kqv[1], kqv[2] # (B, H, T, hs)
69 |
70 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, H, T, m), (B, H, T, m)
71 | D = torch.einsum('bhti,bhi->bht', qp, kp.sum(dim=2)).unsqueeze(dim=-1) # (B, H, T, m) * (B, H, m) -> (B, H, T, 1)
72 | kptv = torch.einsum('bhin,bhim->bhnm', v.float(), kp) # (B, H, emb, m)
73 | y = torch.einsum('bhti,bhni->bhtn', qp, kptv) / (D.repeat(1, 1, 1, self.head_dim) + self.epsilon) # (B, H, T, emb)/Diag
74 |
75 | # skip connection
76 |
77 | y = y.permute(0, 2, 1, 3).reshape(B, N, self.emb)
78 | v = v.permute(0, 2, 1, 3).reshape(B, N, self.emb)
79 |
80 | y = v + self.dp(self.gamma1 * self.proj(y)) # same as token_transformer, use v as skip connection
81 |
82 | return y
83 |
84 | def forward(self, x):
85 | x = self.attn(self.norm1(x))
86 | x = x + self.mlp(self.norm2(x))
87 | return x
--------------------------------------------------------------------------------
/core/network/ViTAE_S/token_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd.
2 | #
3 | # This source code is licensed under the Clear BSD License
4 | # LICENSE file in the root directory of this file
5 | # All rights reserved.
6 | """
7 | Take the standard Transformer as T2T Transformer
8 | """
9 | import torch
10 | import torch.nn as nn
11 | from timm.models.layers import DropPath
12 | from .NormalCell import Mlp
13 |
14 | class Attention(nn.Module):
15 | def __init__(self, dim, num_heads=8, in_dim = None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., gamma=False, init_values=1e-4):
16 | super().__init__()
17 | self.num_heads = num_heads
18 | self.in_dim = in_dim
19 | head_dim = in_dim // num_heads
20 | self.scale = qk_scale or head_dim ** -0.5
21 |
22 | self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias)
23 | self.attn_drop = nn.Dropout(attn_drop)
24 | self.proj = nn.Linear(in_dim, in_dim)
25 | self.proj_drop = nn.Dropout(proj_drop)
26 | if gamma:
27 | self.gamma1 = nn.Parameter(init_values * torch.ones((in_dim)),requires_grad=True)
28 | else:
29 | self.gamma1 = 1
30 |
31 | def forward(self, x):
32 | B, N, C = x.shape
33 |
34 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim // self.num_heads).permute(2, 0, 3, 1, 4)
35 | q, k, v = qkv[0], qkv[1], qkv[2]
36 |
37 | attn = (q @ k.transpose(-2, -1)) * self.scale
38 | attn = attn.softmax(dim=-1)
39 | attn = self.attn_drop(attn)
40 |
41 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim)
42 | x = self.proj(x)
43 | x = self.proj_drop(self.gamma1 * x)
44 | v = v.permute(0, 2, 1, 3).view(B, N, self.in_dim).contiguous()
45 | # skip connection
46 | x = v + x # because the original x has different size with current x, use v to do skip connection
47 |
48 | return x
49 |
50 | class Token_transformer(nn.Module):
51 |
52 | def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
53 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, gamma=False, init_values=1e-4):
54 | super().__init__()
55 | self.norm1 = norm_layer(dim)
56 | self.attn = Attention(
57 | dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, gamma=gamma, init_values=init_values)
58 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
59 | self.norm2 = norm_layer(in_dim)
60 | self.mlp = Mlp(in_features=in_dim, hidden_features=int(in_dim*mlp_ratio), out_features=in_dim, act_layer=act_layer, drop=drop)
61 |
62 | def forward(self, x):
63 | x = self.attn(self.norm1(x))
64 | x = x + self.drop_path(self.mlp(self.norm2(x)))
65 | return x
66 |
--------------------------------------------------------------------------------
/core/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .build_model import build_model
--------------------------------------------------------------------------------
/core/network/build_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 |
4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
5 | Licensed under the MIT License (see LICENSE for details)
6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
7 | Paper link: https://arxiv.org/abs/2203.16828
8 |
9 | """
10 |
11 | from .RestNet34 import *
12 | from .Swin_T import *
13 | from .ViTAE_S import *
14 |
15 |
16 | def build_model(model_arch, **kwargs):
17 | if model_arch == "r34":
18 | model = p3mnet_resnet34(**kwargs)
19 | elif model_arch == "swin":
20 | model = p3mnet_swin_t(**kwargs)
21 | elif model_arch == "vitae":
22 | model = p3mnet_vitae_s(**kwargs)
23 | else:
24 | print(model_arch)
25 | raise NotImplementedError
26 |
27 | return model
--------------------------------------------------------------------------------
/core/network/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 |
4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
5 | Licensed under the MIT License (see LICENSE for details)
6 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
7 | Paper link: https://arxiv.org/abs/2203.16828
8 |
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | "3x3 convolution with padding"
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=1, bias=False)
19 |
20 |
21 | def conv1x1(in_planes, out_planes, stride=1):
22 | """1x1 convolution"""
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
24 |
25 |
26 | class TFI(nn.Module):
27 | expansion = 1
28 | def __init__(self, planes,stride=1):
29 | super(TFI, self).__init__()
30 | middle_planes = int(planes/2)
31 | self.transform = conv1x1(planes, middle_planes)
32 | self.conv1 = conv3x3(middle_planes*3, planes, stride)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.stride = stride
36 |
37 | def forward(self, input_s_guidance, input_m_decoder, input_m_encoder):
38 | input_s_guidance_transform = self.transform(input_s_guidance)
39 | input_m_decoder_transform = self.transform(input_m_decoder)
40 | input_m_encoder_transform = self.transform(input_m_encoder)
41 | x = torch.cat((input_s_guidance_transform,input_m_decoder_transform,input_m_encoder_transform),1)
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 | return out
46 |
47 |
48 | class SBFI(nn.Module):
49 | def __init__(self, planes,planes2,stride=1):
50 | # planes2, min dim
51 | super(SBFI, self).__init__()
52 | self.stride = stride
53 | self.transform1 = conv1x1(planes, int(planes/2))
54 | self.transform2 = conv1x1(planes2, int(planes/2))
55 | self.maxpool = nn.MaxPool2d(2, stride=stride)
56 | self.conv1 = conv3x3(planes, planes, 1)
57 | self.bn1 = nn.BatchNorm2d(planes)
58 | self.relu = nn.ReLU(inplace=True)
59 |
60 | def forward(self, input_m_decoder,e0):
61 | input_m_decoder_transform = self.transform1(input_m_decoder)
62 | e0_maxpool = self.maxpool(e0)
63 | e0_transform = self.transform2(e0_maxpool)
64 | x = torch.cat((input_m_decoder_transform,e0_transform),1)
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 | out = out+input_m_decoder
69 | return out
70 |
71 |
72 | class DBFI(nn.Module):
73 | def __init__(self, planes,planes2,stride=1):
74 | # planes2, max dim
75 | super(DBFI, self).__init__()
76 | self.stride = stride
77 | self.transform1 = conv1x1(planes, int(planes/2))
78 | self.transform2 = conv1x1(planes2, int(planes/2))
79 | self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
80 | self.conv1 = conv3x3(planes, planes, 1)
81 | self.bn1 = nn.BatchNorm2d(planes)
82 | self.relu = nn.ReLU(inplace=True)
83 | self.conv2 = conv3x3(planes, 3, 1)
84 | self.upsample2 = nn.Upsample(scale_factor=int(32/stride), mode='bilinear')
85 |
86 | def forward(self, input_s_decoder,e4):
87 | input_s_decoder_transform = self.transform1(input_s_decoder)
88 | e4_transform = self.transform2(e4)
89 | e4_upsample = self.upsample(e4_transform)
90 | x = torch.cat((input_s_decoder_transform,e4_upsample),1)
91 | out = self.conv1(x)
92 | out = self.bn1(out)
93 | out = self.relu(out)
94 | out = out+input_s_decoder
95 | out_side = self.conv2(out)
96 | out_side = self.upsample2(out_side)
97 | return out, out_side
98 |
--------------------------------------------------------------------------------
/core/test.py:
--------------------------------------------------------------------------------
1 | """
2 | 对test来说,和train是完全分开的。
3 | build model是不一样的。因为cp的原因,train model和test model不一样。
4 | path问题,path是不能和train一起公用的
5 | 另外很多args是train里没有的
6 | 但是test又需要用到config.MODEL里的参数。
7 |
8 | """
9 | import torch
10 | import torch.nn.functional as F
11 | import ipdb
12 | import cv2
13 | import argparse
14 | import numpy as np
15 | from tqdm import tqdm
16 | from PIL import Image
17 | from skimage.transform import resize
18 | from torchvision import transforms
19 | import logging
20 | import warnings
21 | import yaml
22 | import torch.distributed as dist
23 |
24 | from config import *
25 | from util import *
26 | from evaluate import *
27 | from network.build_model import build_model
28 | from config_yacs import get_config, CN
29 | from logger import create_logger
30 |
31 |
32 | def get_args():
33 | parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
34 | parser.add_argument('--tag', type=str, required=True)
35 | parser.add_argument('--test_dataset', type=str, required=True, choices=VALID_TEST_DATA_CHOICE, help="which dataset to test")
36 | parser.add_argument('--test_ckpt', type=str, default='', required=False, help="path of model to use")
37 | parser.add_argument('--test_method', type=str, required=False, help="which dataset to test")
38 | parser.add_argument('--fast_test', action='store_true', default=False, help='skip conn and grad metrics for fast test')
39 | parser.add_argument('--test_privacy', action='store_true', default=False, help='test on the privacy content')
40 | parser.add_argument('--save_result', action='store_true')
41 | args, _ = parser.parse_known_args()
42 |
43 | if args.test_dataset == 'VAL500NP':
44 | args.test_privacy = False
45 | warnings.warn("NO FACEMASK FOR VAL500NP")
46 |
47 | args.existing_cfg = np.load(os.path.join(CKPT_SAVE_FOLDER, args.tag, 'args.npy'), allow_pickle=True).item()
48 | if type(args.existing_cfg) != CN:
49 | new_cfg = CN()
50 | new_cfg.TAG = args.tag
51 | new_cfg.MODEL = CN()
52 | new_cfg.MODEL.TYPE = args.existing_cfg.arch
53 | args.existing_cfg = new_cfg
54 |
55 | config = get_config(args)
56 | return args, config
57 |
58 | def inference_once(config, model, scale_img, scale_trimap=None):
59 | if torch.cuda.device_count() > 0:
60 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1).cuda()
61 | else:
62 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1)
63 | input_t = tensor_img
64 | input_t = input_t/255.0
65 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
66 | std=[0.229, 0.224, 0.225])
67 | input_t = normalize(input_t)
68 | input_t = input_t.unsqueeze(0)
69 | pred_global, pred_local, pred_fusion = model(input_t)[:3]
70 | pred_global = pred_global.data.cpu().numpy()
71 | pred_global = gen_trimap_from_segmap_e2e(pred_global)
72 | pred_local = pred_local.data.cpu().numpy()[0,0,:,:]
73 | pred_fusion = pred_fusion.data.cpu().numpy()[0,0,:,:]
74 | return pred_global, pred_local, pred_fusion
75 |
76 | def inference_img_modnet(config, model, img, *args):
77 | im_h, im_w, c = img.shape
78 |
79 | if config.TEST.TEST_METHOD=='RESIZE512':
80 | ref_size = 512
81 | if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
82 | if im_w >= im_h:
83 | im_rh = ref_size
84 | im_rw = int(im_w / im_h * ref_size)
85 | elif im_w < im_h:
86 | im_rw = ref_size
87 | im_rh = int(im_h / im_w * ref_size)
88 | else:
89 | im_rh = im_h
90 | im_rw = im_w
91 |
92 | im_rw = im_rw - im_rw % 32
93 | im_rh = im_rh - im_rh % 32
94 |
95 | scale_img = resize(img,(im_rh, im_rw))
96 |
97 | scale_img = scale_img*255.0
98 |
99 | elif config.TEST.TEST_METHOD=='ORIGIN':
100 | im_rh = min(MAX_SIZE_H, im_h - im_h % 32)
101 | im_rw = min(MAX_SIZE_W, im_w - im_w % 32)
102 |
103 | scale_img = resize(img, (im_rh, im_rw))
104 | scale_img = scale_img * 255.0
105 | else:
106 | raise NotImplementedError
107 |
108 | tensor_img = torch.from_numpy(scale_img.astype(np.float32)[:, :, :]).permute(2, 0, 1).cuda()
109 | tensor_img = (tensor_img/255.0).unsqueeze(0)
110 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
111 | std=[0.229, 0.224, 0.225])
112 | tensor_img = normalize(tensor_img)
113 |
114 | _,_,pred= model(tensor_img, inference=True)
115 | pred = pred.squeeze()
116 | pred = pred.cpu().data.numpy()
117 | pred = resize(pred,(im_h,im_w))
118 |
119 | return pred
120 |
121 | def inference_img_p3m(config, model, img, *args):
122 | h, w, c = img.shape
123 | new_h = min(MAX_SIZE_H, h - (h % 32))
124 | new_w = min(MAX_SIZE_W, w - (w % 32))
125 |
126 | if config.TEST.TEST_METHOD=='HYBRID':
127 | global_ratio = 1/2
128 | local_ratio = 1
129 | resize_h = int(h*global_ratio)
130 | resize_w = int(w*global_ratio)
131 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32))
132 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32))
133 | scale_img = resize(img,(new_h,new_w))*255.0
134 | pred_coutour_1, pred_retouching_1, pred_fusion_1 = inference_once(config, model, scale_img)
135 | torch.cuda.empty_cache()
136 | pred_coutour_1 = resize(pred_coutour_1,(h,w))*255.0
137 | resize_h = int(h*local_ratio)
138 | resize_w = int(w*local_ratio)
139 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32))
140 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32))
141 | scale_img = resize(img,(new_h,new_w))*255.0
142 | pred_coutour_2, pred_retouching_2, pred_fusion_2 = inference_once(config, model, scale_img)
143 | torch.cuda.empty_cache()
144 | pred_retouching_2 = resize(pred_retouching_2,(h,w))
145 | pred_fusion = get_masked_local_from_global_test(pred_coutour_1, pred_retouching_2)
146 | return pred_fusion
147 | elif config.TEST.TEST_METHOD=='RESIZE':
148 | resize_h = int(h/2)
149 | resize_w = int(w/2)
150 | new_h = min(MAX_SIZE_H, resize_h - (resize_h % 32))
151 | new_w = min(MAX_SIZE_W, resize_w - (resize_w % 32))
152 | scale_img = resize(img,(new_h,new_w))*255.0
153 | pred_global, pred_local, pred_fusion = inference_once(config, model, scale_img)
154 | pred_local = resize(pred_local,(h,w))
155 | pred_global = resize(pred_global,(h,w))*255.0
156 | pred_fusion = resize(pred_fusion,(h,w))
157 | return pred_fusion
158 | else:
159 | raise NotImplementedError()
160 |
161 | def test_p3m10k(config, model, logger):
162 | if torch.cuda.device_count() > 0:
163 | torch.cuda.empty_cache()
164 |
165 | arch_predict_dict = {
166 | 'r34': inference_img_p3m,
167 | 'swin': inference_img_p3m,
168 | 'vitae': inference_img_p3m,
169 | }
170 |
171 | ############################
172 | # Some initial setting for paths
173 | ############################
174 | if config.TEST.DATASET == 'VAL500P':
175 | val_option = 'P3M_500_P'
176 | elif config.TEST.DATASET == 'VAL500NP':
177 | val_option = 'P3M_500_NP'
178 | else:
179 | val_option = config.TEST.DATASET
180 | ORIGINAL_PATH = DATASET_PATHS_DICT['P3M10K'][val_option]['ORIGINAL_PATH']
181 | MASK_PATH = DATASET_PATHS_DICT['P3M10K'][val_option]['MASK_PATH']
182 | TRIMAP_PATH = DATASET_PATHS_DICT['P3M10K'][val_option]['TRIMAP_PATH']
183 | if config.TEST.TEST_PRIVACY:
184 | PRIVACY_PATH = DATASET_PATHS_DICT['P3M10K'][val_option]['PRIVACY_MASK_PATH']
185 | ############################
186 | # Start testing
187 | ############################
188 | sad_diffs = 0.
189 | mse_diffs = 0.
190 | mad_diffs = 0.
191 | sad_trimap_diffs = 0.
192 | mse_trimap_diffs = 0.
193 | mad_trimap_diffs = 0.
194 | sad_fg_diffs = 0.
195 | sad_bg_diffs = 0.
196 | conn_diffs = 0.
197 | grad_diffs = 0.
198 | sad_privacy_diffs = 0. # for test privacy only
199 | mse_privacy_diffs = 0. # for test privacy only
200 | mad_privacy_diffs = 0. # for test privacy only
201 | if config.TEST.SAVE_RESULT:
202 | result_dir = os.path.join(TEST_RESULT_FOLDER, 'test_{}_{}_{}_{}'.format(config.TAG, config.TEST.DATASET, config.TEST.TEST_METHOD, config.TEST.CKPT_NAME.replace('/', '-')))
203 | refresh_folder(result_dir)
204 | model.eval()
205 | img_list = listdir_nohidden(ORIGINAL_PATH)
206 | total_number = len(img_list)
207 | logger.info("===============================")
208 | logger.info(f'====> Start Testing\n\t--Dataset: {config.TEST.DATASET}\n\t--Test: {config.TEST.TEST_METHOD}\n\t--Number: {total_number}')
209 |
210 | if config.TAG.startswith("debug"):
211 | img_list = img_list[:10]
212 |
213 | if dist.is_initialized():
214 | img_list = img_list[GET_RANK()::GET_WORLD_SIZE()]
215 | total_number = len(img_list)
216 | print('rank {}/{}, total num: {}'.format(GET_RANK(), GET_WORLD_SIZE(), total_number))
217 |
218 | for img_name in tqdm(img_list):
219 | img_path = ORIGINAL_PATH+img_name
220 | alpha_path = MASK_PATH+extract_pure_name(img_name)+'.png'
221 | trimap_path = TRIMAP_PATH+extract_pure_name(img_name)+'.png'
222 | img = np.array(Image.open(img_path))
223 | trimap = np.array(Image.open(trimap_path))
224 | alpha = np.array(Image.open(alpha_path))/255.
225 | img = img[:,:,:3] if img.ndim>2 else img
226 | trimap = trimap[:,:,0] if trimap.ndim>2 else trimap
227 | alpha = alpha[:,:,0] if alpha.ndim>2 else alpha
228 |
229 | if config.TEST.TEST_PRIVACY:
230 | privacy_path = PRIVACY_PATH+extract_pure_name(img_name)+'.png'
231 | privacy = np.array(Image.open(privacy_path))
232 | privacy = privacy[:, :, 0] if privacy.ndim>2 else privacy
233 |
234 | with torch.no_grad():
235 | predict = arch_predict_dict.get(config.MODEL.TYPE, inference_img_p3m)(config, model, img)
236 |
237 | # test on whole image and trimap area
238 | sad_trimap_diff, mse_trimap_diff, mad_trimap_diff = calculate_sad_mse_mad(predict, alpha, trimap)
239 | sad_diff, mse_diff, mad_diff = calculate_sad_mse_mad_whole_img(predict, alpha)
240 | sad_fg_diff, sad_bg_diff = calculate_sad_fgbg(predict, alpha, trimap)
241 | if config.TEST.FAST_TEST:
242 | conn_diff = -1
243 | grad_diff = -1
244 | else:
245 | conn_diff = compute_connectivity_loss_whole_image(predict, alpha)
246 | grad_diff = compute_gradient_whole_image(predict, alpha)
247 |
248 | # test on privacy area
249 | if config.TEST.TEST_PRIVACY:
250 | sad_privacy_diff, mse_privacy_diff, mad_privacy_diff = calculate_sad_mse_mad_privacy(predict, alpha, privacy)
251 | else:
252 | sad_privacy_diff = -1
253 | mse_privacy_diff = -1
254 | mad_privacy_diff = -1
255 |
256 | logger.info(f"[{img_list.index(img_name)}/{total_number}]\nImage:{img_name}\nsad:{sad_diff}\nmse:{mse_diff}\nmad:{mad_diff}\nsad_trimap:{sad_trimap_diff}\nmse_trimap:{mse_trimap_diff}\nmad_trimap:{mad_trimap_diff}\nsad_fg:{sad_fg_diff}\nsad_bg:{sad_bg_diff}\nconn:{conn_diff}\ngrad:{grad_diff}\nsad_privacy:{sad_privacy_diff}\nmse_privacy:{mse_privacy_diff}\nmad_privacy:{mad_privacy_diff}\n-----------")
257 | sad_diffs += sad_diff
258 | mse_diffs += mse_diff
259 | mad_diffs += mad_diff
260 | mse_trimap_diffs += mse_trimap_diff
261 | sad_trimap_diffs += sad_trimap_diff
262 | mad_trimap_diffs += mad_trimap_diff
263 | sad_fg_diffs += sad_fg_diff
264 | sad_bg_diffs += sad_bg_diff
265 | conn_diffs += conn_diff
266 | grad_diffs += grad_diff
267 | sad_privacy_diffs += sad_privacy_diff
268 | mse_privacy_diffs += mse_privacy_diff
269 | mad_privacy_diffs += mad_privacy_diff
270 |
271 | if config.TEST.SAVE_RESULT:
272 | save_test_result(os.path.join(result_dir, extract_pure_name(img_name)+'.png'),predict)
273 |
274 | res_dict = {}
275 | logger.info("===============================")
276 | logger.info(f"Testing numbers: {total_number}")
277 | # res_dict['number'] = total_number
278 | logger.info("SAD: {}".format(sad_diffs / total_number))
279 | res_dict['SAD'] = sad_diffs / total_number
280 | logger.info("MSE: {}".format(mse_diffs / total_number))
281 | res_dict['MSE'] = mse_diffs / total_number
282 | logger.info("MAD: {}".format(mad_diffs / total_number))
283 | res_dict['MAD'] = mad_diffs / total_number
284 | logger.info("SAD TRIMAP: {}".format(sad_trimap_diffs / total_number))
285 | res_dict['SAD_TRIMAP'] = sad_trimap_diffs / total_number
286 | logger.info("MSE TRIMAP: {}".format(mse_trimap_diffs / total_number))
287 | res_dict['MSE_TRIMAP'] = mse_trimap_diffs / total_number
288 | logger.info("MAD TRIMAP: {}".format(mad_trimap_diffs / total_number))
289 | res_dict['MAD_TRIMAP'] = mad_trimap_diffs / total_number
290 | logger.info("SAD FG: {}".format(sad_fg_diffs / total_number))
291 | res_dict['SAD_FG'] = sad_fg_diffs / total_number
292 | logger.info("SAD BG: {}".format(sad_bg_diffs / total_number))
293 | res_dict['SAD_BG'] = sad_bg_diffs / total_number
294 | logger.info("CONN: {}".format(conn_diffs / total_number))
295 | res_dict['CONN'] = conn_diffs / total_number
296 | logger.info("GRAD: {}".format(grad_diffs / total_number))
297 | res_dict['GRAD'] = grad_diffs / total_number
298 | logger.info("SAD PRIVACY: {}".format(sad_privacy_diffs / total_number))
299 | res_dict['SAD_PRIVACY'] = sad_privacy_diffs / total_number
300 | logger.info("MSE PRIVACY: {}".format(mse_privacy_diffs / total_number))
301 | res_dict['MSE_PRIVACY'] = mse_privacy_diffs / total_number
302 | logger.info("MAD PRIVACY: {}".format(mad_privacy_diffs / total_number))
303 | res_dict['MAD_PRIVACY'] = mad_privacy_diffs / total_number
304 |
305 | # return int(sad_diffs/total_number)
306 | print("SAD: {}\nMSE: {}\nMAD: {}\n".format(res_dict['SAD'], res_dict['MSE'], res_dict['MAD']))
307 |
308 | if dist.is_initialized():
309 | for k in res_dict.keys():
310 | res_dict[k] = torch.tensor(k) # for reduce only
311 | print('rank')
312 |
313 | return res_dict
314 |
315 | def test_samples(config, model):
316 |
317 | arch_predict_dict = {
318 | 'r34': inference_img_p3m,
319 | 'swin': inference_img_p3m,
320 | 'vitae': inference_img_p3m,
321 | }
322 |
323 | print(f'=====> Test on samples and save alpha results')
324 | model.eval()
325 | img_list = listdir_nohidden(SAMPLES_ORIGINAL_PATH)
326 | refresh_folder(SAMPLES_RESULT_ALPHA_PATH)
327 | refresh_folder(SAMPLES_RESULT_COLOR_PATH)
328 | for img_name in tqdm(img_list):
329 | img_path = SAMPLES_ORIGINAL_PATH+img_name
330 | try:
331 | img = np.array(Image.open(img_path))[:,:,:3]
332 | except Exception as e:
333 | print(f'Error: {str(e)} | Name: {img_name}')
334 | h, w, c = img.shape
335 | if min(h, w)>SHORTER_PATH_LIMITATION:
336 | if h>=w:
337 | new_w = SHORTER_PATH_LIMITATION
338 | new_h = int(SHORTER_PATH_LIMITATION*h/w)
339 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
340 | else:
341 | new_h = SHORTER_PATH_LIMITATION
342 | new_w = int(SHORTER_PATH_LIMITATION*w/h)
343 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
344 |
345 | with torch.no_grad():
346 | if torch.cuda.device_count() > 0:
347 | torch.cuda.empty_cache()
348 | predict = arch_predict_dict.get(config.MODEL.TYPE, inference_img_p3m)(config, model, img)
349 |
350 | composite = generate_composite_img(img, predict)
351 | cv2.imwrite(os.path.join(SAMPLES_RESULT_COLOR_PATH, extract_pure_name(img_name)+'.png'),composite)
352 | predict = predict*255.0
353 | predict = cv2.resize(predict, (w, h), interpolation=cv2.INTER_LINEAR)
354 | cv2.imwrite(os.path.join(SAMPLES_RESULT_ALPHA_PATH, extract_pure_name(img_name)+'.png'),predict.astype(np.uint8))
355 |
356 | def load_model_and_deploy(config):
357 |
358 | ### build model
359 | model = build_model(config.MODEL.TYPE)
360 |
361 | ### load ckpt
362 | ckpt_path = os.path.join(CKPT_SAVE_FOLDER, '{}/{}.pth'.format(config.TAG, config.TEST.CKPT_NAME))
363 | if torch.cuda.device_count()==0:
364 | print(f'Running on CPU...')
365 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
366 | model.load_state_dict(ckpt['state_dict'], strict=True)
367 | else:
368 | print(f'Running on GPU with CUDA...')
369 | ckpt = torch.load(ckpt_path)
370 | model.load_state_dict(ckpt['state_dict'], strict=True)
371 | model = model.cuda()
372 | model.cuda()
373 |
374 | ### Test
375 | if config.TEST.DATASET=='SAMPLES':
376 | test_samples(config, model)
377 | elif config.TEST.DATASET in VALID_TEST_DATASET_CHOICE:
378 | logname = 'test_{}_{}_{}_{}'.format(config.TAG, config.TEST.DATASET, config.TEST.TEST_METHOD, config.TEST.CKPT_NAME.replace('/', '-'))
379 | logging_filename = TEST_LOGS_FOLDER+logname+'.log'
380 | logger = create_logger(logging_filename)
381 | test_p3m10k(config, model, logger)
382 | else:
383 | print('Please input the correct dataset_choice (SAMPLES, P3M_500_P or P3M_500_NP).')
384 |
385 | if __name__ == '__main__':
386 | args, config = get_args()
387 | load_model_and_deploy(config)
388 |
--------------------------------------------------------------------------------
/core/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 | Inferernce file.
4 |
5 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | Licensed under the MIT License (see LICENSE for details)
7 | Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | """
11 | import os
12 | import ipdb
13 | import random
14 | import argparse
15 | import torch
16 | import torch.optim as optim
17 | from torch.autograd import Variable
18 | from torch.utils.data import DataLoader
19 | import torch.backends.cudnn as cudnn
20 | import numpy as np
21 | import datetime
22 | import wandb
23 | import shutil
24 | from tqdm import tqdm
25 | from yacs.config import CfgNode as CN
26 |
27 | from config import *
28 | from util import *
29 | from evaluate import *
30 | from test import test_p3m10k
31 | from data import MattingDataset, MattingTransform
32 | from network.build_model import build_model
33 | from config_yacs import get_config
34 | from logger import create_logger
35 |
36 | ######### Parsing arguments #########
37 | def get_args():
38 | parser = argparse.ArgumentParser(description='Arguments for the training purpose.')
39 | parser.add_argument('--cfg', type=str, help='path to config file', )
40 | parser.add_argument(
41 | "--opts",
42 | help="Modify config options by adding 'KEY VALUE' pairs. ",
43 | default=None,
44 | nargs='+',
45 | )
46 | parser.add_argument('--arch', type=str, help="backbone architecture of the model")
47 | parser.add_argument('--train_from_scratch', action='store_true')
48 | parser.add_argument('--nEpochs', type=int, default=50, help='number of epochs to train for, 500 for ORI-Track and 100 for COMP-Track')
49 | parser.add_argument('--lr', type=float, default=0.00001, help='Learning Rate. Default=0.00001')
50 | parser.add_argument('--warmup_nEpochs', type=int, default=0, help='epochs for warming up')
51 | parser.add_argument('--lr_decay', action='store_true', default=None, help='whehter to use lr decay')
52 | parser.add_argument('--clip_grad', action='store_true', default=None, help='whether to clip gradient')
53 | parser.add_argument('--threads', type=int, default=8, help='number of threads for data loader to use')
54 | parser.add_argument('--batchSize', type=int, default=8, help='training batch size')
55 | parser.add_argument('--tag', type=str, default='debug')
56 | parser.add_argument('--enable_wandb', action='store_true', default=None)
57 | parser.add_argument('--test_freq', type=int)
58 | parser.add_argument('--auto_resume', action='store_true')
59 | parser.add_argument('--test_method', default=None)
60 | parser.add_argument('--dataset', type=str)
61 | parser.add_argument('--train_set', type=str)
62 | args, _ = parser.parse_known_args()
63 |
64 | if args.tag.lower().startswith('debug'):
65 | args.enable_wandb = False
66 |
67 | if args.auto_resume:
68 | # load existing cfg
69 | args_path = os.path.join(CKPT_SAVE_FOLDER, args.tag, "args.npy")
70 | if os.path.exists(args_path):
71 | args.existing_cfg = np.load(os.path.join(CKPT_SAVE_FOLDER, args.tag, 'args.npy'), allow_pickle=True).item()
72 |
73 | config = get_config(args)
74 | if args.auto_resume:
75 | set_seed(config.SEED)
76 |
77 | print(config)
78 | print(args)
79 | return args, config
80 |
81 | def set_seed(seed=None):
82 | if seed is None:
83 | seed = torch.seed()
84 |
85 | random.seed(seed)
86 | np.random.seed(seed)
87 | os.environ['PYTHONHASHSEED'] = str(seed)
88 | torch.manual_seed(seed)
89 | torch.cuda.manual_seed(seed)
90 | torch.cuda.manual_seed_all(seed)
91 | cudnn.benchmark = True
92 | return seed
93 |
94 | def load_dataset(config):
95 | train_transform = MattingTransform(crop_size=config.DATA.CROP_SIZE, resize_size=config.DATA.RESIZE_SIZE)
96 | train_set = MattingDataset(config, train_transform)
97 | collate_fn = None
98 | train_loader = DataLoader(dataset=train_set, num_workers=config.DATA.NUM_WORKERS, batch_size=config.DATA.BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
99 | return train_loader
100 |
101 | def build_lr_scheduler(optimizer, total_epochs):
102 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs)
103 | return lr_scheduler
104 |
105 | def warmup_lr(initial_lr, cur_iter, total_iter):
106 | return cur_iter/total_iter*initial_lr
107 |
108 | def update_lr(cur_lr, optimizer):
109 | for param_group in optimizer.param_groups:
110 | param_group['lr'] = cur_lr
111 |
112 | def get_lr(optimizer):
113 | return optimizer.param_groups[0]['lr']
114 |
115 | def train(config, model, optimizer, train_loader, epoch, lr_scheduler, clip_grad_args):
116 | model = torch.nn.DataParallel(model).cuda()
117 | model.train()
118 | torch.cuda.empty_cache()
119 |
120 | print("===============================")
121 | print("EPOCH: {}/{}".format(epoch, config.TRAIN.EPOCHS))
122 | for iteration, batch in tqdm(enumerate(train_loader, 1)):
123 | ### update lr
124 | if config.TRAIN.WARMUP_EPOCHS > 0 and epoch <= config.TRAIN.TRAIN.WARMUP_EPOCHS:
125 | cur_lr = warmup_lr(config.TRAIN.LR, len(train_loader)*(epoch-1)+iteration, config.TRAIN.WARMUP_EPOCHS*len(train_loader))
126 | update_lr(cur_lr, optimizer)
127 | elif config.TRAIN.LR_DECAY is True and epoch > config.TRAIN.WARMUP_EPOCHS:
128 | lr_scheduler.step()
129 | cur_lr = lr_scheduler.get_lr()[0]
130 | else:
131 | cur_lr = optimizer.param_groups[0]['lr']
132 |
133 | ### get data for general model
134 | batch_new = []
135 | for item in batch:
136 | if type(item) == torch.Tensor:
137 | item = Variable(item).cuda()
138 | batch_new.append(item)
139 | [ori, mask, fg, bg, trimap] = batch_new[:5]
140 | optimizer.zero_grad()
141 |
142 | ### get model prediction
143 | if config.MODEL.CUT_AND_PASTE.TYPE.upper() == 'NONE':
144 | out_list = model(ori)
145 | else:
146 | raise NotImplementedError()
147 |
148 | ### cal loss
149 | predict_global, predict_local, predict_fusion, predict_global_side2, predict_global_side1, predict_global_side0 = out_list
150 | predict_fusion = predict_fusion.cuda()
151 | loss_global =get_crossentropy_loss(trimap, predict_global)
152 | loss_global_side2 = get_crossentropy_loss(trimap, predict_global_side2)
153 | loss_global_side1 = get_crossentropy_loss(trimap, predict_global_side1)
154 | loss_global_side0 = get_crossentropy_loss(trimap, predict_global_side0)
155 | loss_global = loss_global_side2+loss_global_side1+loss_global_side0+3*loss_global
156 | loss_local = get_alpha_loss(predict_local, mask, trimap) + get_laplacian_loss(predict_local, mask, trimap)
157 | loss_fusion_alpha = get_alpha_loss_whole_img(predict_fusion, mask) + get_laplacian_loss_whole_img(predict_fusion, mask)
158 | loss_fusion_comp = get_composition_loss_whole_img(ori, mask, fg, bg, predict_fusion)
159 | loss = loss_global/6+loss_local*2+loss_fusion_alpha*2+loss_fusion_alpha+loss_fusion_comp
160 | loss.backward()
161 |
162 | ### optimize and clip gradient
163 | if config.TRAIN.CLIP_GRAD is True:
164 | if clip_grad_args.moving_max_grad == 0:
165 | clip_grad_args.moving_max_grad = torch.nn.utils.clip_grad_norm_(model.parameters(), 1e+6).cpu().item()
166 | clip_grad_args.max_grad = clip_grad_args.moving_max_grad
167 | else:
168 | clip_grad_args.max_grad = torch.nn.utils.clip_grad_norm_(model.parameters(), 2 * clip_grad_args.moving_max_grad).cpu().item()
169 | clip_grad_args.moving_max_grad = clip_grad_args.moving_max_grad * clip_grad_args.moving_grad_moment + clip_grad_args.max_grad * (
170 | 1 - clip_grad_args.moving_grad_moment)
171 | optimizer.step()
172 |
173 | if config.ENABLE_WANDB:
174 | loss_dict = {
175 | 'train/loss': loss.item(),
176 | 'train/loss_global': loss_global.item(),
177 | 'train/loss_local': loss_local.item(),
178 | 'train/loss_fusion_alpha': loss_fusion_alpha.item(),
179 | 'train/loss_fusion_comp': loss_fusion_comp.item(),
180 | 'train/lr': cur_lr,
181 | }
182 | wandb.log(loss_dict, step=(epoch-1)*len(train_loader)+iteration, commit=False)
183 |
184 | if config.TAG.startswith("debug") and iteration > 20:
185 | break
186 |
187 | def save_latest_checkpoint(config, model, epoch, **kwargs):
188 | model_save_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG)
189 | create_folder_if_not_exists(model_save_dir)
190 | model_out_path = os.path.join(model_save_dir, 'ckpt_latest.pth')
191 | model_dict = {'state_dict':model.state_dict(), 'epoch':epoch}
192 | model_dict.update(kwargs)
193 | torch.save(model_dict, model_out_path)
194 |
195 | def save_checkpoint(config, model, epoch, prefix, **kwargs):
196 | model_save_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG)
197 | create_folder_if_not_exists(model_save_dir)
198 | model_out_path = os.path.join(model_save_dir, prefix+'_ckpt.pth')
199 | model_dict = {'state_dict':model.state_dict(), 'epoch':epoch}
200 | model_dict.update(kwargs)
201 | torch.save(model_dict, model_out_path)
202 |
203 | def save_all_ckpts(config, epoch):
204 | model_save_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG, str(epoch))
205 | create_folder_if_not_exists(model_save_dir)
206 | source_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG)
207 | file_list = os.listdir(source_dir)
208 | for name in file_list:
209 | if name.endswith('.pth'):
210 | shutil.copyfile(os.path.join(source_dir, name), os.path.join(model_save_dir, name))
211 |
212 | def save_args(config):
213 | save_dir = os.path.join(CKPT_SAVE_FOLDER, config.TAG)
214 | create_folder_if_not_exists(save_dir)
215 | np.save(os.path.join(save_dir, 'args.npy'), config)
216 |
217 | def main():
218 | _, config = get_args()
219 | save_args(config)
220 |
221 | now = datetime.datetime.now()
222 | str_time = now.strftime("%Y-%m-%d-%H:%M")
223 | logger = create_logger(os.path.join(TRAIN_LOGS_FOLDER, config.TAG+'_{}.log'.format(str_time)))
224 |
225 | if not torch.cuda.is_available():
226 | raise Exception("No GPU and cuda available, please try again")
227 | gpuNums = torch.cuda.device_count()
228 | logger.info(f'Running with GPUs and the number of GPUs: {gpuNums}')
229 |
230 | # Test configs
231 | config.defrost()
232 | config.TEST.DATASET = 'P3M_500_P'
233 | config.TEST.TEST_METHOD = config.TEST.TEST_METHOD
234 | config.TEST.CKPT_NAME = 'latest_epoch'
235 | config.TEST.FAST_TEST = True
236 | config.TEST.TEST_PRIVACY = True
237 | config.TEST.SAVE_RESULT = False
238 | config.freeze()
239 |
240 | # config wandb
241 | if config.ENABLE_WANDB:
242 | project_name = 'p3m_journal'
243 | WANDB_API_KEY = get_wandb_key(WANDB_KEY_FILE) # setup your wandb account
244 | os.environ["WANDB_API_KEY"] = WANDB_API_KEY
245 | os.environ["WANDB_DIR"] = WANDB_LOGS_FOLDER
246 |
247 | if config.AUTO_RESUME and os.path.exists(os.path.join(CKPT_SAVE_FOLDER, config.TAG, 'wandb_run_id.txt')):
248 | with open(os.path.join(CKPT_SAVE_FOLDER, config.TAG, 'wandb_run_id.txt'), 'r') as f:
249 | run_id = f.readline().strip('\n')
250 | wandb.init(id=run_id, project=project_name, resume='must')
251 | else:
252 | run_id = wandb.util.generate_id()
253 | wandb.init(project=project_name, entity='xymsh', id=run_id, resume='allow')
254 | wandb.config.update(get_wandb_config(config))
255 | wandb.run.name = '{}_{}'.format(config.TAG, str_time)
256 | with open(os.path.join(CKPT_SAVE_FOLDER, config.TAG, 'wandb_run_id.txt'), 'w') as f:
257 | f.write(run_id)
258 | logger.info('===> Enabled wandb run {} at {}'.format(run_id, WANDB_LOGS_FOLDER))
259 |
260 | # data loader
261 | logger.info('===> Load data')
262 | train_loader = load_dataset(config)
263 |
264 | # build model
265 | logger.info('===> Build the model {}'.format(config.MODEL.TYPE))
266 | model = build_model(config.MODEL.TYPE).cuda()
267 | start_epoch = config.TRAIN.START_EPOCH
268 |
269 | # build optimizer
270 | logger.info('===> Initialize optimizer {} and lr scheduler {}'.format(config.TRAIN.OPTIMIZER.TYPE, config.TRAIN.LR_DECAY))
271 | if config.TRAIN.OPTIMIZER.TYPE.upper() == 'ADAM':
272 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.TRAIN.LR)
273 | else:
274 | raise NotImplementedError
275 | lr_scheduler = build_lr_scheduler(optimizer, config.TRAIN.EPOCHS-config.TRAIN.WARMUP_EPOCHS) if config.TRAIN.LR_DECAY is True else None
276 | # init clip grad params
277 | clip_grad_args = CN()
278 | clip_grad_args.moving_max_grad = 0.0
279 | clip_grad_args.moving_grad_moment = 0.999
280 | clip_grad_args.max_grad = 0.0
281 |
282 | # train parameters
283 | best_result = {} # key format: "[metric]_[dataset]" and "[metric]_[dataset]_epoch"
284 | errors = {}
285 |
286 | # auto resume
287 | if config.AUTO_RESUME:
288 | try:
289 | # load latest ckpt
290 | ckpt_path = os.path.join(CKPT_SAVE_FOLDER, config.TAG, 'ckpt_latest.pth')
291 | ckpt_dict = torch.load(ckpt_path)
292 |
293 | model.load_state_dict(ckpt_dict['state_dict'])
294 | start_epoch = ckpt_dict['epoch'] + 1
295 | optimizer.load_state_dict(ckpt_dict['optimizer'])
296 | if config.TRAIN.LR_DECAY:
297 | raise NotImplementedError
298 |
299 | clip_grad_args = ckpt_dict['clip_grad_args']
300 | best_result = ckpt_dict['best_result']
301 | logger.info('===> Auto resume succeeded')
302 | del ckpt_dict
303 | except:
304 | pass
305 |
306 | # start training
307 | logger.info('===> Start Training')
308 | for epoch in range(start_epoch, config.TRAIN.EPOCHS + 1):
309 | logger.info("TRAIN Epoch: {}/{}, Warmup: {}".format(epoch, config.TRAIN.EPOCHS, epoch<=config.TRAIN.WARMUP_EPOCHS))
310 | train(config, model, optimizer, train_loader, epoch, lr_scheduler, clip_grad_args)
311 |
312 | if (config.TEST_FREQ > 0 and epoch % config.TEST_FREQ == 0) or (epoch == config.TRAIN.EPOCHS):
313 | logger.info("TEST Epoch: {}/{}".format(epoch, config.TRAIN.EPOCHS))
314 |
315 | for test_dataset_choice in ["P3M_500_P", "P3M_500_NP"]:
316 | config.defrost()
317 | config.TEST.DATASET = test_dataset_choice
318 | if test_dataset_choice == 'P3M_500_NP':
319 | config.TEST.TEST_PRIVACY = False
320 | else:
321 | config.TEST.TEST_PRIVACY = True
322 | config.freeze()
323 | errors = test_p3m10k(config, model, logger)
324 |
325 | for m in ['SAD', 'SAD_PRIVACY']:
326 | m_dataset = "{}_{}".format(m, test_dataset_choice)
327 | if m in errors.keys():
328 | if m_dataset not in best_result.keys() or errors[m] <= best_result[m_dataset]:
329 | best_result[m_dataset] = errors[m]
330 | best_result["{}_epoch".format(m_dataset)] = epoch
331 | save_checkpoint(config, model, epoch, "best_{}".format(m_dataset), best_result=best_result)
332 |
333 | if config.ENABLE_WANDB:
334 | log_errors = {}
335 | for k, v in errors.items():
336 | log_errors['test_{}/'.format(test_dataset_choice)+k] = v
337 | wandb.log(log_errors, step=epoch*len(train_loader), commit=False)
338 |
339 | if config.ENABLE_WANDB:
340 | # record the best result
341 | wandb.run.summary.update(best_result)
342 | for k, v in best_result.items():
343 | log_errors['best_result_{}/{}'.format(config.TEST.TEST_METHOD, k)] = v
344 | wandb.log(log_errors, step=epoch*len(train_loader), commit=False)
345 |
346 | save_latest_checkpoint(config, model, epoch, optimizer=optimizer.state_dict(), clip_grad_args=clip_grad_args)
347 |
348 |
349 | if config.AUTO_RESUME:
350 | break
351 | if epoch in [50, 100]:
352 | save_all_ckpts(config, epoch)
353 |
354 |
355 | if __name__ == "__main__":
356 | main()
--------------------------------------------------------------------------------
/core/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Rethinking Portrait Matting with Privacy Preserving
3 |
4 | Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
5 | Licensed under the MIT License (see LICENSE for details)
6 | Github repo: https://github.com/ViTAE-Transformer/ViTAE-Transformer-Matting.git
7 | Paper link: https://arxiv.org/abs/2203.16828
8 |
9 | """
10 | import os
11 | import shutil
12 | import cv2
13 | import numpy as np
14 | import torch
15 | import torch.distributed as dist
16 | import glob
17 | import functools
18 | from torchvision import transforms
19 | from config import *
20 |
21 |
22 | def get_wandb_config(config):
23 | wandb_args = {}
24 | wandb_args['model'] = config.MODEL.TYPE
25 | wandb_args['logname'] = config.TAG
26 | return wandb_args
27 |
28 | def get_wandb_key(filename):
29 | f = open(filename)
30 | keys = f.readline().strip()
31 | f.close()
32 | return keys
33 |
34 | ##########################
35 | ### Pure functions
36 | ##########################
37 |
38 | def GET_RANK():
39 | if dist.is_initialized():
40 | return dist.get_rank()
41 | else:
42 | return 0
43 |
44 | def GET_WORLD_SIZE():
45 | if dist.is_initialized():
46 | return dist.get_world_size()
47 | else:
48 | 1
49 |
50 | def is_main_process():
51 | if not dist.is_initialized():
52 | return True
53 | if dist.is_initialized() and dist.get_rank() == 0:
54 | return True
55 | return False
56 |
57 | def extract_pure_name(original_name):
58 | pure_name, extention = os.path.splitext(original_name)
59 | return pure_name
60 |
61 | def listdir_nohidden(path):
62 | new_list = []
63 | for f in os.listdir(path):
64 | if not f.startswith('.'):
65 | new_list.append(f)
66 | new_list.sort()
67 | return new_list
68 |
69 | def create_folder_if_not_exists(folder_path):
70 | if not os.path.exists(folder_path):
71 | os.makedirs(folder_path)
72 |
73 | def refresh_folder(folder_path):
74 | if not os.path.exists(folder_path):
75 | os.makedirs(folder_path)
76 | else:
77 | shutil.rmtree(folder_path)
78 | os.makedirs(folder_path)
79 |
80 | def save_test_result(save_dir, predict):
81 | predict = (predict * 255).astype(np.uint8)
82 | cv2.imwrite(save_dir, predict)
83 |
84 | def generate_composite_img(img, alpha_channel):
85 | b_channel, g_channel, r_channel = cv2.split(img)
86 | b_channel = b_channel * alpha_channel
87 | g_channel = g_channel * alpha_channel
88 | r_channel = r_channel * alpha_channel
89 | alpha_channel = (alpha_channel*255).astype(b_channel.dtype)
90 | img_BGRA = cv2.merge((r_channel,g_channel,b_channel,alpha_channel))
91 | return img_BGRA
92 |
93 | ##########################
94 | ### for dataset processing
95 | ##########################
96 | def trim_img(img):
97 | if img.ndim>2:
98 | img = img[:,:,0]
99 | return img
100 |
101 | def gen_trimap_with_dilate(alpha, kernel_size):
102 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
103 | fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
104 | fg = np.array(np.equal(alpha, 255).astype(np.float32))
105 | dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)
106 | erode = cv2.erode(fg, kernel, iterations=1)
107 | trimap = erode *255 + (dilate-erode)*128
108 | return trimap.astype(np.uint8)
109 |
110 | def normalize_batch_torch(data_t):
111 | normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
112 | std=[0.229, 0.224, 0.225])
113 | new_data = []
114 | for i in range(data_t.shape[0]):
115 | new_data.append(normalize_transform(data_t[i]))
116 | return torch.stack(new_data, dim=0)
117 |
118 | ##########################
119 | ### Functions for fusion
120 | ##########################
121 | def gen_trimap_from_segmap_e2e(segmap):
122 | trimap = np.argmax(segmap, axis=1)[0]
123 | trimap = trimap.astype(np.int64)
124 | trimap[trimap==1]=128
125 | trimap[trimap==2]=255
126 | return trimap.astype(np.uint8)
127 |
128 | def get_masked_local_from_global(global_sigmoid, local_sigmoid):
129 | values, index = torch.max(global_sigmoid,1)
130 | index = index[:,None,:,:].float()
131 | ### index <===> [0, 1, 2]
132 | ### bg_mask <===> [1, 0, 0]
133 | bg_mask = index.clone()
134 | bg_mask[bg_mask==2]=1
135 | bg_mask = 1- bg_mask
136 | ### trimap_mask <===> [0, 1, 0]
137 | trimap_mask = index.clone()
138 | trimap_mask[trimap_mask==2]=0
139 | ### fg_mask <===> [0, 0, 1]
140 | fg_mask = index.clone()
141 | fg_mask[fg_mask==1]=0
142 | fg_mask[fg_mask==2]=1
143 | fusion_sigmoid = local_sigmoid*trimap_mask+fg_mask
144 | return fusion_sigmoid
145 |
146 | def get_masked_local_from_global_test(global_result, local_result):
147 | weighted_global = np.ones(global_result.shape)
148 | weighted_global[global_result==255] = 0
149 | weighted_global[global_result==0] = 0
150 | fusion_result = global_result*(1.-weighted_global)/255+local_result*weighted_global
151 | return fusion_result
152 |
153 | #######################################
154 | ### Function to generate training data
155 | #######################################
156 |
157 | def generate_paths_for_dataset(dataset="P3M10K", trainset="TRAIN"):
158 | ORI_PATH = DATASET_PATHS_DICT[dataset][trainset]['ORIGINAL_PATH']
159 | MASK_PATH = DATASET_PATHS_DICT[dataset][trainset]['MASK_PATH']
160 | FG_PATH = DATASET_PATHS_DICT[dataset][trainset]['FG_PATH']
161 | BG_PATH = DATASET_PATHS_DICT[dataset][trainset]['BG_PATH']
162 | FACEMASK_PATH = DATASET_PATHS_DICT[dataset][trainset]['PRIVACY_MASK_PATH']
163 | mask_list = listdir_nohidden(MASK_PATH)
164 | total_number = len(mask_list)
165 | paths_list = []
166 | for mask_name in mask_list:
167 | path_list = []
168 | ori_path = ORI_PATH+extract_pure_name(mask_name)+'.jpg'
169 | mask_path = MASK_PATH+mask_name
170 | fg_path = FG_PATH+mask_name
171 | bg_path = BG_PATH+extract_pure_name(mask_name)+'.jpg'
172 | facemask_path = FACEMASK_PATH+mask_name
173 | path_list.append(ori_path)
174 | path_list.append(mask_path)
175 | path_list.append(fg_path)
176 | path_list.append(bg_path)
177 | path_list.append(facemask_path)
178 | paths_list.append(path_list)
179 | return paths_list
180 |
181 |
182 | def get_valid_names(*dirs):
183 | # Extract valid names
184 | name_sets = [get_name_set(d) for d in dirs]
185 |
186 | # Reduce
187 | def _join_and(a, b):
188 | return a & b
189 |
190 | valid_names = list(functools.reduce(_join_and, name_sets))
191 | if len(valid_names) == 0:
192 | return None
193 |
194 | valid_names.sort()
195 |
196 | return valid_names
197 |
198 | def get_name_set(dir_name):
199 | path_list = glob.glob(os.path.join(dir_name, '*'))
200 | name_set = set()
201 | for path in path_list:
202 | name = os.path.basename(path)
203 | name = os.path.splitext(name)[0]
204 | if name.startswith(".DS"): continue
205 | name_set.add(name)
206 | return name_set
207 |
208 | def list_abspath(data_dir, ext, data_list):
209 | return [os.path.join(data_dir, name + ext)
210 | for name in data_list]
--------------------------------------------------------------------------------
/demo/README.md:
--------------------------------------------------------------------------------
1 | # More Results Demo
2 |
3 | We present more results of our P3M-Net(ViTAE-S) model on P3M-10k dataset as follows.
4 |
5 | 

6 | 

7 | 

8 | 

9 | 

10 | 

11 | 

12 | 

13 | 

14 | 

15 | 

--------------------------------------------------------------------------------
/demo/face_obfuscation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/face_obfuscation.jpg
--------------------------------------------------------------------------------
/demo/gif/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/gif/2.gif
--------------------------------------------------------------------------------
/demo/gif/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/gif/3.gif
--------------------------------------------------------------------------------
/demo/gif/4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/gif/4.gif
--------------------------------------------------------------------------------
/demo/gif/p_3cf7997c.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/gif/p_3cf7997c.gif
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_07141906.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_07141906.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_09ba26b4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_09ba26b4.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_22fc0130.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_22fc0130.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_514ca06a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_514ca06a.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_51c916fc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_51c916fc.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_818e689d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_818e689d.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_a09f6d7a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_a09f6d7a.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_bac3c1ff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_bac3c1ff.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_bc5cfad1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_bc5cfad1.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_bd6af989.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_bd6af989.png
--------------------------------------------------------------------------------
/demo/imgs/alpha/p_d684dae3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/alpha/p_d684dae3.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_07141906.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_07141906.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_09ba26b4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_09ba26b4.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_22fc0130.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_22fc0130.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_514ca06a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_514ca06a.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_51c916fc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_51c916fc.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_818e689d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_818e689d.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_a09f6d7a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_a09f6d7a.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_bac3c1ff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_bac3c1ff.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_bc5cfad1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_bc5cfad1.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_bd6af989.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_bd6af989.png
--------------------------------------------------------------------------------
/demo/imgs/color/p_d684dae3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/color/p_d684dae3.png
--------------------------------------------------------------------------------
/demo/imgs/original/p_07141906.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_07141906.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_09ba26b4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_09ba26b4.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_22fc0130.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_22fc0130.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_514ca06a.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_514ca06a.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_51c916fc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_51c916fc.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_818e689d.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_818e689d.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_a09f6d7a.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_a09f6d7a.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_bac3c1ff.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_bac3c1ff.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_bc5cfad1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_bc5cfad1.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_bd6af989.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_bd6af989.jpg
--------------------------------------------------------------------------------
/demo/imgs/original/p_d684dae3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/imgs/original/p_d684dae3.jpg
--------------------------------------------------------------------------------
/demo/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/network.png
--------------------------------------------------------------------------------
/demo/p3m-cp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/p3m-cp.png
--------------------------------------------------------------------------------
/demo/p3m-net-variants.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/p3m-net-variants.png
--------------------------------------------------------------------------------
/demo/p3m_dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/p3m_dataset.png
--------------------------------------------------------------------------------
/demo/results1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/results1.png
--------------------------------------------------------------------------------
/demo/results2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/demo/results2.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.19.2
2 | opencv-python==4.4.0.46
3 | Pillow==8.0.0
4 | scikit-image==0.14.5
5 | scipy==1.5.3
6 | tqdm==4.51.0
7 |
--------------------------------------------------------------------------------
/samples/original/p_015cd10e.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/original/p_015cd10e.jpg
--------------------------------------------------------------------------------
/samples/original/p_0865636e.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/original/p_0865636e.jpg
--------------------------------------------------------------------------------
/samples/original/p_819ea202.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/original/p_819ea202.jpg
--------------------------------------------------------------------------------
/samples/result_alpha/p_015cd10e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_alpha/p_015cd10e.png
--------------------------------------------------------------------------------
/samples/result_alpha/p_0865636e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_alpha/p_0865636e.png
--------------------------------------------------------------------------------
/samples/result_alpha/p_819ea202.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_alpha/p_819ea202.png
--------------------------------------------------------------------------------
/samples/result_color/p_015cd10e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_color/p_015cd10e.png
--------------------------------------------------------------------------------
/samples/result_color/p_0865636e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_color/p_0865636e.png
--------------------------------------------------------------------------------
/samples/result_color/p_819ea202.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ViTAE-Transformer/P3M-Net/fbf38909d784d1a6c887e120bb14d8746fa90649/samples/result_color/p_819ea202.png
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Rethinking Portrait Matting with Privacy Preserving
4 | #
5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | # Licensed under the MIT License (see LICENSE for details)
7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 |
9 | result_dir=''
10 | alpha_dir=''
11 | trimap_dir=''
12 |
13 | python core/eval.py \
14 | --pred_dir $result_dir \
15 | --alpha_dir $alpha_dir \
16 | --trimap_dir $trimap_dir \
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Rethinking Portrait Matting with Privacy Preserving
4 | #
5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | # Licensed under the MIT License (see LICENSE for details)
7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | # Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | tag='' # run name
11 | dataset_choice='P3M_500_P'
12 | ckpt_name='ckpt_latest'
13 | test_choice='HYBRID'
14 |
15 | python core/test.py \
16 | --tag=$tag \
17 | --test_dataset=$dataset_choice \
18 | --test_ckpt=$ckpt_name \
19 | --test_method=$test_choice \
20 | --fast_test \
21 | --test_privacy
--------------------------------------------------------------------------------
/scripts/test_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Rethinking Portrait Matting with Privacy Preserving
4 | #
5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | # Licensed under the MIT License (see LICENSE for details)
7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | # Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | arch='vitae'
11 | model_path='./models/P3M-Net_ViTAE-S_trained_on_P3M-10k.pth'
12 | dataset='P3M10K'
13 | valset='P3M_500_P'
14 | test_choice='HYBRID'
15 |
16 | nickname='test_dataset'
17 | result_dir='results/'${nickname}
18 |
19 | python core/infer.py \
20 | --arch $arch \
21 | --dataset $dataset \
22 | --test_set $valset \
23 | --model_path $model_path \
24 | --test_choice $test_choice \
25 | --test_result_dir $result_dir
26 |
27 |
--------------------------------------------------------------------------------
/scripts/test_samples.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Rethinking Portrait Matting with Privacy Preserving
4 | #
5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | # Licensed under the MIT License (see LICENSE for details)
7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | # Paper link: https://arxiv.org/abs/2203.16828
9 |
10 | arch='vitae'
11 | model_path='./models/P3M-Net_ViTAE-S_trained_on_P3M-10k.pth'
12 | dataset='SAMPLES'
13 | test_choice='RESIZE'
14 |
15 | python core/infer.py \
16 | --arch $arch \
17 | --dataset $dataset \
18 | --model_path $model_path \
19 | --test_choice $test_choice \
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Rethinking Portrait Matting with Privacy Preserving
4 | #
5 | # Copyright (c) 2022, Sihan Ma (sima7436@uni.sydney.edu.au) and Jizhizi Li (jili8515@uni.sydney.edu.au)
6 | # Licensed under the MIT License (see LICENSE for details)
7 | # Github repo: https://github.com/ViTAE-Transformer/P3M-Net
8 | # Paper link: https://arxiv.org/abs/2203.16828
9 |
10 |
11 | cfg='core/configs/ViTAE_S.yaml' # change config file here
12 | nEpochs=150
13 | lr=0.00001
14 | nickname=debug_train_vitae_s # change run name here
15 | batchSize=8
16 |
17 | python core/train.py \
18 | --cfg $cfg \
19 | --tag $nickname \
20 | --nEpochs $nEpochs \
21 | --lr=$lr \
22 | --batchSize=$batchSize
--------------------------------------------------------------------------------