├── .gitignore ├── README.md ├── demo ├── image_matting │ ├── Inference_with_ONNX │ │ ├── README.md │ │ ├── export_modnet_onnx.py │ │ ├── inference_onnx.py │ │ └── requirements.txt │ └── colab │ │ ├── README.md │ │ └── inference.py └── video_matting │ ├── custom │ ├── README.md │ ├── requirements.txt │ └── run.py │ └── webcam │ ├── README.md │ ├── requirements.txt │ └── run.py ├── doc └── gif │ ├── homepage_demo.gif │ ├── image_matting_demo.gif │ └── video_matting_demo.gif ├── pretrained └── README.md └── src ├── __init__.py ├── models ├── __init__.py ├── backbones │ ├── __init__.py │ ├── mobilenetv2.py │ └── wrapper.py ├── modnet.py └── onnx_modnet.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary directories and files 2 | *.ckpt 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # IPython Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # dotenv 82 | .env 83 | 84 | # virtualenv 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | 91 | # Rope project settings 92 | .ropeproject 93 | 94 | 95 | # Project files 96 | .vscode -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

MODNet: Is a Green Screen Really Necessary for Real-Time Portrait Matting?

2 | 3 | 4 | 5 |

6 | Arxiv Preprint | 7 | Supplementary Video 8 |

9 | 10 |

11 | WebCam Video Demo [Offline][Colab] | Custom Video Demo [Offline] | 12 | Image Demo [WebGUI][Colab] 13 |

14 | 15 |
This is the official project of our paper Is a Green Screen Really Necessary for Real-Time Portrait Matting?
16 |
MODNet is a trimap-free model for portrait matting in real time under changing scenes.
17 | 18 | 19 | --- 20 | 21 | 22 | ## News 23 | - [Jan  28 2021] Release the [code](src/trainer.py) of MODNet training iteration. 24 | - [Dec 25 2020] ***Merry Christmas!*** :christmas_tree: Release Custom Video Matting Demo [[Offline](demo/video_matting/custom)] for user videos. 25 | - [Dec 10 2020] Release WebCam Video Matting Demo [[Offline](demo/video_matting/webcam)][[Colab](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing)] and Image Matting Demo [[Colab](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing)]. 26 | - [Nov 24 2020] Release [Arxiv Preprint](https://arxiv.org/pdf/2011.11961.pdf) and [Supplementary Video](https://youtu.be/PqJ3BRHX3Lc). 27 | 28 | 29 | ## Demos 30 | 31 | ### Video Matting 32 | We provide two real-time portrait video matting demos based on WebCam. When using the demo, you can move the WebCam around at will. 33 | If you have an Ubuntu system, we recommend you to try the [offline demo](demo/video_matting/webcam) to get a higher *fps*. Otherwise, you can access the [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing). 34 | We also provide an [offline demo](demo/video_matting/custom) that allows you to process custom videos. 35 | 36 | 37 | 38 | 39 | ### Image Matting 40 | We provide an [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing) for portrait image matting. 41 | It allows you to upload portrait images and predict/visualize/download the alpha mattes. 42 | 43 | 44 | 45 | 46 | ### Community 47 | Here we share some cool applications of MODNet built by the community. 48 | 49 | - **WebGUI for Image Matting** 50 | You can try [this WebGUI](https://gradio.app/g/modnet) (hosted on [Gradio](https://www.gradio.app/)) for portrait matting from your browser without any code! 51 | 52 | 53 | - **Colab Demo of Bokeh (Blur Background)** 54 | You can try [this Colab demo](https://colab.research.google.com/github/eyaler/avatars4all/blob/master/yarok.ipynb) (built by [@eyaler](https://github.com/eyaler)) to blur the backgroud based on MODNet! 55 | 56 | 57 | ## Code 58 | We provide the [code](src/trainer.py) of MODNet training iteration, including: 59 | - **Supervised Training**: Train MODNet on a labeled matting dataset 60 | - **SOC Adaptation**: Adapt a trained MODNet to an unlabeled dataset 61 | 62 | In the function comments, we provide examples of how to call the function. 63 | 64 | 65 | ## TODO 66 | - Release the code of One-Frame Delay (OFD) 67 | - Release PPM-100 validation benchmark (scheduled in **Feb 2021**) 68 | **NOTE**: PPM-100 is a **validation set**. Our training set will not be published 69 | 70 | 71 | ## License 72 | This project (code, pre-trained models, demos, *etc.*) is released under the [Creative Commons Attribution NonCommercial ShareAlike 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) license. 73 | 74 | 75 | ## Acknowledgement 76 | - We thank [City University of Hong Kong](https://www.cityu.edu.hk/) and [SenseTime](https://www.sensetime.com/) for their support to this project. 77 | - We thank 78 |         [the Gradio team](https://github.com/gradio-app/gradio), [@eyaler](https://github.com/eyaler), 79 | for their cool applications based on MODNet. 80 | 81 | 82 | ## Citation 83 | If this work helps your research, please consider to cite: 84 | 85 | ```bibtex 86 | @article{MODNet, 87 | author = {Zhanghan Ke and Kaican Li and Yurou Zhou and Qiuhua Wu and Xiangyu Mao and Qiong Yan and Rynson W.H. Lau}, 88 | title = {Is a Green Screen Really Necessary for Real-Time Portrait Matting?}, 89 | journal={ArXiv}, 90 | volume={abs/2011.11961}, 91 | year = {2020}, 92 | } 93 | ``` 94 | 95 | 96 | ## Contact 97 | This project is currently maintained by Zhanghan Ke ([@ZHKKKe](https://github.com/ZHKKKe)). 98 | If you have any questions, please feel free to contact `kezhanghan@outlook.com`. 99 | -------------------------------------------------------------------------------- /demo/image_matting/Inference_with_ONNX/README.md: -------------------------------------------------------------------------------- 1 | # Inference with onnxruntime 2 | 3 | Please try MODNet image matting onnx-inference demo with [Colab Notebook](https://colab.research.google.com/drive/1P3cWtg8fnmu9karZHYDAtmm1vj1rgA-f?usp=sharing) 4 | 5 | Download [modnet.onnx](https://drive.google.com/file/d/1cgycTQlYXpTh26gB9FTnthE7AvruV8hd/view?usp=sharing) 6 | 7 | ### 1. Export onnx model 8 | 9 | Run the following command: 10 | ```shell 11 | python export_modnet_onnx.py \ 12 | --ckpt-path=pretrained/modnet_photographic_portrait_matting.ckpt \ 13 | --output-path=modnet.onnx 14 | ``` 15 | 16 | 17 | ### 2. Inference 18 | 19 | Run the following command: 20 | ```shell 21 | python inference_onnx.py \ 22 | --image-path=PATH_TO_IMAGE \ 23 | --output-path=matte.png \ 24 | --model-path=modnet.onnx 25 | ``` 26 | 27 | -------------------------------------------------------------------------------- /demo/image_matting/Inference_with_ONNX/export_modnet_onnx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export onnx model 3 | 4 | Arguments: 5 | --ckpt-path --> Path of last checkpoint to load 6 | --output-path --> path of onnx model to be saved 7 | 8 | example: 9 | python export_modnet_onnx.py \ 10 | --ckpt-path=modnet_photographic_portrait_matting.ckpt \ 11 | --output-path=modnet.onnx 12 | 13 | output: 14 | ONNX model with dynamic input shape: (batch_size, 3, height, width) & 15 | output shape: (batch_size, 1, height, width) 16 | """ 17 | import os 18 | import argparse 19 | import torch 20 | import torch.nn as nn 21 | from torch.autograd import Variable 22 | from src.models.onnx_modnet import MODNet 23 | 24 | 25 | 26 | if __name__ == '__main__': 27 | # define cmd arguments 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--ckpt-path', type=str, required=True, help='path of pre-trained MODNet') 30 | parser.add_argument('--output-path', type=str, required=True, help='path of output onnx model') 31 | args = parser.parse_args() 32 | 33 | # check input arguments 34 | if not os.path.exists(args.ckpt_path): 35 | print('Cannot find checkpoint path: {0}'.format(args.ckpt_path)) 36 | exit() 37 | 38 | # define model & load checkpoint 39 | modnet = MODNet(backbone_pretrained=False) 40 | modnet = nn.DataParallel(modnet).cuda() 41 | state_dict = torch.load(args.ckpt_path) 42 | modnet.load_state_dict(state_dict) 43 | modnet.eval() 44 | 45 | # prepare dummy_input 46 | batch_size = 1 47 | height = 512 48 | width = 512 49 | dummy_input = Variable(torch.randn(batch_size, 3, height, width)).cuda() 50 | 51 | # export to onnx model 52 | torch.onnx.export(modnet.module, dummy_input, args.output_path, export_params = True, opset_version=11, 53 | input_names = ['input'], output_names = ['output'], 54 | dynamic_axes = {'input': {0:'batch_size', 2:'height', 3:'width'}, 55 | 'output': {0: 'batch_size', 2: 'height', 3: 'width'}}) 56 | -------------------------------------------------------------------------------- /demo/image_matting/Inference_with_ONNX/inference_onnx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference with onnxruntime 3 | 4 | Arguments: 5 | --image-path --> path to single input image 6 | --output-path --> paht to save generated matte 7 | --model-path --> path to onnx model file 8 | 9 | example: 10 | python inference_onnx.py \ 11 | --image-path=demo.jpg \ 12 | --output-path=matte.png \ 13 | --model-path=modnet.onnx 14 | 15 | Optional: 16 | Generate transparent image without background 17 | """ 18 | import os 19 | import argparse 20 | import cv2 21 | import numpy as np 22 | import onnx 23 | import onnxruntime 24 | from onnx import helper 25 | from PIL import Image 26 | 27 | if __name__ == '__main__': 28 | # define cmd arguments 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--image-path', type=str, help='path of input image') 31 | parser.add_argument('--output-path', type=str, help='path of output image') 32 | parser.add_argument('--model-path', type=str, help='path of onnx model') 33 | args = parser.parse_args() 34 | 35 | # check input arguments 36 | if not os.path.exists(args.image_path): 37 | print('Cannot find input path: {0}'.format(args.image_path)) 38 | exit() 39 | if not os.path.exists(args.model_path): 40 | print('Cannot find model path: {0}'.format(args.model_path)) 41 | exit() 42 | 43 | ref_size = 512 44 | 45 | # Get x_scale_factor & y_scale_factor to resize image 46 | def get_scale_factor(im_h, im_w, ref_size): 47 | 48 | if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: 49 | if im_w >= im_h: 50 | im_rh = ref_size 51 | im_rw = int(im_w / im_h * ref_size) 52 | elif im_w < im_h: 53 | im_rw = ref_size 54 | im_rh = int(im_h / im_w * ref_size) 55 | else: 56 | im_rh = im_h 57 | im_rw = im_w 58 | 59 | im_rw = im_rw - im_rw % 32 60 | im_rh = im_rh - im_rh % 32 61 | 62 | x_scale_factor = im_rw / im_w 63 | y_scale_factor = im_rh / im_h 64 | 65 | return x_scale_factor, y_scale_factor 66 | 67 | ############################################## 68 | # Main Inference part 69 | ############################################## 70 | 71 | # read image 72 | im = cv2.imread(args.image_path) 73 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 74 | 75 | # unify image channels to 3 76 | if len(im.shape) == 2: 77 | im = im[:, :, None] 78 | if im.shape[2] == 1: 79 | im = np.repeat(im, 3, axis=2) 80 | elif im.shape[2] == 4: 81 | im = im[:, :, 0:3] 82 | 83 | # normalize values to scale it between -1 to 1 84 | im = (im - 127.5) / 127.5 85 | 86 | im_h, im_w, im_c = im.shape 87 | x, y = get_scale_factor(im_h, im_w, ref_size) 88 | 89 | # resize image 90 | im = cv2.resize(im, None, fx = x, fy = y, interpolation = cv2.INTER_AREA) 91 | 92 | # prepare input shape 93 | im = np.transpose(im) 94 | im = np.swapaxes(im, 1, 2) 95 | im = np.expand_dims(im, axis = 0).astype('float32') 96 | 97 | # Initialize session and get prediction 98 | session = onnxruntime.InferenceSession(args.model_path, None) 99 | input_name = session.get_inputs()[0].name 100 | output_name = session.get_outputs()[0].name 101 | result = session.run([output_name], {input_name: im}) 102 | 103 | # refine matte 104 | matte = (np.squeeze(result[0]) * 255).astype('uint8') 105 | matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation = cv2.INTER_AREA) 106 | 107 | cv2.imwrite(args.output_path, matte) 108 | 109 | ############################################## 110 | # Optional - save png image without background 111 | ############################################## 112 | 113 | # im_PIL = Image.open(args.image_path) 114 | # matte = Image.fromarray(matte) 115 | # im_PIL.putalpha(matte) # add alpha channel to keep transparency 116 | # im_PIL.save('without_background.png') -------------------------------------------------------------------------------- /demo/image_matting/Inference_with_ONNX/requirements.txt: -------------------------------------------------------------------------------- 1 | onnx==1.8.1 2 | onnxruntime==1.6.0 3 | opencv-python==4.5.1.48 4 | torch==1.7.1 -------------------------------------------------------------------------------- /demo/image_matting/colab/README.md: -------------------------------------------------------------------------------- 1 | ## MODNet - Portrait Image Matting Demo 2 | Please try MODNet portrait image matting demo through our [online Colab demo](https://colab.research.google.com/drive/1GANpbKT06aEFiW-Ssx0DQnnEADcXwQG6?usp=sharing). 3 | -------------------------------------------------------------------------------- /demo/image_matting/colab/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision.transforms as transforms 11 | 12 | from src.models.modnet import MODNet 13 | 14 | 15 | if __name__ == '__main__': 16 | # define cmd arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--input-path', type=str, help='path of input images') 19 | parser.add_argument('--output-path', type=str, help='path of output images') 20 | parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet') 21 | args = parser.parse_args() 22 | 23 | # check input arguments 24 | if not os.path.exists(args.input_path): 25 | print('Cannot find input path: {0}'.format(args.input_path)) 26 | exit() 27 | if not os.path.exists(args.output_path): 28 | print('Cannot find output path: {0}'.format(args.output_path)) 29 | exit() 30 | if not os.path.exists(args.ckpt_path): 31 | print('Cannot find ckpt path: {0}'.format(args.ckpt_path)) 32 | exit() 33 | 34 | # define hyper-parameters 35 | ref_size = 512 36 | 37 | # define image to tensor transform 38 | im_transform = transforms.Compose( 39 | [ 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 42 | ] 43 | ) 44 | 45 | # create MODNet and load the pre-trained ckpt 46 | modnet = MODNet(backbone_pretrained=False) 47 | modnet = nn.DataParallel(modnet).cuda() 48 | modnet.load_state_dict(torch.load(args.ckpt_path)) 49 | modnet.eval() 50 | 51 | # inference images 52 | im_names = os.listdir(args.input_path) 53 | for im_name in im_names: 54 | print('Process image: {0}'.format(im_name)) 55 | 56 | # read image 57 | im = Image.open(os.path.join(args.input_path, im_name)) 58 | 59 | # unify image channels to 3 60 | im = np.asarray(im) 61 | if len(im.shape) == 2: 62 | im = im[:, :, None] 63 | if im.shape[2] == 1: 64 | im = np.repeat(im, 3, axis=2) 65 | elif im.shape[2] == 4: 66 | im = im[:, :, 0:3] 67 | 68 | # convert image to PyTorch tensor 69 | im = Image.fromarray(im) 70 | im = im_transform(im) 71 | 72 | # add mini-batch dim 73 | im = im[None, :, :, :] 74 | 75 | # resize image for input 76 | im_b, im_c, im_h, im_w = im.shape 77 | if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: 78 | if im_w >= im_h: 79 | im_rh = ref_size 80 | im_rw = int(im_w / im_h * ref_size) 81 | elif im_w < im_h: 82 | im_rw = ref_size 83 | im_rh = int(im_h / im_w * ref_size) 84 | else: 85 | im_rh = im_h 86 | im_rw = im_w 87 | 88 | im_rw = im_rw - im_rw % 32 89 | im_rh = im_rh - im_rh % 32 90 | im = F.interpolate(im, size=(im_rh, im_rw), mode='area') 91 | 92 | # inference 93 | _, _, matte = modnet(im.cuda(), True) 94 | 95 | # resize and save matte 96 | matte = F.interpolate(matte, size=(im_h, im_w), mode='area') 97 | matte = matte[0][0].data.cpu().numpy() 98 | matte_name = im_name.split('.')[0] + '.png' 99 | Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name)) 100 | -------------------------------------------------------------------------------- /demo/video_matting/custom/README.md: -------------------------------------------------------------------------------- 1 | ## MODNet - Custom Portrait Video Matting Demo 2 | This is a MODNet portrait video matting demo that allows you to process custom videos. 3 | 4 | ### 1. Requirements 5 | The basic requirements for this demo are: 6 | - Ubuntu System 7 | - Python 3+ 8 | 9 | 10 | ### 2. Introduction 11 | We use ~400 unlabeled video clips (divided into ~50,000 frames) downloaded from the internet to perform SOC to adapt MODNet to the video domain. **Nonetheless, due to insufficient labeled training data (~3k labeled foregrounds), our model may still make errors in portrait semantics estimation under challenging scenes.** Besides, this demo does not currently support the OFD trick. 12 | 13 | 14 | For a better experience, please make sure your videos satisfy: 15 | 16 | * the portrait and background are distinguishable, i.e., are not similar 17 | * captured in soft and bright ambient lighting 18 | * the contents do not move too fast 19 | 20 | ### 3. Run Demo 21 | We recommend creating a new conda virtual environment to run this demo, as follow: 22 | 23 | 1. Clone the MODNet repository: 24 | ``` 25 | git clone https://github.com/ZHKKKe/MODNet.git 26 | cd MODNet 27 | ``` 28 | 29 | 2. Download the pre-trained model from this [link](https://drive.google.com/file/d/1Nf1ZxeJZJL8Qx9KadcYYyEmmlKhTADxX/view?usp=sharing) and put it into the folder `MODNet/pretrained/`. 30 | 31 | 32 | 3. Create a conda virtual environment named `modnet` (if it doesn't exist) and activate it. Here we use `python=3.6` as an example: 33 | ``` 34 | conda create -n modnet python=3.6 35 | source activate modnet 36 | ``` 37 | 38 | 4. Install the required python dependencies (please make sure your CUDA version is supported by the PyTorch version installed): 39 | ``` 40 | pip install -r demo/video_matting/custom/requirements.txt 41 | ``` 42 | 43 | 5. Execute the main code: 44 | ``` 45 | python -m demo.video_matting.custom.run --video YOUR_VIDEO_PATH 46 | ``` 47 | where `YOUR_VIDEO_PATH` is the specific path of your video. 48 | There are some optional arguments: 49 | - `--result-type (default=fg)` : fg - save the alpha matte; fg - save the foreground 50 | - `--fps (default=30)` : fps of the result video 51 | -------------------------------------------------------------------------------- /demo/video_matting/custom/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Pillow 3 | opencv-python 4 | torch >= 1.0.0 5 | torchvision 6 | tqdm -------------------------------------------------------------------------------- /demo/video_matting/custom/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torchvision.transforms as transforms 11 | 12 | from src.models.modnet import MODNet 13 | 14 | 15 | torch_transforms = transforms.Compose( 16 | [ 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 19 | ] 20 | ) 21 | 22 | 23 | def matting(video, result, alpha_matte=False, fps=30): 24 | # video capture 25 | vc = cv2.VideoCapture(video) 26 | 27 | if vc.isOpened(): 28 | rval, frame = vc.read() 29 | else: 30 | rval = False 31 | 32 | if not rval: 33 | print('Failed to read the video: {0}'.format(video)) 34 | exit() 35 | 36 | num_frame = vc.get(cv2.CAP_PROP_FRAME_COUNT) 37 | h, w = frame.shape[:2] 38 | if w >= h: 39 | rh = 512 40 | rw = int(w / h * 512) 41 | else: 42 | rw = 512 43 | rh = int(h / w * 512) 44 | rh = rh - rh % 32 45 | rw = rw - rw % 32 46 | 47 | # video writer 48 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 49 | video_writer = cv2.VideoWriter(result, fourcc, fps, (w, h)) 50 | 51 | print('Start matting...') 52 | with tqdm(range(int(num_frame)))as t: 53 | for c in t: 54 | frame_np = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 55 | frame_np = cv2.resize(frame_np, (rw, rh), cv2.INTER_AREA) 56 | 57 | frame_PIL = Image.fromarray(frame_np) 58 | frame_tensor = torch_transforms(frame_PIL) 59 | frame_tensor = frame_tensor[None, :, :, :] 60 | if GPU: 61 | frame_tensor = frame_tensor.cuda() 62 | 63 | with torch.no_grad(): 64 | _, _, matte_tensor = modnet(frame_tensor, True) 65 | 66 | matte_tensor = matte_tensor.repeat(1, 3, 1, 1) 67 | matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0) 68 | if alpha_matte: 69 | view_np = matte_np * np.full(frame_np.shape, 255.0) 70 | else: 71 | view_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0) 72 | view_np = cv2.cvtColor(view_np.astype(np.uint8), cv2.COLOR_RGB2BGR) 73 | view_np = cv2.resize(view_np, (w, h)) 74 | video_writer.write(view_np) 75 | 76 | rval, frame = vc.read() 77 | c += 1 78 | 79 | video_writer.release() 80 | print('Save the result video to {0}'.format(result)) 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--video', type=str, required=True, help='input video file') 86 | parser.add_argument('--result-type', type=str, default='fg', choices=['fg', 'matte'], 87 | help='matte - save the alpha matte; fg - save the foreground') 88 | parser.add_argument('--fps', type=int, default=30, help='fps of the result video') 89 | 90 | print('Get CMD Arguments...') 91 | args = parser.parse_args() 92 | 93 | if not os.path.exists(args.video): 94 | print('Cannot find the input video: {0}'.format(args.video)) 95 | exit() 96 | 97 | print('Load pre-trained MODNet...') 98 | pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt' 99 | modnet = MODNet(backbone_pretrained=False) 100 | modnet = nn.DataParallel(modnet) 101 | 102 | GPU = True if torch.cuda.device_count() > 0 else False 103 | if GPU: 104 | print('Use GPU...') 105 | modnet = modnet.cuda() 106 | modnet.load_state_dict(torch.load(pretrained_ckpt)) 107 | else: 108 | print('Use CPU...') 109 | modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu'))) 110 | modnet.eval() 111 | 112 | result = os.path.splitext(args.video)[0] + '_{0}.mp4'.format(args.result_type) 113 | alpha_matte = True if args.result_type == 'matte' else False 114 | matting(args.video, result, alpha_matte, args.fps) 115 | -------------------------------------------------------------------------------- /demo/video_matting/webcam/README.md: -------------------------------------------------------------------------------- 1 | ## MODNet - WebCam-Based Portrait Video Matting Demo 2 | This is a MODNet portrait video matting demo based on WebCam. It will call your local WebCam and display the matting results in real time. The demo can run under CPU or GPU. 3 | 4 | ### 1. Requirements 5 | The basic requirements for this demo are: 6 | - Ubuntu System 7 | - WebCam 8 | - Python 3+ 9 | 10 | **NOTE**: If your device does not satisfy the above conditions, please try our [online Colab demo](https://colab.research.google.com/drive/1Pt3KDSc2q7WxFvekCnCLD8P0gBEbxm6J?usp=sharing). 11 | 12 | 13 | ### 2. Introduction 14 | We use ~400 unlabeled video clips (divided into ~50,000 frames) downloaded from the internet to perform SOC to adapt MODNet to the video domain. **Nonetheless, due to insufficient labeled training data (~3k labeled foregrounds), our model may still make errors in portrait semantics estimation under challenging scenes.** Besides, this demo does not currently support the OFD trick, which will be provided soon. 15 | 16 | For a better experience, please: 17 | 18 | * make sure the portrait and background are distinguishable, i.e., are not similar 19 | * run in soft and bright ambient lighting 20 | * do not be too close or too far from the WebCam 21 | * do not move too fast 22 | 23 | ### 3. Run Demo 24 | We recommend creating a new conda virtual environment to run this demo, as follow: 25 | 26 | 1. Clone the MODNet repository: 27 | ``` 28 | git clone https://github.com/ZHKKKe/MODNet.git 29 | cd MODNet 30 | ``` 31 | 32 | 2. Download the pre-trained model from this [link](https://drive.google.com/file/d/1Nf1ZxeJZJL8Qx9KadcYYyEmmlKhTADxX/view?usp=sharing) and put it into the folder `MODNet/pretrained/`. 33 | 34 | 35 | 3. Create a conda virtual environment named `modnet` (if it doesn't exist) and activate it. Here we use `python=3.6` as an example: 36 | ``` 37 | conda create -n modnet python=3.6 38 | source activate modnet 39 | ``` 40 | 41 | 4. Install the required python dependencies (please make sure your CUDA version is supported by the PyTorch version installed): 42 | ``` 43 | pip install -r demo/video_matting/webcam/requirements.txt 44 | ``` 45 | 46 | 5. Execute the main code: 47 | ``` 48 | python -m demo.video_matting.webcam.run 49 | ``` 50 | 51 | ### 4. Acknowledgement 52 | We thank [@tkianai](https://github.com/tkianai) and [@mazhar004](https://github.com/mazhar004) for their contributions to making this demo available for CPU use. 53 | -------------------------------------------------------------------------------- /demo/video_matting/webcam/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Pillow 3 | opencv-python 4 | torch >= 1.0.0 5 | torchvision -------------------------------------------------------------------------------- /demo/video_matting/webcam/run.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | 9 | from src.models.modnet import MODNet 10 | 11 | 12 | torch_transforms = transforms.Compose( 13 | [ 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 16 | ] 17 | ) 18 | 19 | print('Load pre-trained MODNet...') 20 | pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt' 21 | modnet = MODNet(backbone_pretrained=False) 22 | modnet = nn.DataParallel(modnet) 23 | 24 | GPU = True if torch.cuda.device_count() > 0 else False 25 | if GPU: 26 | print('Use GPU...') 27 | modnet = modnet.cuda() 28 | modnet.load_state_dict(torch.load(pretrained_ckpt)) 29 | else: 30 | print('Use CPU...') 31 | modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu'))) 32 | 33 | modnet.eval() 34 | 35 | print('Init WebCam...') 36 | cap = cv2.VideoCapture(0) 37 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) 38 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) 39 | 40 | print('Start matting...') 41 | while(True): 42 | _, frame_np = cap.read() 43 | frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB) 44 | frame_np = cv2.resize(frame_np, (910, 512), cv2.INTER_AREA) 45 | frame_np = frame_np[:, 120:792, :] 46 | frame_np = cv2.flip(frame_np, 1) 47 | 48 | frame_PIL = Image.fromarray(frame_np) 49 | frame_tensor = torch_transforms(frame_PIL) 50 | frame_tensor = frame_tensor[None, :, :, :] 51 | if GPU: 52 | frame_tensor = frame_tensor.cuda() 53 | 54 | with torch.no_grad(): 55 | _, _, matte_tensor = modnet(frame_tensor, True) 56 | 57 | matte_tensor = matte_tensor.repeat(1, 3, 1, 1) 58 | matte_np = matte_tensor[0].data.cpu().numpy().transpose(1, 2, 0) 59 | fg_np = matte_np * frame_np + (1 - matte_np) * np.full(frame_np.shape, 255.0) 60 | view_np = np.uint8(np.concatenate((frame_np, fg_np), axis=1)) 61 | view_np = cv2.cvtColor(view_np, cv2.COLOR_RGB2BGR) 62 | 63 | cv2.imshow('MODNet - WebCam [Press \'Q\' To Exit]', view_np) 64 | if cv2.waitKey(1) & 0xFF == ord('q'): 65 | break 66 | 67 | print('Exit...') 68 | -------------------------------------------------------------------------------- /doc/gif/homepage_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/doc/gif/homepage_demo.gif -------------------------------------------------------------------------------- /doc/gif/image_matting_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/doc/gif/image_matting_demo.gif -------------------------------------------------------------------------------- /doc/gif/video_matting_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/doc/gif/video_matting_demo.gif -------------------------------------------------------------------------------- /pretrained/README.md: -------------------------------------------------------------------------------- 1 | ## MODNet - Pre-Trained Models 2 | This folder is used to save the official pre-trained models of MODNet. You can download them from this [link](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR?usp=sharing). -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/src/__init__.py -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manthan2305/MODNet/6b529d157c1540daca2478fddd71f4bd89fd9de9/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrapper import * 2 | 3 | 4 | #------------------------------------------------------------------------------ 5 | # Replaceable Backbones 6 | #------------------------------------------------------------------------------ 7 | 8 | SUPPORTED_BACKBONES = { 9 | 'mobilenetv2': MobileNetV2Backbone, 10 | } 11 | -------------------------------------------------------------------------------- /src/models/backbones/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch""" 2 | 3 | import math 4 | import json 5 | from functools import reduce 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | #------------------------------------------------------------------------------ 12 | # Useful functions 13 | #------------------------------------------------------------------------------ 14 | 15 | def _make_divisible(v, divisor, min_value=None): 16 | if min_value is None: 17 | min_value = divisor 18 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 19 | # Make sure that round down does not go down by more than 10%. 20 | if new_v < 0.9 * v: 21 | new_v += divisor 22 | return new_v 23 | 24 | 25 | def conv_bn(inp, oup, stride): 26 | return nn.Sequential( 27 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 28 | nn.BatchNorm2d(oup), 29 | nn.ReLU6(inplace=True) 30 | ) 31 | 32 | 33 | def conv_1x1_bn(inp, oup): 34 | return nn.Sequential( 35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | nn.ReLU6(inplace=True) 38 | ) 39 | 40 | 41 | #------------------------------------------------------------------------------ 42 | # Class of Inverted Residual block 43 | #------------------------------------------------------------------------------ 44 | 45 | class InvertedResidual(nn.Module): 46 | def __init__(self, inp, oup, stride, expansion, dilation=1): 47 | super(InvertedResidual, self).__init__() 48 | self.stride = stride 49 | assert stride in [1, 2] 50 | 51 | hidden_dim = round(inp * expansion) 52 | self.use_res_connect = self.stride == 1 and inp == oup 53 | 54 | if expansion == 1: 55 | self.conv = nn.Sequential( 56 | # dw 57 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), 58 | nn.BatchNorm2d(hidden_dim), 59 | nn.ReLU6(inplace=True), 60 | # pw-linear 61 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 62 | nn.BatchNorm2d(oup), 63 | ) 64 | else: 65 | self.conv = nn.Sequential( 66 | # pw 67 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 68 | nn.BatchNorm2d(hidden_dim), 69 | nn.ReLU6(inplace=True), 70 | # dw 71 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), 72 | nn.BatchNorm2d(hidden_dim), 73 | nn.ReLU6(inplace=True), 74 | # pw-linear 75 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 76 | nn.BatchNorm2d(oup), 77 | ) 78 | 79 | def forward(self, x): 80 | if self.use_res_connect: 81 | return x + self.conv(x) 82 | else: 83 | return self.conv(x) 84 | 85 | 86 | #------------------------------------------------------------------------------ 87 | # Class of MobileNetV2 88 | #------------------------------------------------------------------------------ 89 | 90 | class MobileNetV2(nn.Module): 91 | def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000): 92 | super(MobileNetV2, self).__init__() 93 | self.in_channels = in_channels 94 | self.num_classes = num_classes 95 | input_channel = 32 96 | last_channel = 1280 97 | interverted_residual_setting = [ 98 | # t, c, n, s 99 | [1 , 16, 1, 1], 100 | [expansion, 24, 2, 2], 101 | [expansion, 32, 3, 2], 102 | [expansion, 64, 4, 2], 103 | [expansion, 96, 3, 1], 104 | [expansion, 160, 3, 2], 105 | [expansion, 320, 1, 1], 106 | ] 107 | 108 | # building first layer 109 | input_channel = _make_divisible(input_channel*alpha, 8) 110 | self.last_channel = _make_divisible(last_channel*alpha, 8) if alpha > 1.0 else last_channel 111 | self.features = [conv_bn(self.in_channels, input_channel, 2)] 112 | 113 | # building inverted residual blocks 114 | for t, c, n, s in interverted_residual_setting: 115 | output_channel = _make_divisible(int(c*alpha), 8) 116 | for i in range(n): 117 | if i == 0: 118 | self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t)) 119 | else: 120 | self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t)) 121 | input_channel = output_channel 122 | 123 | # building last several layers 124 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 125 | 126 | # make it nn.Sequential 127 | self.features = nn.Sequential(*self.features) 128 | 129 | # building classifier 130 | if self.num_classes is not None: 131 | self.classifier = nn.Sequential( 132 | nn.Dropout(0.2), 133 | nn.Linear(self.last_channel, num_classes), 134 | ) 135 | 136 | # Initialize weights 137 | self._init_weights() 138 | 139 | def forward(self, x, feature_names=None): 140 | # Stage1 141 | x = reduce(lambda x, n: self.features[n](x), list(range(0,2)), x) 142 | # Stage2 143 | x = reduce(lambda x, n: self.features[n](x), list(range(2,4)), x) 144 | # Stage3 145 | x = reduce(lambda x, n: self.features[n](x), list(range(4,7)), x) 146 | # Stage4 147 | x = reduce(lambda x, n: self.features[n](x), list(range(7,14)), x) 148 | # Stage5 149 | x = reduce(lambda x, n: self.features[n](x), list(range(14,19)), x) 150 | 151 | # Classification 152 | if self.num_classes is not None: 153 | x = x.mean(dim=(2,3)) 154 | x = self.classifier(x) 155 | 156 | # Output 157 | return x 158 | 159 | def _load_pretrained_model(self, pretrained_file): 160 | pretrain_dict = torch.load(pretrained_file, map_location='cpu') 161 | model_dict = {} 162 | state_dict = self.state_dict() 163 | print("[MobileNetV2] Loading pretrained model...") 164 | for k, v in pretrain_dict.items(): 165 | if k in state_dict: 166 | model_dict[k] = v 167 | else: 168 | print(k, "is ignored") 169 | state_dict.update(model_dict) 170 | self.load_state_dict(state_dict) 171 | 172 | def _init_weights(self): 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 176 | m.weight.data.normal_(0, math.sqrt(2. / n)) 177 | if m.bias is not None: 178 | m.bias.data.zero_() 179 | elif isinstance(m, nn.BatchNorm2d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | elif isinstance(m, nn.Linear): 183 | n = m.weight.size(1) 184 | m.weight.data.normal_(0, 0.01) 185 | m.bias.data.zero_() 186 | -------------------------------------------------------------------------------- /src/models/backbones/wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import reduce 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .mobilenetv2 import MobileNetV2 8 | 9 | 10 | class BaseBackbone(nn.Module): 11 | """ Superclass of Replaceable Backbone Model for Semantic Estimation 12 | """ 13 | 14 | def __init__(self, in_channels): 15 | super(BaseBackbone, self).__init__() 16 | self.in_channels = in_channels 17 | 18 | self.model = None 19 | self.enc_channels = [] 20 | 21 | def forward(self, x): 22 | raise NotImplementedError 23 | 24 | def load_pretrained_ckpt(self): 25 | raise NotImplementedError 26 | 27 | 28 | class MobileNetV2Backbone(BaseBackbone): 29 | """ MobileNetV2 Backbone 30 | """ 31 | 32 | def __init__(self, in_channels): 33 | super(MobileNetV2Backbone, self).__init__(in_channels) 34 | 35 | self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None) 36 | self.enc_channels = [16, 24, 32, 96, 1280] 37 | 38 | def forward(self, x): 39 | x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) 40 | enc2x = x 41 | x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) 42 | enc4x = x 43 | x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) 44 | enc8x = x 45 | x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) 46 | enc16x = x 47 | x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) 48 | enc32x = x 49 | return [enc2x, enc4x, enc8x, enc16x, enc32x] 50 | 51 | def load_pretrained_ckpt(self): 52 | # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch 53 | ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' 54 | if not os.path.exists(ckpt_path): 55 | print('cannot find the pretrained mobilenetv2 backbone') 56 | exit() 57 | 58 | ckpt = torch.load(ckpt_path) 59 | self.model.load_state_dict(ckpt) 60 | -------------------------------------------------------------------------------- /src/models/modnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .backbones import SUPPORTED_BACKBONES 6 | 7 | 8 | #------------------------------------------------------------------------------ 9 | # MODNet Basic Modules 10 | #------------------------------------------------------------------------------ 11 | 12 | class IBNorm(nn.Module): 13 | """ Combine Instance Norm and Batch Norm into One Layer 14 | """ 15 | 16 | def __init__(self, in_channels): 17 | super(IBNorm, self).__init__() 18 | in_channels = in_channels 19 | self.bnorm_channels = int(in_channels / 2) 20 | self.inorm_channels = in_channels - self.bnorm_channels 21 | 22 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) 23 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) 24 | 25 | def forward(self, x): 26 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) 27 | in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) 28 | 29 | return torch.cat((bn_x, in_x), 1) 30 | 31 | 32 | class Conv2dIBNormRelu(nn.Module): 33 | """ Convolution + IBNorm + ReLu 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, kernel_size, 37 | stride=1, padding=0, dilation=1, groups=1, bias=True, 38 | with_ibn=True, with_relu=True): 39 | super(Conv2dIBNormRelu, self).__init__() 40 | 41 | layers = [ 42 | nn.Conv2d(in_channels, out_channels, kernel_size, 43 | stride=stride, padding=padding, dilation=dilation, 44 | groups=groups, bias=bias) 45 | ] 46 | 47 | if with_ibn: 48 | layers.append(IBNorm(out_channels)) 49 | if with_relu: 50 | layers.append(nn.ReLU(inplace=True)) 51 | 52 | self.layers = nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | return self.layers(x) 56 | 57 | 58 | class SEBlock(nn.Module): 59 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf 60 | """ 61 | 62 | def __init__(self, in_channels, out_channels, reduction=1): 63 | super(SEBlock, self).__init__() 64 | self.pool = nn.AdaptiveAvgPool2d(1) 65 | self.fc = nn.Sequential( 66 | nn.Linear(in_channels, int(in_channels // reduction), bias=False), 67 | nn.ReLU(inplace=True), 68 | nn.Linear(int(in_channels // reduction), out_channels, bias=False), 69 | nn.Sigmoid() 70 | ) 71 | 72 | def forward(self, x): 73 | b, c, _, _ = x.size() 74 | w = self.pool(x).view(b, c) 75 | w = self.fc(w).view(b, c, 1, 1) 76 | 77 | return x * w.expand_as(x) 78 | 79 | 80 | #------------------------------------------------------------------------------ 81 | # MODNet Branches 82 | #------------------------------------------------------------------------------ 83 | 84 | class LRBranch(nn.Module): 85 | """ Low Resolution Branch of MODNet 86 | """ 87 | 88 | def __init__(self, backbone): 89 | super(LRBranch, self).__init__() 90 | 91 | enc_channels = backbone.enc_channels 92 | 93 | self.backbone = backbone 94 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) 95 | self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) 96 | self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) 97 | self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) 98 | 99 | def forward(self, img, inference): 100 | enc_features = self.backbone.forward(img) 101 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] 102 | 103 | enc32x = self.se_block(enc32x) 104 | lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) 105 | lr16x = self.conv_lr16x(lr16x) 106 | lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) 107 | lr8x = self.conv_lr8x(lr8x) 108 | 109 | pred_semantic = None 110 | if not inference: 111 | lr = self.conv_lr(lr8x) 112 | pred_semantic = torch.sigmoid(lr) 113 | 114 | return pred_semantic, lr8x, [enc2x, enc4x] 115 | 116 | 117 | class HRBranch(nn.Module): 118 | """ High Resolution Branch of MODNet 119 | """ 120 | 121 | def __init__(self, hr_channels, enc_channels): 122 | super(HRBranch, self).__init__() 123 | 124 | self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) 125 | self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) 126 | 127 | self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) 128 | self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) 129 | 130 | self.conv_hr4x = nn.Sequential( 131 | Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), 132 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 133 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 134 | ) 135 | 136 | self.conv_hr2x = nn.Sequential( 137 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 138 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 139 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 140 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 141 | ) 142 | 143 | self.conv_hr = nn.Sequential( 144 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), 145 | Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), 146 | ) 147 | 148 | def forward(self, img, enc2x, enc4x, lr8x, inference): 149 | img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False) 150 | img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False) 151 | 152 | enc2x = self.tohr_enc2x(enc2x) 153 | hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) 154 | 155 | enc4x = self.tohr_enc4x(enc4x) 156 | hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) 157 | 158 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 159 | hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) 160 | 161 | hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) 162 | hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) 163 | 164 | pred_detail = None 165 | if not inference: 166 | hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False) 167 | hr = self.conv_hr(torch.cat((hr, img), dim=1)) 168 | pred_detail = torch.sigmoid(hr) 169 | 170 | return pred_detail, hr2x 171 | 172 | 173 | class FusionBranch(nn.Module): 174 | """ Fusion Branch of MODNet 175 | """ 176 | 177 | def __init__(self, hr_channels, enc_channels): 178 | super(FusionBranch, self).__init__() 179 | self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) 180 | 181 | self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) 182 | self.conv_f = nn.Sequential( 183 | Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), 184 | Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), 185 | ) 186 | 187 | def forward(self, img, lr8x, hr2x): 188 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 189 | lr4x = self.conv_lr4x(lr4x) 190 | lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) 191 | 192 | f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) 193 | f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) 194 | f = self.conv_f(torch.cat((f, img), dim=1)) 195 | pred_matte = torch.sigmoid(f) 196 | 197 | return pred_matte 198 | 199 | 200 | #------------------------------------------------------------------------------ 201 | # MODNet 202 | #------------------------------------------------------------------------------ 203 | 204 | class MODNet(nn.Module): 205 | """ Architecture of MODNet 206 | """ 207 | 208 | def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True): 209 | super(MODNet, self).__init__() 210 | 211 | self.in_channels = in_channels 212 | self.hr_channels = hr_channels 213 | self.backbone_arch = backbone_arch 214 | self.backbone_pretrained = backbone_pretrained 215 | 216 | self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels) 217 | 218 | self.lr_branch = LRBranch(self.backbone) 219 | self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) 220 | self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) 221 | 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d): 224 | self._init_conv(m) 225 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 226 | self._init_norm(m) 227 | 228 | if self.backbone_pretrained: 229 | self.backbone.load_pretrained_ckpt() 230 | 231 | def forward(self, img, inference): 232 | pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference) 233 | pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference) 234 | pred_matte = self.f_branch(img, lr8x, hr2x) 235 | 236 | return pred_semantic, pred_detail, pred_matte 237 | 238 | def freeze_norm(self): 239 | norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] 240 | for m in self.modules(): 241 | for n in norm_types: 242 | if isinstance(m, n): 243 | m.eval() 244 | continue 245 | 246 | def _init_conv(self, conv): 247 | nn.init.kaiming_uniform_( 248 | conv.weight, a=0, mode='fan_in', nonlinearity='relu') 249 | if conv.bias is not None: 250 | nn.init.constant_(conv.bias, 0) 251 | 252 | def _init_norm(self, norm): 253 | if norm.weight is not None: 254 | nn.init.constant_(norm.weight, 1) 255 | nn.init.constant_(norm.bias, 0) 256 | -------------------------------------------------------------------------------- /src/models/onnx_modnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is a modified version of the original file modnet.py without 3 | "pred_semantic" and "pred_details" as these both returns None when "inference = True" 4 | 5 | And it does not contain "inference" argument which will make it easier to 6 | convert checkpoint into onnx model. 7 | 8 | Refer: 'demo/image_matting/inference_with_ONNX/export_modnet_onnx.py' to export model. 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from .backbones import SUPPORTED_BACKBONES 16 | 17 | 18 | #------------------------------------------------------------------------------ 19 | # MODNet Basic Modules 20 | #------------------------------------------------------------------------------ 21 | 22 | class IBNorm(nn.Module): 23 | """ Combine Instance Norm and Batch Norm into One Layer 24 | """ 25 | 26 | def __init__(self, in_channels): 27 | super(IBNorm, self).__init__() 28 | in_channels = in_channels 29 | self.bnorm_channels = int(in_channels / 2) 30 | self.inorm_channels = in_channels - self.bnorm_channels 31 | 32 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) 33 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) 34 | 35 | def forward(self, x): 36 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) 37 | in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) 38 | 39 | return torch.cat((bn_x, in_x), 1) 40 | 41 | 42 | class Conv2dIBNormRelu(nn.Module): 43 | """ Convolution + IBNorm + ReLu 44 | """ 45 | 46 | def __init__(self, in_channels, out_channels, kernel_size, 47 | stride=1, padding=0, dilation=1, groups=1, bias=True, 48 | with_ibn=True, with_relu=True): 49 | super(Conv2dIBNormRelu, self).__init__() 50 | 51 | layers = [ 52 | nn.Conv2d(in_channels, out_channels, kernel_size, 53 | stride=stride, padding=padding, dilation=dilation, 54 | groups=groups, bias=bias) 55 | ] 56 | 57 | if with_ibn: 58 | layers.append(IBNorm(out_channels)) 59 | if with_relu: 60 | layers.append(nn.ReLU(inplace=True)) 61 | 62 | self.layers = nn.Sequential(*layers) 63 | 64 | def forward(self, x): 65 | return self.layers(x) 66 | 67 | 68 | class SEBlock(nn.Module): 69 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf 70 | """ 71 | 72 | def __init__(self, in_channels, out_channels, reduction=1): 73 | super(SEBlock, self).__init__() 74 | self.pool = nn.AdaptiveAvgPool2d(1) 75 | self.fc = nn.Sequential( 76 | nn.Linear(in_channels, int(in_channels // reduction), bias=False), 77 | nn.ReLU(inplace=True), 78 | nn.Linear(int(in_channels // reduction), out_channels, bias=False), 79 | nn.Sigmoid() 80 | ) 81 | 82 | def forward(self, x): 83 | b, c, _, _ = x.size() 84 | w = self.pool(x).view(b, c) 85 | w = self.fc(w).view(b, c, 1, 1) 86 | 87 | return x * w.expand_as(x) 88 | 89 | 90 | #------------------------------------------------------------------------------ 91 | # MODNet Branches 92 | #------------------------------------------------------------------------------ 93 | 94 | class LRBranch(nn.Module): 95 | """ Low Resolution Branch of MODNet 96 | """ 97 | 98 | def __init__(self, backbone): 99 | super(LRBranch, self).__init__() 100 | 101 | enc_channels = backbone.enc_channels 102 | 103 | self.backbone = backbone 104 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) 105 | self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) 106 | self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) 107 | self.conv_lr = Conv2dIBNormRelu(enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) 108 | 109 | def forward(self, img): 110 | enc_features = self.backbone.forward(img) 111 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] 112 | 113 | enc32x = self.se_block(enc32x) 114 | lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) 115 | lr16x = self.conv_lr16x(lr16x) 116 | lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) 117 | lr8x = self.conv_lr8x(lr8x) 118 | 119 | return lr8x, [enc2x, enc4x] 120 | 121 | 122 | class HRBranch(nn.Module): 123 | """ High Resolution Branch of MODNet 124 | """ 125 | 126 | def __init__(self, hr_channels, enc_channels): 127 | super(HRBranch, self).__init__() 128 | 129 | self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) 130 | self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) 131 | 132 | self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) 133 | self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) 134 | 135 | self.conv_hr4x = nn.Sequential( 136 | Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), 137 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 138 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 139 | ) 140 | 141 | self.conv_hr2x = nn.Sequential( 142 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 143 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 144 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 145 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 146 | ) 147 | 148 | self.conv_hr = nn.Sequential( 149 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), 150 | Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), 151 | ) 152 | 153 | def forward(self, img, enc2x, enc4x, lr8x): 154 | img2x = F.interpolate(img, scale_factor=1/2, mode='bilinear', align_corners=False) 155 | img4x = F.interpolate(img, scale_factor=1/4, mode='bilinear', align_corners=False) 156 | 157 | enc2x = self.tohr_enc2x(enc2x) 158 | hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) 159 | 160 | enc4x = self.tohr_enc4x(enc4x) 161 | hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) 162 | 163 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 164 | hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) 165 | 166 | hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) 167 | hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) 168 | 169 | return hr2x 170 | 171 | 172 | class FusionBranch(nn.Module): 173 | """ Fusion Branch of MODNet 174 | """ 175 | 176 | def __init__(self, hr_channels, enc_channels): 177 | super(FusionBranch, self).__init__() 178 | self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) 179 | 180 | self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) 181 | self.conv_f = nn.Sequential( 182 | Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), 183 | Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), 184 | ) 185 | 186 | def forward(self, img, lr8x, hr2x): 187 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 188 | lr4x = self.conv_lr4x(lr4x) 189 | lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) 190 | 191 | f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) 192 | f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) 193 | f = self.conv_f(torch.cat((f, img), dim=1)) 194 | pred_matte = torch.sigmoid(f) 195 | 196 | return pred_matte 197 | 198 | 199 | #------------------------------------------------------------------------------ 200 | # MODNet 201 | #------------------------------------------------------------------------------ 202 | 203 | class MODNet(nn.Module): 204 | """ Architecture of MODNet 205 | """ 206 | 207 | def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True): 208 | super(MODNet, self).__init__() 209 | 210 | self.in_channels = in_channels 211 | self.hr_channels = hr_channels 212 | self.backbone_arch = backbone_arch 213 | self.backbone_pretrained = backbone_pretrained 214 | 215 | self.backbone = SUPPORTED_BACKBONES[self.backbone_arch](self.in_channels) 216 | 217 | self.lr_branch = LRBranch(self.backbone) 218 | self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) 219 | self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) 220 | 221 | for m in self.modules(): 222 | if isinstance(m, nn.Conv2d): 223 | self._init_conv(m) 224 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 225 | self._init_norm(m) 226 | 227 | if self.backbone_pretrained: 228 | self.backbone.load_pretrained_ckpt() 229 | 230 | def forward(self, img): 231 | lr8x, [enc2x, enc4x] = self.lr_branch(img) 232 | hr2x = self.hr_branch(img, enc2x, enc4x, lr8x) 233 | pred_matte = self.f_branch(img, lr8x, hr2x) 234 | 235 | return pred_matte 236 | 237 | def freeze_norm(self): 238 | norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] 239 | for m in self.modules(): 240 | for n in norm_types: 241 | if isinstance(m, n): 242 | m.eval() 243 | continue 244 | 245 | def _init_conv(self, conv): 246 | nn.init.kaiming_uniform_( 247 | conv.weight, a=0, mode='fan_in', nonlinearity='relu') 248 | if conv.bias is not None: 249 | nn.init.constant_(conv.bias, 0) 250 | 251 | def _init_norm(self, norm): 252 | if norm.weight is not None: 253 | nn.init.constant_(norm.weight, 1) 254 | nn.init.constant_(norm.bias, 0) 255 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import scipy 3 | import numpy as np 4 | from scipy.ndimage import grey_dilation, grey_erosion 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | __all__ = [ 12 | 'supervised_training_iter', 13 | 'soc_adaptation_iter', 14 | ] 15 | 16 | 17 | # ---------------------------------------------------------------------------------- 18 | # Tool Classes/Functions 19 | # ---------------------------------------------------------------------------------- 20 | 21 | class GaussianBlurLayer(nn.Module): 22 | """ Add Gaussian Blur to a 4D tensors 23 | This layer takes a 4D tensor of {N, C, H, W} as input. 24 | The Gaussian blur will be performed in given channel number (C) splitly. 25 | """ 26 | 27 | def __init__(self, channels, kernel_size): 28 | """ 29 | Arguments: 30 | channels (int): Channel for input tensor 31 | kernel_size (int): Size of the kernel used in blurring 32 | """ 33 | 34 | super(GaussianBlurLayer, self).__init__() 35 | self.channels = channels 36 | self.kernel_size = kernel_size 37 | assert self.kernel_size % 2 != 0 38 | 39 | self.op = nn.Sequential( 40 | nn.ReflectionPad2d(math.floor(self.kernel_size / 2)), 41 | nn.Conv2d(channels, channels, self.kernel_size, 42 | stride=1, padding=0, bias=None, groups=channels) 43 | ) 44 | 45 | self._init_kernel() 46 | 47 | def forward(self, x): 48 | """ 49 | Arguments: 50 | x (torch.Tensor): input 4D tensor 51 | Returns: 52 | torch.Tensor: Blurred version of the input 53 | """ 54 | 55 | if not len(list(x.shape)) == 4: 56 | print('\'GaussianBlurLayer\' requires a 4D tensor as input\n') 57 | exit() 58 | elif not x.shape[1] == self.channels: 59 | print('In \'GaussianBlurLayer\', the required channel ({0}) is' 60 | 'not the same as input ({1})\n'.format(self.channels, x.shape[1])) 61 | exit() 62 | 63 | return self.op(x) 64 | 65 | def _init_kernel(self): 66 | sigma = 0.3 * ((self.kernel_size - 1) * 0.5 - 1) + 0.8 67 | 68 | n = np.zeros((self.kernel_size, self.kernel_size)) 69 | i = math.floor(self.kernel_size / 2) 70 | n[i, i] = 1 71 | kernel = scipy.ndimage.gaussian_filter(n, sigma) 72 | 73 | for name, param in self.named_parameters(): 74 | param.data.copy_(torch.from_numpy(kernel)) 75 | 76 | # ---------------------------------------------------------------------------------- 77 | 78 | 79 | # ---------------------------------------------------------------------------------- 80 | # MODNet Training Functions 81 | # ---------------------------------------------------------------------------------- 82 | 83 | blurer = GaussianBlurLayer(1, 3).cuda() 84 | 85 | 86 | def supervised_training_iter( 87 | modnet, optimizer, image, trimap, gt_matte, 88 | semantic_scale=10.0, detail_scale=10.0, matte_scale=1.0): 89 | """ Supervised training iteration of MODNet 90 | This function trains MODNet for one iteration in a labeled dataset. 91 | 92 | Arguments: 93 | modnet (torch.nn.Module): instance of MODNet 94 | optimizer (torch.optim.Optimizer): optimizer for supervised training 95 | image (torch.autograd.Variable): input RGB image 96 | its pixel values should be normalized 97 | trimap (torch.autograd.Variable): trimap used to calculate the losses 98 | its pixel values can be 0, 0.5, or 1 99 | (foreground=1, background=0, unknown=0.5) 100 | gt_matte (torch.autograd.Variable): ground truth alpha matte 101 | its pixel values are between [0, 1] 102 | semantic_scale (float): scale of the semantic loss 103 | NOTE: please adjust according to your dataset 104 | detail_scale (float): scale of the detail loss 105 | NOTE: please adjust according to your dataset 106 | matte_scale (float): scale of the matte loss 107 | NOTE: please adjust according to your dataset 108 | 109 | Returns: 110 | semantic_loss (torch.Tensor): loss of the semantic estimation [Low-Resolution (LR) Branch] 111 | detail_loss (torch.Tensor): loss of the detail prediction [High-Resolution (HR) Branch] 112 | matte_loss (torch.Tensor): loss of the semantic-detail fusion [Fusion Branch] 113 | 114 | Example: 115 | import torch 116 | from src.models.modnet import MODNet 117 | from src.trainer import supervised_training_iter 118 | 119 | bs = 16 # batch size 120 | lr = 0.01 # learn rate 121 | epochs = 40 # total epochs 122 | 123 | modnet = torch.nn.DataParallel(MODNet()).cuda() 124 | optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9) 125 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1) 126 | 127 | dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function 128 | 129 | for epoch in range(0, epochs): 130 | for idx, (image, trimap, gt_matte) in enumerate(dataloader): 131 | semantic_loss, detail_loss, matte_loss = \ 132 | supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) 133 | lr_scheduler.step() 134 | """ 135 | 136 | global blurer 137 | 138 | # set the model to train mode and clear the optimizer 139 | modnet.train() 140 | optimizer.zero_grad() 141 | 142 | # forward the model 143 | pred_semantic, pred_detail, pred_matte = modnet(image, False) 144 | 145 | # calculate the boundary mask from the trimap 146 | boundaries = (trimap < 0.5) + (trimap > 0.5) 147 | 148 | # calculate the semantic loss 149 | gt_semantic = F.interpolate(gt_matte, scale_factor=1/16, mode='bilinear') 150 | gt_semantic = blurer(gt_semantic) 151 | semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic)) 152 | semantic_loss = semantic_scale * semantic_loss 153 | 154 | # calculate the detail loss 155 | pred_boundary_detail = torch.where(boundaries, trimap, pred_detail) 156 | gt_detail = torch.where(boundaries, trimap, gt_matte) 157 | detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail)) 158 | detail_loss = detail_scale * detail_loss 159 | 160 | # calculate the matte loss 161 | pred_boundary_matte = torch.where(boundaries, trimap, pred_matte) 162 | matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte) 163 | matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \ 164 | + 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte) 165 | matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss) 166 | matte_loss = matte_scale * matte_loss 167 | 168 | # calculate the final loss, backward the loss, and update the model 169 | loss = semantic_loss + detail_loss + matte_loss 170 | loss.backward() 171 | optimizer.step() 172 | 173 | # for test 174 | return semantic_loss, detail_loss, matte_loss 175 | 176 | 177 | def soc_adaptation_iter( 178 | modnet, backup_modnet, optimizer, image, 179 | soc_semantic_scale=100.0, soc_detail_scale=1.0): 180 | """ Self-Supervised sub-objective consistency (SOC) adaptation iteration of MODNet 181 | This function fine-tunes MODNet for one iteration in an unlabeled dataset. 182 | Note that SOC can only fine-tune a converged MODNet, i.e., MODNet that has been 183 | trained in a labeled dataset. 184 | 185 | Arguments: 186 | modnet (torch.nn.Module): instance of MODNet 187 | backup_modnet (torch.nn.Module): backup of the trained MODNet 188 | optimizer (torch.optim.Optimizer): optimizer for self-supervised SOC 189 | image (torch.autograd.Variable): input RGB image 190 | its pixel values should be normalized 191 | soc_semantic_scale (float): scale of the SOC semantic loss 192 | NOTE: please adjust according to your dataset 193 | soc_detail_scale (float): scale of the SOC detail loss 194 | NOTE: please adjust according to your dataset 195 | 196 | Returns: 197 | soc_semantic_loss (torch.Tensor): loss of the semantic SOC 198 | soc_detail_loss (torch.Tensor): loss of the detail SOC 199 | 200 | Example: 201 | import copy 202 | import torch 203 | from src.models.modnet import MODNet 204 | from src.trainer import soc_adaptation_iter 205 | 206 | bs = 1 # batch size 207 | lr = 0.00001 # learn rate 208 | epochs = 10 # total epochs 209 | 210 | modnet = torch.nn.DataParallel(MODNet()).cuda() 211 | modnet = LOAD_TRAINED_CKPT() # NOTE: please finish this function 212 | 213 | optimizer = torch.optim.Adam(modnet.parameters(), lr=lr, betas=(0.9, 0.99)) 214 | dataloader = CREATE_YOUR_DATALOADER(bs) # NOTE: please finish this function 215 | 216 | for epoch in range(0, epochs): 217 | backup_modnet = copy.deepcopy(modnet) 218 | for idx, (image) in enumerate(dataloader): 219 | soc_semantic_loss, soc_detail_loss = \ 220 | soc_adaptation_iter(modnet, backup_modnet, optimizer, image) 221 | """ 222 | 223 | global blurer 224 | 225 | # set the backup model to eval mode 226 | backup_modnet.eval() 227 | 228 | # set the main model to train mode and freeze its norm layers 229 | modnet.train() 230 | modnet.module.freeze_norm() 231 | 232 | # clear the optimizer 233 | optimizer.zero_grad() 234 | 235 | # forward the main model 236 | pred_semantic, pred_detail, pred_matte = modnet(image, False) 237 | 238 | # forward the backup model 239 | with torch.no_grad(): 240 | _, pred_backup_detail, pred_backup_matte = backup_modnet(image, False) 241 | 242 | # calculate the boundary mask from `pred_matte` and `pred_semantic` 243 | pred_matte_fg = (pred_matte.detach() > 0.1).float() 244 | pred_semantic_fg = (pred_semantic.detach() > 0.1).float() 245 | pred_semantic_fg = F.interpolate(pred_semantic_fg, scale_factor=16, mode='bilinear') 246 | pred_fg = pred_matte_fg * pred_semantic_fg 247 | 248 | n, c, h, w = pred_matte.shape 249 | np_pred_fg = pred_fg.data.cpu().numpy() 250 | np_boundaries = np.zeros([n, c, h, w]) 251 | for sdx in range(0, n): 252 | sample_np_boundaries = np_boundaries[sdx, 0, ...] 253 | sample_np_pred_fg = np_pred_fg[sdx, 0, ...] 254 | 255 | side = int((h + w) / 2 * 0.05) 256 | dilated = grey_dilation(sample_np_pred_fg, size=(side, side)) 257 | eroded = grey_erosion(sample_np_pred_fg, size=(side, side)) 258 | 259 | sample_np_boundaries[np.where(dilated - eroded != 0)] = 1 260 | np_boundaries[sdx, 0, ...] = sample_np_boundaries 261 | 262 | boundaries = torch.tensor(np_boundaries).float().cuda() 263 | 264 | # sub-objectives consistency between `pred_semantic` and `pred_matte` 265 | # generate pseudo ground truth for `pred_semantic` 266 | downsampled_pred_matte = blurer(F.interpolate(pred_matte, scale_factor=1/16, mode='bilinear')) 267 | pseudo_gt_semantic = downsampled_pred_matte.detach() 268 | pseudo_gt_semantic = pseudo_gt_semantic * (pseudo_gt_semantic > 0.01).float() 269 | 270 | # generate pseudo ground truth for `pred_matte` 271 | pseudo_gt_matte = pred_semantic.detach() 272 | pseudo_gt_matte = pseudo_gt_matte * (pseudo_gt_matte > 0.01).float() 273 | 274 | # calculate the SOC semantic loss 275 | soc_semantic_loss = F.mse_loss(pred_semantic, pseudo_gt_semantic) + F.mse_loss(downsampled_pred_matte, pseudo_gt_matte) 276 | soc_semantic_loss = soc_semantic_scale * torch.mean(soc_semantic_loss) 277 | 278 | # NOTE: using the formulas in our paper to calculate the following losses has similar results 279 | # sub-objectives consistency between `pred_detail` and `pred_backup_detail` (on boundaries only) 280 | backup_detail_loss = boundaries * F.l1_loss(pred_detail, pred_backup_detail) 281 | backup_detail_loss = torch.sum(backup_detail_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) 282 | backup_detail_loss = torch.mean(backup_detail_loss) 283 | 284 | # sub-objectives consistency between pred_matte` and `pred_backup_matte` (on boundaries only) 285 | backup_matte_loss = boundaries * F.l1_loss(pred_matte, pred_backup_matte) 286 | backup_matte_loss = torch.sum(backup_matte_loss, dim=(1,2,3)) / torch.sum(boundaries, dim=(1,2,3)) 287 | backup_matte_loss = torch.mean(backup_matte_loss) 288 | 289 | soc_detail_loss = soc_detail_scale * (backup_detail_loss + backup_matte_loss) 290 | 291 | # calculate the final loss, backward the loss, and update the model 292 | loss = soc_semantic_loss + soc_detail_loss 293 | 294 | loss.backward() 295 | optimizer.step() 296 | 297 | return soc_semantic_loss, soc_detail_loss 298 | 299 | # ---------------------------------------------------------------------------------- 300 | --------------------------------------------------------------------------------