├── .gitignore
├── README.md
├── demo
├── image_matting
│ ├── Inference_with_ONNX
│ │ ├── README.md
│ │ ├── export_modnet_onnx.py
│ │ ├── inference_onnx.py
│ │ └── requirements.txt
│ └── colab
│ │ ├── README.md
│ │ └── inference.py
└── video_matting
│ ├── custom
│ ├── README.md
│ ├── requirements.txt
│ └── run.py
│ └── webcam
│ ├── README.md
│ ├── requirements.txt
│ └── run.py
├── doc
└── gif
│ ├── homepage_demo.gif
│ ├── image_matting_demo.gif
│ └── video_matting_demo.gif
├── pretrained
└── README.md
└── src
├── __init__.py
├── models
├── __init__.py
├── backbones
│ ├── __init__.py
│ ├── mobilenetv2.py
│ └── wrapper.py
├── modnet.py
└── onnx_modnet.py
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Temporary directories and files
2 | *.ckpt
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | env/
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *,cover
49 | .hypothesis/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # IPython Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # dotenv
82 | .env
83 |
84 | # virtualenv
85 | venv/
86 | ENV/
87 |
88 | # Spyder project settings
89 | .spyderproject
90 |
91 | # Rope project settings
92 | .ropeproject
93 |
94 |
95 | # Project files
96 | .vscode
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
MODNet: Is a Green Screen Really Necessary for Real-Time Portrait Matting?
2 |
3 |
4 |
5 |
6 | Arxiv Preprint |
7 | Supplementary Video
8 |
9 |
10 |
11 | WebCam Video Demo [Offline][Colab] | Custom Video Demo [Offline] |
12 | Image Demo [WebGUI][Colab]
13 |
14 |
15 | This is the official project of our paper Is a Green Screen Really Necessary for Real-Time Portrait Matting?
16 | MODNet is a trimap-free model for portrait matting in real time under changing scenes.
17 |
18 |
19 | ---
20 |
21 |
22 | ## News
23 | - [Jan 28 2021] Release the [code](src/trainer.py) of MODNet training iteration.
24 | - [Dec 25 2020] ***Merry Christmas!*** :christmas_tree: Release Custom Video Matting Demo [[Offline](demo/video_matting/custom)] for user videos.
25 | - [Dec 10 2020] Release WebCam Video Matting Demo [[Offline](demo/video_matting/webcam)][[Colab](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing)] and Image Matting Demo [[Colab](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing)].
26 | - [Nov 24 2020] Release [Arxiv Preprint](https://arxiv.org/pdf/2011.11961.pdf) and [Supplementary Video](https://youtu.be/PqJ3BRHX3Lc).
27 |
28 |
29 | ## Demos
30 |
31 | ### Video Matting
32 | We provide two real-time portrait video matting demos based on WebCam. When using the demo, you can move the WebCam around at will.
33 | If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/webcam) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
34 | We also provide an [offline demo](demo/video_matting/custom) that allows you to process custom videos.
35 |
36 |
37 |
38 |
39 | ### Image Matting
40 | We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting.
41 | It allows you to upload portrait images and predict/visualize/download the alpha mattes.
42 |
43 |
44 |
45 |
46 | ### Community
47 | Here we share some cool applications of MODNet built by the community.
48 |
49 | - **WebGUI for Image Matting**
50 | You can try [this WebGUI](https://gradio.app/g/modnet) (hosted on [Gradio](https://www.gradio.app/)) for portrait matting from your browser without any code!
51 |
52 |
53 | - **Colab Demo of Bokeh (Blur Background)**
54 | You can try [this Colab demo](https://colab.research.google.com/github/eyaler/avatars4all/blob/master/yarok.ipynb) (built by [@eyaler](https://github.com/eyaler)) to blur the backgroud based on MODNet!
55 |
56 |
57 | ## Code
58 | We provide the [code](src/trainer.py) of MODNet training iteration, including:
59 | - **Supervised Training**: Train MODNet on a labeled matting dataset
60 | - **SOC Adaptation**: Adapt a trained MODNet to an unlabeled dataset
61 |
62 | In the function comments, we provide examples of how to call the function.
63 |
64 |
65 | ## TODO
66 | - Release the code of One-Frame Delay (OFD)
67 | - Release PPM-100 validation benchmark (scheduled in **Feb 2021**)
68 | **NOTE**: PPM-100 is a **validation set**. Our training set will not be published
69 |
70 |
71 | ## License
72 | This project (code, pre-trained models, demos, *etc.*) is released under the [Creative Commons Attribution NonCommercial ShareAlike 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) license.
73 |
74 |
75 | ## Acknowledgement
76 | - We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project.
77 | - We thank
78 | [the Gradio team](https://github.com/gradio-app/gradio), [@eyaler](https://github.com/eyaler),
79 | for their cool applications based on MODNet.
80 |
81 |
82 | ## Citation
83 | If this work helps your research, please consider to cite:
84 |
85 | ```bibtex
86 | @article{MODNet,
87 | author = {Zhanghan Ke and Kaican Li and Yurou Zhou and Qiuhua Wu and Xiangyu Mao and Qiong Yan and Rynson W.H. Lau},
88 | title = {Is a Green Screen Really Necessary for Real-Time Portrait Matting?},
89 | journal={ArXiv},
90 | volume={abs/2011.11961},
91 | year = {2020},
92 | }
93 | ```
94 |
95 |
96 | ## Contact
97 | This project is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)).
98 | If you have any questions, please feel free to contact `kezhanghan@outlook.com`.
99 |
--------------------------------------------------------------------------------
/demo/image_matting/Inference_with_ONNX/README.md:
--------------------------------------------------------------------------------
1 | # Inference with onnxruntime
2 |
3 | Please try MODNet image matting onnx-inference demo with [Colab Notebook](https://colab.research.google.com/drive/1P3cWtg8fnmu9karZHYDAtmm1vj1rgA-f?usp=sharing)
4 |
5 | Download [modnet.onnx](https://drive.google.com/file/d/1cgycTQlYXpTh26gB9FTnthE7AvruV8hd/view?usp=sharing)
6 |
7 | ### 1. Export onnx model
8 |
9 | Run the following command:
10 | ```shell
11 | python export_modnet_onnx.py \
12 | --ckpt-path=pretrained/modnet_photographic_portrait_matting.ckpt \
13 | --output-path=modnet.onnx
14 | ```
15 |
16 |
17 | ### 2. Inference
18 |
19 | Run the following command:
20 | ```shell
21 | python inference_onnx.py \
22 | --image-path=PATH_TO_IMAGE \
23 | --output-path=matte.png \
24 | --model-path=modnet.onnx
25 | ```
26 |
27 |
--------------------------------------------------------------------------------
/demo/image_matting/Inference_with_ONNX/export_modnet_onnx.py:
--------------------------------------------------------------------------------
1 | """
2 | Export onnx model
3 |
4 | Arguments:
5 | --ckpt-path --> Path of last checkpoint to load
6 | --output-path --> path of onnx model to be saved
7 |
8 | example:
9 | python export_modnet_onnx.py \
10 | --ckpt-path=modnet_photographic_portrait_matting.ckpt \
11 | --output-path=modnet.onnx
12 |
13 | output:
14 | ONNX model with dynamic input shape: (batch_size, 3, height, width) &
15 | output shape: (batch_size, 1, height, width)
16 | """
17 | import os
18 | import argparse
19 | import torch
20 | import torch.nn as nn
21 | from torch.autograd import Variable
22 | from src.models.onnx_modnet import MODNet
23 |
24 |
25 |
26 | if __name__ == '__main__':
27 | # define cmd arguments
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument('--ckpt-path', type=str, required=True, help='path of pre-trained MODNet')
30 | parser.add_argument('--output-path', type=str, required=True, help='path of output onnx model')
31 | args = parser.parse_args()
32 |
33 | # check input arguments
34 | if not os.path.exists(args.ckpt_path):
35 | print('Cannot find checkpoint path: {0}'.format(args.ckpt_path))
36 | exit()
37 |
38 | # define model & load checkpoint
39 | modnet = MODNet(backbone_pretrained=False)
40 | modnet = nn.DataParallel(modnet).cuda()
41 | state_dict = torch.load(args.ckpt_path)
42 | modnet.load_state_dict(state_dict)
43 | modnet.eval()
44 |
45 | # prepare dummy_input
46 | batch_size = 1
47 | height = 512
48 | width = 512
49 | dummy_input = Variable(torch.randn(batch_size, 3, height, width)).cuda()
50 |
51 | # export to onnx model
52 | torch.onnx.export(modnet.module, dummy_input, args.output_path, export_params = True, opset_version=11,
53 | input_names = ['input'], output_names = ['output'],
54 | dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'},
55 | 'output': {0: 'batch_size', 2: 'height', 3: 'width'}})
56 |
--------------------------------------------------------------------------------
/demo/image_matting/Inference_with_ONNX/inference_onnx.py:
--------------------------------------------------------------------------------
1 | """
2 | Inference with onnxruntime
3 |
4 | Arguments:
5 | --image-path --> path to single input image
6 | --output-path --> paht to save generated matte
7 | --model-path --> path to onnx model file
8 |
9 | example:
10 | python inference_onnx.py \
11 | --image-path=demo.jpg \
12 | --output-path=matte.png \
13 | --model-path=modnet.onnx
14 |
15 | Optional:
16 | Generate transparent image without background
17 | """
18 | import os
19 | import argparse
20 | import cv2
21 | import numpy as np
22 | import onnx
23 | import onnxruntime
24 | from onnx import helper
25 | from PIL import Image
26 |
27 | if __name__ == '__main__':
28 | # define cmd arguments
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--image-path', type=str, help='path of input image')
31 | parser.add_argument('--output-path', type=str, help='path of output image')
32 | parser.add_argument('--model-path', type=str, help='path of onnx model')
33 | args = parser.parse_args()
34 |
35 | # check input arguments
36 | if not os.path.exists(args.image_path):
37 | print('Cannot find input path: {0}'.format(args.image_path))
38 | exit()
39 | if not os.path.exists(args.model_path):
40 | print('Cannot find model path: {0}'.format(args.model_path))
41 | exit()
42 |
43 | ref_size = 512
44 |
45 | # Get x_scale_factor & y_scale_factor to resize image
46 | def get_scale_factor(im_h, im_w, ref_size):
47 |
48 | if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
49 | if im_w >= im_h:
50 | im_rh = ref_size
51 | im_rw = int(im_w / im_h * ref_size)
52 | elif im_w < im_h:
53 | im_rw = ref_size
54 | im_rh = int(im_h / im_w * ref_size)
55 | else:
56 | im_rh = im_h
57 | im_rw = im_w
58 |
59 | im_rw = im_rw - im_rw % 32
60 | im_rh = im_rh - im_rh % 32
61 |
62 | x_scale_factor = im_rw / im_w
63 | y_scale_factor = im_rh / im_h
64 |
65 | return x_scale_factor, y_scale_factor
66 |
67 | ##############################################
68 | # Main Inference part
69 | ##############################################
70 |
71 | # read image
72 | im = cv2.imread(args.image_path)
73 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
74 |
75 | # unify image channels to 3
76 | if len(im.shape) == 2:
77 | im = im[:, :, None]
78 | if im.shape[2] == 1:
79 | im = np.repeat(im, 3, axis=2)
80 | elif im.shape[2] == 4:
81 | im = im[:, :, 0:3]
82 |
83 | # normalize values to scale it between -1 to 1
84 | im = (im - 127.5) / 127.5
85 |
86 | im_h, im_w, im_c = im.shape
87 | x, y = get_scale_factor(im_h, im_w, ref_size)
88 |
89 | # resize image
90 | im = cv2.resize(im, None, fx = x, fy = y, interpolation = cv2.INTER_AREA)
91 |
92 | # prepare input shape
93 | im = np.transpose(im)
94 | im = np.swapaxes(im, 1, 2)
95 | im = np.expand_dims(im, axis = 0).astype('float32')
96 |
97 | # Initialize session and get prediction
98 | session = onnxruntime.InferenceSession(args.model_path, None)
99 | input_name = session.get_inputs()[0].name
100 | output_name = session.get_outputs()[0].name
101 | result = session.run([output_name], {input_name: im})
102 |
103 | # refine matte
104 | matte = (np.squeeze(result[0]) * 255).astype('uint8')
105 | matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation = cv2.INTER_AREA)
106 |
107 | cv2.imwrite(args.output_path, matte)
108 |
109 | ##############################################
110 | # Optional - save png image without background
111 | ##############################################
112 |
113 | # im_PIL = Image.open(args.image_path)
114 | # matte = Image.fromarray(matte)
115 | # im_PIL.putalpha(matte) # add alpha channel to keep transparency
116 | # im_PIL.save('without_background.png')
--------------------------------------------------------------------------------
/demo/image_matting/Inference_with_ONNX/requirements.txt:
--------------------------------------------------------------------------------
1 | onnx==1.8.1
2 | onnxruntime==1.6.0
3 | opencv-python==4.5.1.48
4 | torch==1.7.1
--------------------------------------------------------------------------------
/demo/image_matting/colab/README.md:
--------------------------------------------------------------------------------
1 | ## MODNet - Portrait Image Matting Demo
2 | Please try MODNet portrait image matting demo through our [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing).
3 |
--------------------------------------------------------------------------------
/demo/image_matting/colab/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import numpy as np
5 | from PIL import Image
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torchvision.transforms as transforms
11 |
12 | from src.models.modnet import MODNet
13 |
14 |
15 | if __name__ == '__main__':
16 | # define cmd arguments
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--input-path', type=str, help='path of input images')
19 | parser.add_argument('--output-path', type=str, help='path of output images')
20 | parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet')
21 | args = parser.parse_args()
22 |
23 | # check input arguments
24 | if not os.path.exists(args.input_path):
25 | print('Cannot find input path: {0}'.format(args.input_path))
26 | exit()
27 | if not os.path.exists(args.output_path):
28 | print('Cannot find output path: {0}'.format(args.output_path))
29 | exit()
30 | if not os.path.exists(args.ckpt_path):
31 | print('Cannot find ckpt path: {0}'.format(args.ckpt_path))
32 | exit()
33 |
34 | # define hyper-parameters
35 | ref_size = 512
36 |
37 | # define image to tensor transform
38 | im_transform = transforms.Compose(
39 | [
40 | transforms.ToTensor(),
41 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
42 | ]
43 | )
44 |
45 | # create MODNet and load the pre-trained ckpt
46 | modnet = MODNet(backbone_pretrained=False)
47 | modnet = nn.DataParallel(modnet).cuda()
48 | modnet.load_state_dict(torch.load(args.ckpt_path))
49 | modnet.eval()
50 |
51 | # inference images
52 | im_names = os.listdir(args.input_path)
53 | for im_name in im_names:
54 | print('Process image: {0}'.format(im_name))
55 |
56 | # read image
57 | im = Image.open(os.path.join(args.input_path, im_name))
58 |
59 | # unify image channels to 3
60 | im = np.asarray(im)
61 | if len(im.shape) == 2:
62 | im = im[:, :, None]
63 | if im.shape[2] == 1:
64 | im = np.repeat(im, 3, axis=2)
65 | elif im.shape[2] == 4:
66 | im = im[:, :, 0:3]
67 |
68 | # convert image to PyTorch tensor
69 | im = Image.fromarray(im)
70 | im = im_transform(im)
71 |
72 | # add mini-batch dim
73 | im = im[None, :, :, :]
74 |
75 | # resize image for input
76 | im_b, im_c, im_h, im_w = im.shape
77 | if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
78 | if im_w >= im_h:
79 | im_rh = ref_size
80 | im_rw = int(im_w / im_h * ref_size)
81 | elif im_w < im_h:
82 | im_rw = ref_size
83 | im_rh = int(im_h / im_w * ref_size)
84 | else:
85 | im_rh = im_h
86 | im_rw = im_w
87 |
88 | im_rw = im_rw - im_rw % 32
89 | im_rh = im_rh - im_rh % 32
90 | im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
91 |
92 | # inference
93 | _, _, matte = modnet(im.cuda(), True)
94 |
95 | # resize and save matte
96 | matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
97 | matte = matte[0][0].data.cpu().numpy()
98 | matte_name = im_name.split('.')[0] + '.png'
99 | Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name))
100 |
--------------------------------------------------------------------------------
/demo/video_matting/custom/README.md:
--------------------------------------------------------------------------------
1 | ## MODNet - Custom Portrait Video Matting Demo
2 | This is a MODNet portrait video matting demo that allows you to process custom videos.
3 |
4 | ### 1. Requirements
5 | The basic requirements for this demo are:
6 | - Ubuntu System
7 | - Python 3+
8 |
9 |
10 | ### 2. Introduction
11 | We use ~400 unlabeled video clips (divided into ~50,000 frames) downloaded from the internet to perform SOC to adapt MODNet to the video domain. **Nonetheless, due to insufficient labeled training data (~3k labeled foregrounds), our model may still make errors in portrait semantics estimation under challenging scenes.** Besides, this demo does not currently support the OFD trick.
12 |
13 |
14 | For a better experience, please make sure your videos satisfy:
15 |
16 | * the portrait and background are distinguishable, i.e., are not similar
17 | * captured in soft and bright ambient lighting
18 | * the contents do not move too fast
19 |
20 | ### 3. Run Demo
21 | We recommend creating a new conda virtual environment to run this demo, as follow:
22 |
23 | 1. Clone the MODNet repository:
24 | ```
25 | git clone https://github.com/ZHKKKe/MODNet.git
26 | cd MODNet
27 | ```
28 |
29 | 2. Download the pre-trained model from this [link](https://drive.google.com/file/d/1Nf1ZxeJZJL8Qx9KadcYYyEmmlKhTADxX/view?usp=sharing) and put it into the folder `MODNet/pretrained/`.
30 |
31 |
32 | 3. Create a conda virtual environment named `modnet` (if it doesn't exist) and activate it. Here we use `python=3.6` as an example:
33 | ```
34 | conda create -n modnet python=3.6
35 | source activate modnet
36 | ```
37 |
38 | 4. Install the required python dependencies (please make sure your CUDA version is supported by the PyTorch version installed):
39 | ```
40 | pip install -r demo/video_matting/custom/requirements.txt
41 | ```
42 |
43 | 5. Execute the main code:
44 | ```
45 | python -m demo.video_matting.custom.run --video YOUR_VIDEO_PATH
46 | ```
47 | where `YOUR_VIDEO_PATH` is the specific path of your video.
48 | There are some optional arguments:
49 | - `--result-type (default=fg)` : fg - save the alpha matte; fg - save the foreground
50 | - `--fps (default=30)` : fps of the result video
51 |
--------------------------------------------------------------------------------
/demo/video_matting/custom/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | Pillow
3 | opencv-python
4 | torch >= 1.0.0
5 | torchvision
6 | tqdm
--------------------------------------------------------------------------------
/demo/video_matting/custom/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import argparse
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torchvision.transforms as transforms
11 |
12 | from src.models.modnet import MODNet
13 |
14 |
15 | torch_transforms = transforms.Compose(
16 | [
17 | transforms.ToTensor(),
18 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
19 | ]
20 | )
21 |
22 |
23 | def matting(video, result, alpha_matte=False, fps=30):
24 | # video capture
25 | vc = cv2.VideoCapture(video)
26 |
27 | if vc.isOpened():
28 | rval, frame = vc.read()
29 | else:
30 | rval = False
31 |
32 | if not rval:
33 | print('Failed to read the video: {0}'.format(video))
34 | exit()
35 |
36 | num_frame = vc.get(cv2.CAP_PROP_FRAME_COUNT)
37 | h, w = frame.shape[:2]
38 | if w >= h:
39 | rh = 512
40 | rw = int(w / h * 512)
41 | else:
42 | rw = 512
43 | rh = int(h / w * 512)
44 | rh = rh - rh % 32
45 | rw = rw - rw % 32
46 |
47 | # video writer
48 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
49 | video_writer = cv2.VideoWriter(result, fourcc, fps, (w, h))
50 |
51 | print('Start matting...')
52 | with tqdm(range(int(num_frame)))as t:
53 | for c in t:
54 | frame_np = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
55 | frame_np = cv2.resize(frame_np, (rw, rh), cv2.INTER_AREA)
56 |
57 | frame_PIL = Image.fromarray(frame_np)
58 | frame_tensor = torch_transforms(frame_PIL)
59 | frame_tensor = frame_tensor[None, :, :, :]
60 | if GPU:
61 | frame_tensor = frame_tensor.cuda()
62 |
63 | with torch.no_grad():
64 | _, _, matte_tensor = modnet(frame_tensor, True)
65 |
66 | matte_tensor = matte_tensor.repeat(1, 3, 1, 1)
67 | matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0)
68 | if alpha_matte:
69 | view_np = matte_np * np.full(frame_np.shape, 255.0)
70 | else:
71 | view_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0)
72 | view_np = cv2.cvtColor(view_np.astype(np.uint8), cv2.COLOR_RGB2BGR)
73 | view_np = cv2.resize(view_np, (w, h))
74 | video_writer.write(view_np)
75 |
76 | rval, frame = vc.read()
77 | c += 1
78 |
79 | video_writer.release()
80 | print('Save the result video to {0}'.format(result))
81 |
82 |
83 | if __name__ == '__main__':
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('--video', type=str, required=True, help='input video file')
86 | parser.add_argument('--result-type', type=str, default='fg', choices=['fg', 'matte'],
87 | help='matte - save the alpha matte; fg - save the foreground')
88 | parser.add_argument('--fps', type=int, default=30, help='fps of the result video')
89 |
90 | print('Get CMD Arguments...')
91 | args = parser.parse_args()
92 |
93 | if not os.path.exists(args.video):
94 | print('Cannot find the input video: {0}'.format(args.video))
95 | exit()
96 |
97 | print('Load pre-trained MODNet...')
98 | pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
99 | modnet = MODNet(backbone_pretrained=False)
100 | modnet = nn.DataParallel(modnet)
101 |
102 | GPU = True if torch.cuda.device_count() > 0 else False
103 | if GPU:
104 | print('Use GPU...')
105 | modnet = modnet.cuda()
106 | modnet.load_state_dict(torch.load(pretrained_ckpt))
107 | else:
108 | print('Use CPU...')
109 | modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
110 | modnet.eval()
111 |
112 | result = os.path.splitext(args.video)[0] + '_{0}.mp4'.format(args.result_type)
113 | alpha_matte = True if args.result_type == 'matte' else False
114 | matting(args.video, result, alpha_matte, args.fps)
115 |
--------------------------------------------------------------------------------
/demo/video_matting/webcam/README.md:
--------------------------------------------------------------------------------
1 | ## MODNet - WebCam-Based Portrait Video Matting Demo
2 | This is a MODNet portrait video matting demo based on WebCam. It will call your local WebCam and display the matting results in real time. The demo can run under CPU or GPU.
3 |
4 | ### 1. Requirements
5 | The basic requirements for this demo are:
6 | - Ubuntu System
7 | - WebCam
8 | - Python 3+
9 |
10 | **NOTE**: If your device does not satisfy the above conditions, please try our [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing).
11 |
12 |
13 | ### 2. Introduction
14 | We use ~400 unlabeled video clips (divided into ~50,000 frames) downloaded from the internet to perform SOC to adapt MODNet to the video domain. **Nonetheless, due to insufficient labeled training data (~3k labeled foregrounds), our model may still make errors in portrait semantics estimation under challenging scenes.** Besides, this demo does not currently support the OFD trick, which will be provided soon.
15 |
16 | For a better experience, please:
17 |
18 | * make sure the portrait and background are distinguishable, i.e., are not similar
19 | * run in soft and bright ambient lighting
20 | * do not be too close or too far from the WebCam
21 | * do not move too fast
22 |
23 | ### 3. Run Demo
24 | We recommend creating a new conda virtual environment to run this demo, as follow:
25 |
26 | 1. Clone the MODNet repository:
27 | ```
28 | git clone https://github.com/ZHKKKe/MODNet.git
29 | cd MODNet
30 | ```
31 |
32 | 2. Download the pre-trained model from this [link](https://drive.google.com/file/d/1Nf1ZxeJZJL8Qx9KadcYYyEmmlKhTADxX/view?usp=sharing) and put it into the folder `MODNet/pretrained/`.
33 |
34 |
35 | 3. Create a conda virtual environment named `modnet` (if it doesn't exist) and activate it. Here we use `python=3.6` as an example:
36 | ```
37 | conda create -n modnet python=3.6
38 | source activate modnet
39 | ```
40 |
41 | 4. Install the required python dependencies (please make sure your CUDA version is supported by the PyTorch version installed):
42 | ```
43 | pip install -r demo/video_matting/webcam/requirements.txt
44 | ```
45 |
46 | 5. Execute the main code:
47 | ```
48 | python -m demo.video_matting.webcam.run
49 | ```
50 |
51 | ### 4. Acknowledgement
52 | We thank [@tkianai](https://github.com/tkianai) and [@mazhar004](https://github.com/mazhar004) for their contributions to making this demo available for CPU use.
53 |
--------------------------------------------------------------------------------
/demo/video_matting/webcam/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | Pillow
3 | opencv-python
4 | torch >= 1.0.0
5 | torchvision
--------------------------------------------------------------------------------
/demo/video_matting/webcam/run.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from PIL import Image
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.transforms as transforms
8 |
9 | from src.models.modnet import MODNet
10 |
11 |
12 | torch_transforms = transforms.Compose(
13 | [
14 | transforms.ToTensor(),
15 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
16 | ]
17 | )
18 |
19 | print('Load pre-trained MODNet...')
20 | pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
21 | modnet = MODNet(backbone_pretrained=False)
22 | modnet = nn.DataParallel(modnet)
23 |
24 | GPU = True if torch.cuda.device_count() > 0 else False
25 | if GPU:
26 | print('Use GPU...')
27 | modnet = modnet.cuda()
28 | modnet.load_state_dict(torch.load(pretrained_ckpt))
29 | else:
30 | print('Use CPU...')
31 | modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
32 |
33 | modnet.eval()
34 |
35 | print('Init WebCam...')
36 | cap = cv2.VideoCapture(0)
37 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
38 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
39 |
40 | print('Start matting...')
41 | while(True):
42 | _, frame_np = cap.read()
43 | frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)
44 | frame_np = cv2.resize(frame_np, (910, 512), cv2.INTER_AREA)
45 | frame_np = frame_np[:, 120:792, :]
46 | frame_np = cv2.flip(frame_np, 1)
47 |
48 | frame_PIL = Image.fromarray(frame_np)
49 | frame_tensor = torch_transforms(frame_PIL)
50 | frame_tensor = frame_tensor[None, :, :, :]
51 | if GPU:
52 | frame_tensor = frame_tensor.cuda()
53 |
54 | with torch.no_grad():
55 | _, _, matte_tensor = modnet(frame_tensor, True)
56 |
57 | matte_tensor = matte_tensor.repeat(1, 3, 1, 1)
58 | matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0)
59 | fg_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0)
60 | view_np = np.uint8(np.concatenate((frame_np, fg_np), axis=1))
61 | view_np = cv2.cvtColor(view_np, cv2.COLOR_RGB2BGR)
62 |
63 | cv2.imshow('MODNet - WebCam [Press \'Q\' To Exit]', view_np)
64 | if cv2.waitKey(1) & 0xFF == ord('q'):
65 | break
66 |
67 | print('Exit...')
68 |
--------------------------------------------------------------------------------
/doc/gif/homepage_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/doc/gif/homepage_demo.gif
--------------------------------------------------------------------------------
/doc/gif/image_matting_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/doc/gif/image_matting_demo.gif
--------------------------------------------------------------------------------
/doc/gif/video_matting_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/doc/gif/video_matting_demo.gif
--------------------------------------------------------------------------------
/pretrained/README.md:
--------------------------------------------------------------------------------
1 | ## MODNet - Pre-Trained Models
2 | This folder is used to save the official pre-trained models of MODNet. You can download them from this [link](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR?usp=sharing).
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/src/__init__.py
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .wrapper import *
2 |
3 |
4 | #------------------------------------------------------------------------------
5 | # Replaceable Backbones
6 | #------------------------------------------------------------------------------
7 |
8 | SUPPORTED_BACKBONES = {
9 | 'mobilenetv2': MobileNetV2Backbone,
10 | }
11 |
--------------------------------------------------------------------------------
/src/models/backbones/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | """ This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch"""
2 |
3 | import math
4 | import json
5 | from functools import reduce
6 |
7 | import torch
8 | from torch import nn
9 |
10 |
11 | #------------------------------------------------------------------------------
12 | # Useful functions
13 | #------------------------------------------------------------------------------
14 |
15 | def _make_divisible(v, divisor, min_value=None):
16 | if min_value is None:
17 | min_value = divisor
18 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
19 | # Make sure that round down does not go down by more than 10%.
20 | if new_v < 0.9 * v:
21 | new_v += divisor
22 | return new_v
23 |
24 |
25 | def conv_bn(inp, oup, stride):
26 | return nn.Sequential(
27 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
28 | nn.BatchNorm2d(oup),
29 | nn.ReLU6(inplace=True)
30 | )
31 |
32 |
33 | def conv_1x1_bn(inp, oup):
34 | return nn.Sequential(
35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36 | nn.BatchNorm2d(oup),
37 | nn.ReLU6(inplace=True)
38 | )
39 |
40 |
41 | #------------------------------------------------------------------------------
42 | # Class of Inverted Residual block
43 | #------------------------------------------------------------------------------
44 |
45 | class InvertedResidual(nn.Module):
46 | def __init__(self, inp, oup, stride, expansion, dilation=1):
47 | super(InvertedResidual, self).__init__()
48 | self.stride = stride
49 | assert stride in [1, 2]
50 |
51 | hidden_dim = round(inp * expansion)
52 | self.use_res_connect = self.stride == 1 and inp == oup
53 |
54 | if expansion == 1:
55 | self.conv = nn.Sequential(
56 | # dw
57 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
58 | nn.BatchNorm2d(hidden_dim),
59 | nn.ReLU6(inplace=True),
60 | # pw-linear
61 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
62 | nn.BatchNorm2d(oup),
63 | )
64 | else:
65 | self.conv = nn.Sequential(
66 | # pw
67 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
68 | nn.BatchNorm2d(hidden_dim),
69 | nn.ReLU6(inplace=True),
70 | # dw
71 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
72 | nn.BatchNorm2d(hidden_dim),
73 | nn.ReLU6(inplace=True),
74 | # pw-linear
75 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
76 | nn.BatchNorm2d(oup),
77 | )
78 |
79 | def forward(self, x):
80 | if self.use_res_connect:
81 | return x + self.conv(x)
82 | else:
83 | return self.conv(x)
84 |
85 |
86 | #------------------------------------------------------------------------------
87 | # Class of MobileNetV2
88 | #------------------------------------------------------------------------------
89 |
90 | class MobileNetV2(nn.Module):
91 | def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
92 | super(MobileNetV2, self).__init__()
93 | self.in_channels = in_channels
94 | self.num_classes = num_classes
95 | input_channel = 32
96 | last_channel = 1280
97 | interverted_residual_setting = [
98 | # t, c, n, s
99 | [1 , 16, 1, 1],
100 | [expansion, 24, 2, 2],
101 | [expansion, 32, 3, 2],
102 | [expansion, 64, 4, 2],
103 | [expansion, 96, 3, 1],
104 | [expansion, 160, 3, 2],
105 | [expansion, 320, 1, 1],
106 | ]
107 |
108 | # building first layer
109 | input_channel = _make_divisible(input_channel*alpha, 8)
110 | self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel
111 | self.features = [conv_bn(self.in_channels, input_channel, 2)]
112 |
113 | # building inverted residual blocks
114 | for t, c, n, s in interverted_residual_setting:
115 | output_channel = _make_divisible(int(c*alpha), 8)
116 | for i in range(n):
117 | if i == 0:
118 | self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
119 | else:
120 | self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
121 | input_channel = output_channel
122 |
123 | # building last several layers
124 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
125 |
126 | # make it nn.Sequential
127 | self.features = nn.Sequential(*self.features)
128 |
129 | # building classifier
130 | if self.num_classes is not None:
131 | self.classifier = nn.Sequential(
132 | nn.Dropout(0.2),
133 | nn.Linear(self.last_channel, num_classes),
134 | )
135 |
136 | # Initialize weights
137 | self._init_weights()
138 |
139 | def forward(self, x, feature_names=None):
140 | # Stage1
141 | x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x)
142 | # Stage2
143 | x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x)
144 | # Stage3
145 | x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x)
146 | # Stage4
147 | x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x)
148 | # Stage5
149 | x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x)
150 |
151 | # Classification
152 | if self.num_classes is not None:
153 | x = x.mean(dim=(2,3))
154 | x = self.classifier(x)
155 |
156 | # Output
157 | return x
158 |
159 | def _load_pretrained_model(self, pretrained_file):
160 | pretrain_dict = torch.load(pretrained_file, map_location='cpu')
161 | model_dict = {}
162 | state_dict = self.state_dict()
163 | print("[MobileNetV2] Loading pretrained model...")
164 | for k, v in pretrain_dict.items():
165 | if k in state_dict:
166 | model_dict[k] = v
167 | else:
168 | print(k, "is ignored")
169 | state_dict.update(model_dict)
170 | self.load_state_dict(state_dict)
171 |
172 | def _init_weights(self):
173 | for m in self.modules():
174 | if isinstance(m, nn.Conv2d):
175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
176 | m.weight.data.normal_(0, math.sqrt(2. / n))
177 | if m.bias is not None:
178 | m.bias.data.zero_()
179 | elif isinstance(m, nn.BatchNorm2d):
180 | m.weight.data.fill_(1)
181 | m.bias.data.zero_()
182 | elif isinstance(m, nn.Linear):
183 | n = m.weight.size(1)
184 | m.weight.data.normal_(0, 0.01)
185 | m.bias.data.zero_()
186 |
--------------------------------------------------------------------------------
/src/models/backbones/wrapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import reduce
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .mobilenetv2 import MobileNetV2
8 |
9 |
10 | class BaseBackbone(nn.Module):
11 | """ Superclass of Replaceable Backbone Model for Semantic Estimation
12 | """
13 |
14 | def __init__(self, in_channels):
15 | super(BaseBackbone, self).__init__()
16 | self.in_channels = in_channels
17 |
18 | self.model = None
19 | self.enc_channels = []
20 |
21 | def forward(self, x):
22 | raise NotImplementedError
23 |
24 | def load_pretrained_ckpt(self):
25 | raise NotImplementedError
26 |
27 |
28 | class MobileNetV2Backbone(BaseBackbone):
29 | """ MobileNetV2 Backbone
30 | """
31 |
32 | def __init__(self, in_channels):
33 | super(MobileNetV2Backbone, self).__init__(in_channels)
34 |
35 | self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
36 | self.enc_channels = [16, 24, 32, 96, 1280]
37 |
38 | def forward(self, x):
39 | x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
40 | enc2x = x
41 | x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
42 | enc4x = x
43 | x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
44 | enc8x = x
45 | x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
46 | enc16x = x
47 | x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
48 | enc32x = x
49 | return [enc2x, enc4x, enc8x, enc16x, enc32x]
50 |
51 | def load_pretrained_ckpt(self):
52 | # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
53 | ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
54 | if not os.path.exists(ckpt_path):
55 | print('cannot find the pretrained mobilenetv2 backbone')
56 | exit()
57 |
58 | ckpt = torch.load(ckpt_path)
59 | self.model.load_state_dict(ckpt)
60 |
--------------------------------------------------------------------------------
/src/models/modnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .backbones import SUPPORTED_BACKBONES
6 |
7 |
8 | #------------------------------------------------------------------------------
9 | # MODNet Basic Modules
10 | #------------------------------------------------------------------------------
11 |
12 | class IBNorm(nn.Module):
13 | """ Combine Instance Norm and Batch Norm into One Layer
14 | """
15 |
16 | def __init__(self, in_channels):
17 | super(IBNorm, self).__init__()
18 | in_channels = in_channels
19 | self.bnorm_channels = int(in_channels / 2)
20 | self.inorm_channels = in_channels - self.bnorm_channels
21 |
22 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
23 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
24 |
25 | def forward(self, x):
26 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
27 | in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
28 |
29 | return torch.cat((bn_x, in_x), 1)
30 |
31 |
32 | class Conv2dIBNormRelu(nn.Module):
33 | """ Convolution + IBNorm + ReLu
34 | """
35 |
36 | def __init__(self, in_channels, out_channels, kernel_size,
37 | stride=1, padding=0, dilation=1, groups=1, bias=True,
38 | with_ibn=True, with_relu=True):
39 | super(Conv2dIBNormRelu, self).__init__()
40 |
41 | layers = [
42 | nn.Conv2d(in_channels, out_channels, kernel_size,
43 | stride=stride, padding=padding, dilation=dilation,
44 | groups=groups, bias=bias)
45 | ]
46 |
47 | if with_ibn:
48 | layers.append(IBNorm(out_channels))
49 | if with_relu:
50 | layers.append(nn.ReLU(inplace=True))
51 |
52 | self.layers = nn.Sequential(*layers)
53 |
54 | def forward(self, x):
55 | return self.layers(x)
56 |
57 |
58 | class SEBlock(nn.Module):
59 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
60 | """
61 |
62 | def __init__(self, in_channels, out_channels, reduction=1):
63 | super(SEBlock, self).__init__()
64 | self.pool = nn.AdaptiveAvgPool2d(1)
65 | self.fc = nn.Sequential(
66 | nn.Linear(in_channels, int(in_channels // reduction), bias=False),
67 | nn.ReLU(inplace=True),
68 | nn.Linear(int(in_channels // reduction), out_channels, bias=False),
69 | nn.Sigmoid()
70 | )
71 |
72 | def forward(self, x):
73 | b, c, _, _ = x.size()
74 | w = self.pool(x).view(b, c)
75 | w = self.fc(w).view(b, c, 1, 1)
76 |
77 | return x * w.expand_as(x)
78 |
79 |
80 | #------------------------------------------------------------------------------
81 | # MODNet Branches
82 | #------------------------------------------------------------------------------
83 |
84 | class LRBranch(nn.Module):
85 | """ Low Resolution Branch of MODNet
86 | """
87 |
88 | def __init__(self, backbone):
89 | super(LRBranch, self).__init__()
90 |
91 | enc_channels = backbone.enc_channels
92 |
93 | self.backbone = backbone
94 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
95 | self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
96 | self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
97 | self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)
98 |
99 | def forward(self, img, inference):
100 | enc_features = self.backbone.forward(img)
101 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
102 |
103 | enc32x = self.se_block(enc32x)
104 | lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
105 | lr16x = self.conv_lr16x(lr16x)
106 | lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
107 | lr8x = self.conv_lr8x(lr8x)
108 |
109 | pred_semantic = None
110 | if not inference:
111 | lr = self.conv_lr(lr8x)
112 | pred_semantic = torch.sigmoid(lr)
113 |
114 | return pred_semantic, lr8x, [enc2x, enc4x]
115 |
116 |
117 | class HRBranch(nn.Module):
118 | """ High Resolution Branch of MODNet
119 | """
120 |
121 | def __init__(self, hr_channels, enc_channels):
122 | super(HRBranch, self).__init__()
123 |
124 | self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
125 | self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
126 |
127 | self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
128 | self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
129 |
130 | self.conv_hr4x = nn.Sequential(
131 | Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
132 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
133 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
134 | )
135 |
136 | self.conv_hr2x = nn.Sequential(
137 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
138 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
139 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
140 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
141 | )
142 |
143 | self.conv_hr = nn.Sequential(
144 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
145 | Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
146 | )
147 |
148 | def forward(self, img, enc2x, enc4x, lr8x, inference):
149 | img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
150 | img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)
151 |
152 | enc2x = self.tohr_enc2x(enc2x)
153 | hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
154 |
155 | enc4x = self.tohr_enc4x(enc4x)
156 | hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
157 |
158 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
159 | hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
160 |
161 | hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
162 | hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
163 |
164 | pred_detail = None
165 | if not inference:
166 | hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
167 | hr = self.conv_hr(torch.cat((hr, img), dim=1))
168 | pred_detail = torch.sigmoid(hr)
169 |
170 | return pred_detail, hr2x
171 |
172 |
173 | class FusionBranch(nn.Module):
174 | """ Fusion Branch of MODNet
175 | """
176 |
177 | def __init__(self, hr_channels, enc_channels):
178 | super(FusionBranch, self).__init__()
179 | self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
180 |
181 | self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
182 | self.conv_f = nn.Sequential(
183 | Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
184 | Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
185 | )
186 |
187 | def forward(self, img, lr8x, hr2x):
188 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
189 | lr4x = self.conv_lr4x(lr4x)
190 | lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
191 |
192 | f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
193 | f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
194 | f = self.conv_f(torch.cat((f, img), dim=1))
195 | pred_matte = torch.sigmoid(f)
196 |
197 | return pred_matte
198 |
199 |
200 | #------------------------------------------------------------------------------
201 | # MODNet
202 | #------------------------------------------------------------------------------
203 |
204 | class MODNet(nn.Module):
205 | """ Architecture of MODNet
206 | """
207 |
208 | def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
209 | super(MODNet, self).__init__()
210 |
211 | self.in_channels = in_channels
212 | self.hr_channels = hr_channels
213 | self.backbone_arch = backbone_arch
214 | self.backbone_pretrained = backbone_pretrained
215 |
216 | self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)
217 |
218 | self.lr_branch = LRBranch(self.backbone)
219 | self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
220 | self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
221 |
222 | for m in self.modules():
223 | if isinstance(m, nn.Conv2d):
224 | self._init_conv(m)
225 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
226 | self._init_norm(m)
227 |
228 | if self.backbone_pretrained:
229 | self.backbone.load_pretrained_ckpt()
230 |
231 | def forward(self, img, inference):
232 | pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
233 | pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
234 | pred_matte = self.f_branch(img, lr8x, hr2x)
235 |
236 | return pred_semantic, pred_detail, pred_matte
237 |
238 | def freeze_norm(self):
239 | norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
240 | for m in self.modules():
241 | for n in norm_types:
242 | if isinstance(m, n):
243 | m.eval()
244 | continue
245 |
246 | def _init_conv(self, conv):
247 | nn.init.kaiming_uniform_(
248 | conv.weight, a=0, mode='fan_in', nonlinearity='relu')
249 | if conv.bias is not None:
250 | nn.init.constant_(conv.bias, 0)
251 |
252 | def _init_norm(self, norm):
253 | if norm.weight is not None:
254 | nn.init.constant_(norm.weight, 1)
255 | nn.init.constant_(norm.bias, 0)
256 |
--------------------------------------------------------------------------------
/src/models/onnx_modnet.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is a modified version of the original file modnet.py without
3 | "pred_semantic" and "pred_details" as these both returns None when "inference = True"
4 |
5 | And it does not contain "inference" argument which will make it easier to
6 | convert checkpoint into onnx model.
7 |
8 | Refer: 'demo/image_matting/inference_with_ONNX/export_modnet_onnx.py' to export model.
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | from .backbones import SUPPORTED_BACKBONES
16 |
17 |
18 | #------------------------------------------------------------------------------
19 | # MODNet Basic Modules
20 | #------------------------------------------------------------------------------
21 |
22 | class IBNorm(nn.Module):
23 | """ Combine Instance Norm and Batch Norm into One Layer
24 | """
25 |
26 | def __init__(self, in_channels):
27 | super(IBNorm, self).__init__()
28 | in_channels = in_channels
29 | self.bnorm_channels = int(in_channels / 2)
30 | self.inorm_channels = in_channels - self.bnorm_channels
31 |
32 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
33 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
34 |
35 | def forward(self, x):
36 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
37 | in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
38 |
39 | return torch.cat((bn_x, in_x), 1)
40 |
41 |
42 | class Conv2dIBNormRelu(nn.Module):
43 | """ Convolution + IBNorm + ReLu
44 | """
45 |
46 | def __init__(self, in_channels, out_channels, kernel_size,
47 | stride=1, padding=0, dilation=1, groups=1, bias=True,
48 | with_ibn=True, with_relu=True):
49 | super(Conv2dIBNormRelu, self).__init__()
50 |
51 | layers = [
52 | nn.Conv2d(in_channels, out_channels, kernel_size,
53 | stride=stride, padding=padding, dilation=dilation,
54 | groups=groups, bias=bias)
55 | ]
56 |
57 | if with_ibn:
58 | layers.append(IBNorm(out_channels))
59 | if with_relu:
60 | layers.append(nn.ReLU(inplace=True))
61 |
62 | self.layers = nn.Sequential(*layers)
63 |
64 | def forward(self, x):
65 | return self.layers(x)
66 |
67 |
68 | class SEBlock(nn.Module):
69 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
70 | """
71 |
72 | def __init__(self, in_channels, out_channels, reduction=1):
73 | super(SEBlock, self).__init__()
74 | self.pool = nn.AdaptiveAvgPool2d(1)
75 | self.fc = nn.Sequential(
76 | nn.Linear(in_channels, int(in_channels // reduction), bias=False),
77 | nn.ReLU(inplace=True),
78 | nn.Linear(int(in_channels // reduction), out_channels, bias=False),
79 | nn.Sigmoid()
80 | )
81 |
82 | def forward(self, x):
83 | b, c, _, _ = x.size()
84 | w = self.pool(x).view(b, c)
85 | w = self.fc(w).view(b, c, 1, 1)
86 |
87 | return x * w.expand_as(x)
88 |
89 |
90 | #------------------------------------------------------------------------------
91 | # MODNet Branches
92 | #------------------------------------------------------------------------------
93 |
94 | class LRBranch(nn.Module):
95 | """ Low Resolution Branch of MODNet
96 | """
97 |
98 | def __init__(self, backbone):
99 | super(LRBranch, self).__init__()
100 |
101 | enc_channels = backbone.enc_channels
102 |
103 | self.backbone = backbone
104 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
105 | self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
106 | self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
107 | self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)
108 |
109 | def forward(self, img):
110 | enc_features = self.backbone.forward(img)
111 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
112 |
113 | enc32x = self.se_block(enc32x)
114 | lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
115 | lr16x = self.conv_lr16x(lr16x)
116 | lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
117 | lr8x = self.conv_lr8x(lr8x)
118 |
119 | return lr8x, [enc2x, enc4x]
120 |
121 |
122 | class HRBranch(nn.Module):
123 | """ High Resolution Branch of MODNet
124 | """
125 |
126 | def __init__(self, hr_channels, enc_channels):
127 | super(HRBranch, self).__init__()
128 |
129 | self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
130 | self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
131 |
132 | self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
133 | self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
134 |
135 | self.conv_hr4x = nn.Sequential(
136 | Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
137 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
138 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
139 | )
140 |
141 | self.conv_hr2x = nn.Sequential(
142 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
143 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
144 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
145 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
146 | )
147 |
148 | self.conv_hr = nn.Sequential(
149 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
150 | Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
151 | )
152 |
153 | def forward(self, img, enc2x, enc4x, lr8x):
154 | img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False)
155 | img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False)
156 |
157 | enc2x = self.tohr_enc2x(enc2x)
158 | hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
159 |
160 | enc4x = self.tohr_enc4x(enc4x)
161 | hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
162 |
163 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
164 | hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
165 |
166 | hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
167 | hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
168 |
169 | return hr2x
170 |
171 |
172 | class FusionBranch(nn.Module):
173 | """ Fusion Branch of MODNet
174 | """
175 |
176 | def __init__(self, hr_channels, enc_channels):
177 | super(FusionBranch, self).__init__()
178 | self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
179 |
180 | self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
181 | self.conv_f = nn.Sequential(
182 | Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
183 | Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
184 | )
185 |
186 | def forward(self, img, lr8x, hr2x):
187 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
188 | lr4x = self.conv_lr4x(lr4x)
189 | lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
190 |
191 | f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
192 | f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
193 | f = self.conv_f(torch.cat((f, img), dim=1))
194 | pred_matte = torch.sigmoid(f)
195 |
196 | return pred_matte
197 |
198 |
199 | #------------------------------------------------------------------------------
200 | # MODNet
201 | #------------------------------------------------------------------------------
202 |
203 | class MODNet(nn.Module):
204 | """ Architecture of MODNet
205 | """
206 |
207 | def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
208 | super(MODNet, self).__init__()
209 |
210 | self.in_channels = in_channels
211 | self.hr_channels = hr_channels
212 | self.backbone_arch = backbone_arch
213 | self.backbone_pretrained = backbone_pretrained
214 |
215 | self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels)
216 |
217 | self.lr_branch = LRBranch(self.backbone)
218 | self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
219 | self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
220 |
221 | for m in self.modules():
222 | if isinstance(m, nn.Conv2d):
223 | self._init_conv(m)
224 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
225 | self._init_norm(m)
226 |
227 | if self.backbone_pretrained:
228 | self.backbone.load_pretrained_ckpt()
229 |
230 | def forward(self, img):
231 | lr8x, [enc2x, enc4x] = self.lr_branch(img)
232 | hr2x = self.hr_branch(img, enc2x, enc4x, lr8x)
233 | pred_matte = self.f_branch(img, lr8x, hr2x)
234 |
235 | return pred_matte
236 |
237 | def freeze_norm(self):
238 | norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
239 | for m in self.modules():
240 | for n in norm_types:
241 | if isinstance(m, n):
242 | m.eval()
243 | continue
244 |
245 | def _init_conv(self, conv):
246 | nn.init.kaiming_uniform_(
247 | conv.weight, a=0, mode='fan_in', nonlinearity='relu')
248 | if conv.bias is not None:
249 | nn.init.constant_(conv.bias, 0)
250 |
251 | def _init_norm(self, norm):
252 | if norm.weight is not None:
253 | nn.init.constant_(norm.weight, 1)
254 | nn.init.constant_(norm.bias, 0)
255 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import scipy
3 | import numpy as np
4 | from scipy.ndimage import grey_dilation, grey_erosion
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | __all__ = [
12 | 'supervised_training_iter',
13 | 'soc_adaptation_iter',
14 | ]
15 |
16 |
17 | # ----------------------------------------------------------------------------------
18 | # Tool Classes/Functions
19 | # ----------------------------------------------------------------------------------
20 |
21 | class GaussianBlurLayer(nn.Module):
22 | """ Add Gaussian Blur to a 4D tensors
23 | This layer takes a 4D tensor of {N, C, H, W} as input.
24 | The Gaussian blur will be performed in given channel number (C) splitly.
25 | """
26 |
27 | def __init__(self, channels, kernel_size):
28 | """
29 | Arguments:
30 | channels (int): Channel for input tensor
31 | kernel_size (int): Size of the kernel used in blurring
32 | """
33 |
34 | super(GaussianBlurLayer, self).__init__()
35 | self.channels = channels
36 | self.kernel_size = kernel_size
37 | assert self.kernel_size % 2 != 0
38 |
39 | self.op = nn.Sequential(
40 | nn.ReflectionPad2d(math.floor(self.kernel_size / 2)),
41 | nn.Conv2d(channels, channels, self.kernel_size,
42 | stride=1, padding=0, bias=None, groups=channels)
43 | )
44 |
45 | self._init_kernel()
46 |
47 | def forward(self, x):
48 | """
49 | Arguments:
50 | x (torch.Tensor): input 4D tensor
51 | Returns:
52 | torch.Tensor: Blurred version of the input
53 | """
54 |
55 | if not len(list(x.shape)) == 4:
56 | print('\'GaussianBlurLayer\' requires a 4D tensor as input\n')
57 | exit()
58 | elif not x.shape[1] == self.channels:
59 | print('In \'GaussianBlurLayer\', the required channel ({0}) is'
60 | 'not the same as input ({1})\n'.format(self.channels, x.shape[1]))
61 | exit()
62 |
63 | return self.op(x)
64 |
65 | def _init_kernel(self):
66 | sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8
67 |
68 | n = np.zeros((self.kernel_size, self.kernel_size))
69 | i = math.floor(self.kernel_size / 2)
70 | n[i, i] = 1
71 | kernel = scipy.ndimage.gaussian_filter(n, sigma)
72 |
73 | for name, param in self.named_parameters():
74 | param.data.copy_(torch.from_numpy(kernel))
75 |
76 | # ----------------------------------------------------------------------------------
77 |
78 |
79 | # ----------------------------------------------------------------------------------
80 | # MODNet Training Functions
81 | # ----------------------------------------------------------------------------------
82 |
83 | blurer = GaussianBlurLayer(1, 3).cuda()
84 |
85 |
86 | def supervised_training_iter(
87 | modnet, optimizer, image, trimap, gt_matte,
88 | semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0):
89 | """ Supervised training iteration of MODNet
90 | This function trains MODNet for one iteration in a labeled dataset.
91 |
92 | Arguments:
93 | modnet (torch.nn.Module): instance of MODNet
94 | optimizer (torch.optim.Optimizer): optimizer for supervised training
95 | image (torch.autograd.Variable): input RGB image
96 | its pixel values should be normalized
97 | trimap (torch.autograd.Variable): trimap used to calculate the losses
98 | its pixel values can be 0, 0.5, or 1
99 | (foreground=1, background=0, unknown=0.5)
100 | gt_matte (torch.autograd.Variable): ground truth alpha matte
101 | its pixel values are between [0, 1]
102 | semantic_scale (float): scale of the semantic loss
103 | NOTE: please adjust according to your dataset
104 | detail_scale (float): scale of the detail loss
105 | NOTE: please adjust according to your dataset
106 | matte_scale (float): scale of the matte loss
107 | NOTE: please adjust according to your dataset
108 |
109 | Returns:
110 | semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch]
111 | detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch]
112 | matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch]
113 |
114 | Example:
115 | import torch
116 | from src.models.modnet import MODNet
117 | from src.trainer import supervised_training_iter
118 |
119 | bs = 16 # batch size
120 | lr = 0.01 # learn rate
121 | epochs = 40 # total epochs
122 |
123 | modnet = torch.nn.DataParallel(MODNet()).cuda()
124 | optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
125 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)
126 |
127 | dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function
128 |
129 | for epoch in range(0, epochs):
130 | for idx, (image, trimap, gt_matte) in enumerate(dataloader):
131 | semantic_loss, detail_loss, matte_loss = \
132 | supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
133 | lr_scheduler.step()
134 | """
135 |
136 | global blurer
137 |
138 | # set the model to train mode and clear the optimizer
139 | modnet.train()
140 | optimizer.zero_grad()
141 |
142 | # forward the model
143 | pred_semantic, pred_detail, pred_matte = modnet(image, False)
144 |
145 | # calculate the boundary mask from the trimap
146 | boundaries = (trimap < 0.5) + (trimap > 0.5)
147 |
148 | # calculate the semantic loss
149 | gt_semantic = F.interpolate(gt_matte, scale_factor=1/16, mode='bilinear')
150 | gt_semantic = blurer(gt_semantic)
151 | semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
152 | semantic_loss = semantic_scale * semantic_loss
153 |
154 | # calculate the detail loss
155 | pred_boundary_detail = torch.where(boundaries, trimap, pred_detail)
156 | gt_detail = torch.where(boundaries, trimap, gt_matte)
157 | detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail))
158 | detail_loss = detail_scale * detail_loss
159 |
160 | # calculate the matte loss
161 | pred_boundary_matte = torch.where(boundaries, trimap, pred_matte)
162 | matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
163 | matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
164 | + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
165 | matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
166 | matte_loss = matte_scale * matte_loss
167 |
168 | # calculate the final loss, backward the loss, and update the model
169 | loss = semantic_loss + detail_loss + matte_loss
170 | loss.backward()
171 | optimizer.step()
172 |
173 | # for test
174 | return semantic_loss, detail_loss, matte_loss
175 |
176 |
177 | def soc_adaptation_iter(
178 | modnet, backup_modnet, optimizer, image,
179 | soc_semantic_scale=100.0, soc_detail_scale=1.0):
180 | """ Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet
181 | This function fine-tunes MODNet for one iteration in an unlabeled dataset.
182 | Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been
183 | trained in a labeled dataset.
184 |
185 | Arguments:
186 | modnet (torch.nn.Module): instance of MODNet
187 | backup_modnet (torch.nn.Module): backup of the trained MODNet
188 | optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC
189 | image (torch.autograd.Variable): input RGB image
190 | its pixel values should be normalized
191 | soc_semantic_scale (float): scale of the SOC semantic loss
192 | NOTE: please adjust according to your dataset
193 | soc_detail_scale (float): scale of the SOC detail loss
194 | NOTE: please adjust according to your dataset
195 |
196 | Returns:
197 | soc_semantic_loss (torch.Tensor): loss of the semantic SOC
198 | soc_detail_loss (torch.Tensor): loss of the detail SOC
199 |
200 | Example:
201 | import copy
202 | import torch
203 | from src.models.modnet import MODNet
204 | from src.trainer import soc_adaptation_iter
205 |
206 | bs = 1 # batch size
207 | lr = 0.00001 # learn rate
208 | epochs = 10 # total epochs
209 |
210 | modnet = torch.nn.DataParallel(MODNet()).cuda()
211 | modnet = LOAD_TRAINED_CKPT() # NOTE: please finish this function
212 |
213 | optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99))
214 | dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function
215 |
216 | for epoch in range(0, epochs):
217 | backup_modnet = copy.deepcopy(modnet)
218 | for idx, (image) in enumerate(dataloader):
219 | soc_semantic_loss, soc_detail_loss = \
220 | soc_adaptation_iter(modnet, backup_modnet, optimizer, image)
221 | """
222 |
223 | global blurer
224 |
225 | # set the backup model to eval mode
226 | backup_modnet.eval()
227 |
228 | # set the main model to train mode and freeze its norm layers
229 | modnet.train()
230 | modnet.module.freeze_norm()
231 |
232 | # clear the optimizer
233 | optimizer.zero_grad()
234 |
235 | # forward the main model
236 | pred_semantic, pred_detail, pred_matte = modnet(image, False)
237 |
238 | # forward the backup model
239 | with torch.no_grad():
240 | _, pred_backup_detail, pred_backup_matte = backup_modnet(image, False)
241 |
242 | # calculate the boundary mask from `pred_matte` and `pred_semantic`
243 | pred_matte_fg = (pred_matte.detach() > 0.1).float()
244 | pred_semantic_fg = (pred_semantic.detach() > 0.1).float()
245 | pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear')
246 | pred_fg = pred_matte_fg * pred_semantic_fg
247 |
248 | n, c, h, w = pred_matte.shape
249 | np_pred_fg = pred_fg.data.cpu().numpy()
250 | np_boundaries = np.zeros([n, c, h, w])
251 | for sdx in range(0, n):
252 | sample_np_boundaries = np_boundaries[sdx, 0, ...]
253 | sample_np_pred_fg = np_pred_fg[sdx, 0, ...]
254 |
255 | side = int((h + w) / 2 * 0.05)
256 | dilated = grey_dilation(sample_np_pred_fg, size=(side, side))
257 | eroded = grey_erosion(sample_np_pred_fg, size=(side, side))
258 |
259 | sample_np_boundaries[np.where(dilated - eroded != 0)] = 1
260 | np_boundaries[sdx, 0, ...] = sample_np_boundaries
261 |
262 | boundaries = torch.tensor(np_boundaries).float().cuda()
263 |
264 | # sub-objectives consistency between `pred_semantic` and `pred_matte`
265 | # generate pseudo ground truth for `pred_semantic`
266 | downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1/16, mode='bilinear'))
267 | pseudo_gt_semantic = downsampled_pred_matte.detach()
268 | pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float()
269 |
270 | # generate pseudo ground truth for `pred_matte`
271 | pseudo_gt_matte = pred_semantic.detach()
272 | pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float()
273 |
274 | # calculate the SOC semantic loss
275 | soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte, pseudo_gt_matte)
276 | soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss)
277 |
278 | # NOTE: using the formulas in our paper to calculate the following losses has similar results
279 | # sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only)
280 | backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail)
281 | backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
282 | backup_detail_loss = torch.mean(backup_detail_loss)
283 |
284 | # sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only)
285 | backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte)
286 | backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3))
287 | backup_matte_loss = torch.mean(backup_matte_loss)
288 |
289 | soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss)
290 |
291 | # calculate the final loss, backward the loss, and update the model
292 | loss = soc_semantic_loss + soc_detail_loss
293 |
294 | loss.backward()
295 | optimizer.step()
296 |
297 | return soc_semantic_loss, soc_detail_loss
298 |
299 | # ----------------------------------------------------------------------------------
300 |
--------------------------------------------------------------------------------