├── LICENSE ├── README.md ├── README_zh_Hans.md ├── dataset ├── augmentation.py ├── coco.py ├── imagematte.py ├── spd.py ├── videomatte.py └── youtubevis.py ├── documentation ├── image │ ├── showreel.gif │ └── teaser.gif ├── inference.md ├── inference_zh_Hans.md ├── misc │ ├── aim_test.txt │ ├── d646_test.txt │ ├── dvm_background_test_clips.txt │ ├── dvm_background_train_clips.txt │ ├── imagematte_train.txt │ ├── imagematte_valid.txt │ └── spd_preprocess.py └── training.md ├── evaluation ├── evaluate_hr.py ├── evaluate_lr.py ├── generate_imagematte_with_background_image.py ├── generate_imagematte_with_background_video.py ├── generate_videomatte_with_background_image.py └── generate_videomatte_with_background_video.py ├── hubconf.py ├── inference.py ├── inference_speed_test.py ├── inference_utils.py ├── model ├── __init__.py ├── decoder.py ├── deep_guided_filter.py ├── fast_guided_filter.py ├── lraspp.py ├── mobilenetv3.py ├── model.py └── resnet.py ├── requirements_inference.txt ├── requirements_training.txt ├── train.py ├── train_config.py └── train_loss.py /README.md: -------------------------------------------------------------------------------- 1 | # Robust Video Matting (RVM) 2 | 3 | ![Teaser](/documentation/image/teaser.gif) 4 | 5 |

English | 中文

6 | 7 | Official repository for the paper [Robust High-Resolution Video Matting with Temporal Guidance](https://peterl1n.github.io/RobustVideoMatting/). RVM is specifically designed for robust human video matting. Unlike existing neural models that process frames as independent images, RVM uses a recurrent neural network to process videos with temporal memory. RVM can perform matting in real-time on any videos without additional inputs. It achieves **4K 76FPS** and **HD 104FPS** on an Nvidia GTX 1080 Ti GPU. The project was developed at [ByteDance Inc.](https://www.bytedance.com/) 8 | 9 |
10 | 11 | ## News 12 | 13 | * [Nov 03 2021] Fixed a bug in [train.py](https://github.com/PeterL1n/RobustVideoMatting/commit/48effc91576a9e0e7a8519f3da687c0d3522045f). 14 | * [Sep 16 2021] Code is re-released under GPL-3.0 license. 15 | * [Aug 25 2021] Source code and pretrained models are published. 16 | * [Jul 27 2021] Paper is accepted by WACV 2022. 17 | 18 |
19 | 20 | ## Showreel 21 | Watch the showreel video ([YouTube](https://youtu.be/Jvzltozpbpk), [Bilibili](https://www.bilibili.com/video/BV1Z3411B7g7/)) to see the model's performance. 22 | 23 |

24 | 25 | 26 | 27 |

28 | 29 | All footage in the video are available in [Google Drive](https://drive.google.com/drive/folders/1VFnWwuu-YXDKG-N6vcjK_nL7YZMFapMU?usp=sharing). 30 | 31 |
32 | 33 | 34 | ## Demo 35 | * [Webcam Demo](https://peterl1n.github.io/RobustVideoMatting/#/demo): Run the model live in your browser. Visualize recurrent states. 36 | * [Colab Demo](https://colab.research.google.com/drive/10z-pNKRnVNsp0Lq9tH1J_XPZ7CBC_uHm?usp=sharing): Test our model on your own videos with free GPU. 37 | 38 |
39 | 40 | ## Download 41 | 42 | We recommend MobileNetv3 models for most use cases. ResNet50 models are the larger variant with small performance improvements. Our model is available on various inference frameworks. See [inference documentation](documentation/inference.md) for more instructions. 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 59 | 62 | 63 | 64 | 65 | 68 | 71 | 72 | 73 | 74 | 80 | 83 | 84 | 85 | 86 | 92 | 95 | 96 | 97 | 98 | 102 | 105 | 106 | 107 | 108 | 111 | 114 | 115 | 116 | 117 | 123 | 126 | 127 | 128 |
FrameworkDownloadNotes
PyTorch 56 | rvm_mobilenetv3.pth
57 | rvm_resnet50.pth 58 |
60 | Official weights for PyTorch. Doc 61 |
TorchHub 66 | Nothing to Download. 67 | 69 | Easiest way to use our model in your PyTorch project. Doc 70 |
TorchScript 75 | rvm_mobilenetv3_fp32.torchscript
76 | rvm_mobilenetv3_fp16.torchscript
77 | rvm_resnet50_fp32.torchscript
78 | rvm_resnet50_fp16.torchscript 79 |
81 | If inference on mobile, consider export int8 quantized models yourself. Doc 82 |
ONNX 87 | rvm_mobilenetv3_fp32.onnx
88 | rvm_mobilenetv3_fp16.onnx
89 | rvm_resnet50_fp32.onnx
90 | rvm_resnet50_fp16.onnx 91 |
93 | Tested on ONNX Runtime with CPU and CUDA backends. Provided models use opset 12. Doc, Exporter. 94 |
TensorFlow 99 | rvm_mobilenetv3_tf.zip
100 | rvm_resnet50_tf.zip 101 |
103 | TensorFlow 2 SavedModel. Doc 104 |
TensorFlow.js 109 | rvm_mobilenetv3_tfjs_int8.zip
110 |
112 | Run the model on the web. Demo, Starter Code 113 |
CoreML 118 | rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
119 | rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
120 | rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
121 | rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
122 |
124 | CoreML does not support dynamic resolution. Other resolutions can be exported yourself. Models require iOS 13+. s denotes downsample_ratio. Doc, Exporter 125 |
129 | 130 | All models are available in [Google Drive](https://drive.google.com/drive/folders/1pBsG-SCTatv-95SnEuxmnvvlRx208VKj?usp=sharing) and [Baidu Pan](https://pan.baidu.com/s/1puPSxQqgBFOVpW4W7AolkA) (code: gym7). 131 | 132 |
133 | 134 | ## PyTorch Example 135 | 136 | 1. Install dependencies: 137 | ```sh 138 | pip install -r requirements_inference.txt 139 | ``` 140 | 141 | 2. Load the model: 142 | 143 | ```python 144 | import torch 145 | from model import MattingNetwork 146 | 147 | model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50" 148 | model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) 149 | ``` 150 | 151 | 3. To convert videos, we provide a simple conversion API: 152 | 153 | ```python 154 | from inference import convert_video 155 | 156 | convert_video( 157 | model, # The model, can be on any device (cpu or cuda). 158 | input_source='input.mp4', # A video file or an image sequence directory. 159 | output_type='video', # Choose "video" or "png_sequence" 160 | output_composition='com.mp4', # File path if video; directory path if png sequence. 161 | output_alpha="pha.mp4", # [Optional] Output the raw alpha prediction. 162 | output_foreground="fgr.mp4", # [Optional] Output the raw foreground prediction. 163 | output_video_mbps=4, # Output video mbps. Not needed for png sequence. 164 | downsample_ratio=None, # A hyperparameter to adjust or use None for auto. 165 | seq_chunk=12, # Process n frames at once for better parallelism. 166 | ) 167 | ``` 168 | 169 | 4. Or write your own inference code: 170 | ```python 171 | from torch.utils.data import DataLoader 172 | from torchvision.transforms import ToTensor 173 | from inference_utils import VideoReader, VideoWriter 174 | 175 | reader = VideoReader('input.mp4', transform=ToTensor()) 176 | writer = VideoWriter('output.mp4', frame_rate=30) 177 | 178 | bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background. 179 | rec = [None] * 4 # Initial recurrent states. 180 | downsample_ratio = 0.25 # Adjust based on your video. 181 | 182 | with torch.no_grad(): 183 | for src in DataLoader(reader): # RGB tensor normalized to 0 ~ 1. 184 | fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # Cycle the recurrent states. 185 | com = fgr * pha + bgr * (1 - pha) # Composite to green background. 186 | writer.write(com) # Write frame. 187 | ``` 188 | 189 | 5. The models and converter API are also available through TorchHub. 190 | 191 | ```python 192 | # Load the model. 193 | model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50" 194 | 195 | # Converter API. 196 | convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") 197 | ``` 198 | 199 | Please see [inference documentation](documentation/inference.md) for details on `downsample_ratio` hyperparameter, more converter arguments, and more advanced usage. 200 | 201 |
202 | 203 | ## Training and Evaluation 204 | 205 | Please refer to the [training documentation](documentation/training.md) to train and evaluate your own model. 206 | 207 |
208 | 209 | ## Speed 210 | 211 | Speed is measured with `inference_speed_test.py` for reference. 212 | 213 | | GPU | dType | HD (1920x1080) | 4K (3840x2160) | 214 | | -------------- | ----- | -------------- |----------------| 215 | | RTX 3090 | FP16 | 172 FPS | 154 FPS | 216 | | RTX 2060 Super | FP16 | 134 FPS | 108 FPS | 217 | | GTX 1080 Ti | FP32 | 104 FPS | 74 FPS | 218 | 219 | * Note 1: HD uses `downsample_ratio=0.25`, 4K uses `downsample_ratio=0.125`. All tests use batch size 1 and frame chunk 1. 220 | * Note 2: GPUs before Turing architecture does not support FP16 inference, so GTX 1080 Ti uses FP32. 221 | * Note 3: We only measure tensor throughput. The provided video conversion script in this repo is expected to be much slower, because it does not utilize hardware video encoding/decoding and does not have the tensor transfer done on parallel threads. If you are interested in implementing hardware video encoding/decoding in Python, please refer to [PyNvCodec](https://github.com/NVIDIA/VideoProcessingFramework). 222 | 223 |
224 | 225 | ## Project Members 226 | * [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/) 227 | * [Linjie Yang](https://sites.google.com/site/linjieyang89/) 228 | * [Imran Saleemi](https://www.linkedin.com/in/imran-saleemi/) 229 | * [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/) 230 | 231 |
232 | 233 | ## Third-Party Projects 234 | 235 | * [NCNN C++ Android](https://github.com/FeiGeChuanShu/ncnn_Android_RobustVideoMatting) ([@FeiGeChuanShu](https://github.com/FeiGeChuanShu)) 236 | * [lite.ai.toolkit](https://github.com/DefTruth/RobustVideoMatting.lite.ai.toolkit) ([@DefTruth](https://github.com/DefTruth)) 237 | * [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Robust-Video-Matting) ([@AK391](https://github.com/AK391)) 238 | * [Unity Engine demo with NatML](https://hub.natml.ai/@natsuite/robust-video-matting) ([@natsuite](https://github.com/natsuite)) 239 | * [MNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth)) 240 | * [TNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth)) 241 | 242 | -------------------------------------------------------------------------------- /README_zh_Hans.md: -------------------------------------------------------------------------------- 1 | # 稳定视频抠像 (RVM) 2 | 3 | ![Teaser](/documentation/image/teaser.gif) 4 | 5 |

English | 中文

6 | 7 | 论文 [Robust High-Resolution Video Matting with Temporal Guidance](https://peterl1n.github.io/RobustVideoMatting/) 的官方 GitHub 库。RVM 专为稳定人物视频抠像设计。不同于现有神经网络将每一帧作为单独图片处理,RVM 使用循环神经网络,在处理视频流时有时间记忆。RVM 可在任意视频上做实时高清抠像。在 Nvidia GTX 1080Ti 上实现 **4K 76FPS** 和 **HD 104FPS**。此研究项目来自[字节跳动](https://www.bytedance.com/)。 8 | 9 |
10 | 11 | ## 更新 12 | 13 | * [2021年11月3日] 修复了 [train.py](https://github.com/PeterL1n/RobustVideoMatting/commit/48effc91576a9e0e7a8519f3da687c0d3522045f) 的 bug。 14 | * [2021年9月16日] 代码重新以 GPL-3.0 许可发布。 15 | * [2021年8月25日] 公开代码和模型。 16 | * [2021年7月27日] 论文被 WACV 2022 收录。 17 | 18 |
19 | 20 | ## 展示视频 21 | 观看展示视频 ([YouTube](https://youtu.be/Jvzltozpbpk), [Bilibili](https://www.bilibili.com/video/BV1Z3411B7g7/)),了解模型能力。 22 |

23 | 24 | 25 | 26 |

27 | 28 | 视频中的所有素材都提供下载,可用于测试模型:[Google Drive](https://drive.google.com/drive/folders/1VFnWwuu-YXDKG-N6vcjK_nL7YZMFapMU?usp=sharing) 29 | 30 |
31 | 32 | 33 | ## Demo 34 | * [网页](https://peterl1n.github.io/RobustVideoMatting/#/demo): 在浏览器里看摄像头抠像效果,展示模型内部循环记忆值。 35 | * [Colab](https://colab.research.google.com/drive/10z-pNKRnVNsp0Lq9tH1J_XPZ7CBC_uHm?usp=sharing): 用我们的模型转换你的视频。 36 | 37 |
38 | 39 | ## 下载 40 | 41 | 推荐在通常情况下使用 MobileNetV3 的模型。ResNet50 的模型大很多,效果稍有提高。我们的模型支持很多框架。详情请阅读[推断文档](documentation/inference_zh_Hans.md)。 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 58 | 61 | 62 | 63 | 64 | 67 | 70 | 71 | 72 | 73 | 79 | 82 | 83 | 84 | 85 | 91 | 94 | 95 | 96 | 97 | 101 | 104 | 105 | 106 | 107 | 110 | 113 | 114 | 115 | 116 | 122 | 125 | 126 | 127 |
框架下载备注
PyTorch 55 | rvm_mobilenetv3.pth
56 | rvm_resnet50.pth 57 |
59 | 官方 PyTorch 模型权值。文档 60 |
TorchHub 65 | 无需手动下载。 66 | 68 | 更方便地在你的 PyTorch 项目里使用此模型。文档 69 |
TorchScript 74 | rvm_mobilenetv3_fp32.torchscript
75 | rvm_mobilenetv3_fp16.torchscript
76 | rvm_resnet50_fp32.torchscript
77 | rvm_resnet50_fp16.torchscript 78 |
80 | 若需在移动端推断,可以考虑自行导出 int8 量化的模型。文档 81 |
ONNX 86 | rvm_mobilenetv3_fp32.onnx
87 | rvm_mobilenetv3_fp16.onnx
88 | rvm_resnet50_fp32.onnx
89 | rvm_resnet50_fp16.onnx 90 |
92 | 在 ONNX Runtime 的 CPU 和 CUDA backend 上测试过。提供的模型用 opset 12。文档导出 93 |
TensorFlow 98 | rvm_mobilenetv3_tf.zip
99 | rvm_resnet50_tf.zip 100 |
102 | TensorFlow 2 SavedModel 格式。文档 103 |
TensorFlow.js 108 | rvm_mobilenetv3_tfjs_int8.zip
109 |
111 | 在网页上跑模型。展示示范代码 112 |
CoreML 117 | rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
118 | rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
119 | rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
120 | rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
121 |
123 | CoreML 只能导出固定分辨率,其他分辨率可自行导出。支持 iOS 13+。s 代表下采样比。文档导出 124 |
128 | 129 | 所有模型可在 [Google Drive](https://drive.google.com/drive/folders/1pBsG-SCTatv-95SnEuxmnvvlRx208VKj?usp=sharing) 或[百度网盘](https://pan.baidu.com/s/1puPSxQqgBFOVpW4W7AolkA)(密码: gym7)上下载。 130 | 131 |
132 | 133 | ## PyTorch 范例 134 | 135 | 1. 安装 Python 库: 136 | ```sh 137 | pip install -r requirements_inference.txt 138 | ``` 139 | 140 | 2. 加载模型: 141 | 142 | ```python 143 | import torch 144 | from model import MattingNetwork 145 | 146 | model = MattingNetwork('mobilenetv3').eval().cuda() # 或 "resnet50" 147 | model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) 148 | ``` 149 | 150 | 3. 若只需要做视频抠像处理,我们提供简单的 API: 151 | 152 | ```python 153 | from inference import convert_video 154 | 155 | convert_video( 156 | model, # 模型,可以加载到任何设备(cpu 或 cuda) 157 | input_source='input.mp4', # 视频文件,或图片序列文件夹 158 | output_type='video', # 可选 "video"(视频)或 "png_sequence"(PNG 序列) 159 | output_composition='com.mp4', # 若导出视频,提供文件路径。若导出 PNG 序列,提供文件夹路径 160 | output_alpha="pha.mp4", # [可选项] 输出透明度预测 161 | output_foreground="fgr.mp4", # [可选项] 输出前景预测 162 | output_video_mbps=4, # 若导出视频,提供视频码率 163 | downsample_ratio=None, # 下采样比,可根据具体视频调节,或 None 选择自动 164 | seq_chunk=12, # 设置多帧并行计算 165 | ) 166 | ``` 167 | 168 | 4. 或自己写推断逻辑: 169 | ```python 170 | from torch.utils.data import DataLoader 171 | from torchvision.transforms import ToTensor 172 | from inference_utils import VideoReader, VideoWriter 173 | 174 | reader = VideoReader('input.mp4', transform=ToTensor()) 175 | writer = VideoWriter('output.mp4', frame_rate=30) 176 | 177 | bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # 绿背景 178 | rec = [None] * 4 # 初始循环记忆(Recurrent States) 179 | downsample_ratio = 0.25 # 下采样比,根据视频调节 180 | 181 | with torch.no_grad(): 182 | for src in DataLoader(reader): # 输入张量,RGB通道,范围为 0~1 183 | fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # 将上一帧的记忆给下一帧 184 | com = fgr * pha + bgr * (1 - pha) # 将前景合成到绿色背景 185 | writer.write(com) # 输出帧 186 | ``` 187 | 188 | 5. 模型和 API 也可通过 TorchHub 快速载入。 189 | 190 | ```python 191 | # 加载模型 192 | model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # 或 "resnet50" 193 | 194 | # 转换 API 195 | convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") 196 | ``` 197 | 198 | [推断文档](documentation/inference_zh_Hans.md)里有对 `downsample_ratio` 参数,API 使用,和高阶使用的讲解。 199 | 200 |
201 | 202 | ## 训练和评估 203 | 204 | 请参照[训练文档(英文)](documentation/training.md)。 205 | 206 |
207 | 208 | ## 速度 209 | 210 | 速度用 `inference_speed_test.py` 测量以供参考。 211 | 212 | | GPU | dType | HD (1920x1080) | 4K (3840x2160) | 213 | | -------------- | ----- | -------------- |----------------| 214 | | RTX 3090 | FP16 | 172 FPS | 154 FPS | 215 | | RTX 2060 Super | FP16 | 134 FPS | 108 FPS | 216 | | GTX 1080 Ti | FP32 | 104 FPS | 74 FPS | 217 | 218 | * 注释1:HD 使用 `downsample_ratio=0.25`,4K 使用 `downsample_ratio=0.125`。 所有测试都使用 batch size 1 和 frame chunk 1。 219 | * 注释2:图灵架构之前的 GPU 不支持 FP16 推理,所以 GTX 1080 Ti 使用 FP32。 220 | * 注释3:我们只测量张量吞吐量(tensor throughput)。 提供的视频转换脚本会慢得多,因为它不使用硬件视频编码/解码,也没有在并行线程上完成张量传输。如果您有兴趣在 Python 中实现硬件视频编码/解码,请参考 [PyNvCodec](https://github.com/NVIDIA/VideoProcessingFramework)。 221 | 222 |
223 | 224 | ## 项目成员 225 | * [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/) 226 | * [Linjie Yang](https://sites.google.com/site/linjieyang89/) 227 | * [Imran Saleemi](https://www.linkedin.com/in/imran-saleemi/) 228 | * [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/) 229 | 230 |
231 | 232 | ## 第三方资源 233 | 234 | * [NCNN C++ Android](https://github.com/FeiGeChuanShu/ncnn_Android_RobustVideoMatting) ([@FeiGeChuanShu](https://github.com/FeiGeChuanShu)) 235 | * [lite.ai.toolkit](https://github.com/DefTruth/RobustVideoMatting.lite.ai.toolkit) ([@DefTruth](https://github.com/DefTruth)) 236 | * [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/Robust-Video-Matting) ([@AK391](https://github.com/AK391)) 237 | * [带有 NatML 的 Unity 引擎](https://hub.natml.ai/@natsuite/robust-video-matting) ([@natsuite](https://github.com/natsuite)) 238 | * [MNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth)) 239 | * [TNN C++ Demo](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_rvm.cpp) ([@DefTruth](https://github.com/DefTruth)) 240 | 241 | -------------------------------------------------------------------------------- /dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | import easing_functions as ef 2 | import random 3 | import torch 4 | from torchvision import transforms 5 | from torchvision.transforms import functional as F 6 | 7 | 8 | class MotionAugmentation: 9 | def __init__(self, 10 | size, 11 | prob_fgr_affine, 12 | prob_bgr_affine, 13 | prob_noise, 14 | prob_color_jitter, 15 | prob_grayscale, 16 | prob_sharpness, 17 | prob_blur, 18 | prob_hflip, 19 | prob_pause, 20 | static_affine=True, 21 | aspect_ratio_range=(0.9, 1.1)): 22 | self.size = size 23 | self.prob_fgr_affine = prob_fgr_affine 24 | self.prob_bgr_affine = prob_bgr_affine 25 | self.prob_noise = prob_noise 26 | self.prob_color_jitter = prob_color_jitter 27 | self.prob_grayscale = prob_grayscale 28 | self.prob_sharpness = prob_sharpness 29 | self.prob_blur = prob_blur 30 | self.prob_hflip = prob_hflip 31 | self.prob_pause = prob_pause 32 | self.static_affine = static_affine 33 | self.aspect_ratio_range = aspect_ratio_range 34 | 35 | def __call__(self, fgrs, phas, bgrs): 36 | # Foreground affine 37 | if random.random() < self.prob_fgr_affine: 38 | fgrs, phas = self._motion_affine(fgrs, phas) 39 | 40 | # Background affine 41 | if random.random() < self.prob_bgr_affine / 2: 42 | bgrs = self._motion_affine(bgrs) 43 | if random.random() < self.prob_bgr_affine / 2: 44 | fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs) 45 | 46 | # Still Affine 47 | if self.static_affine: 48 | fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1)) 49 | bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5)) 50 | 51 | # To tensor 52 | fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs]) 53 | phas = torch.stack([F.to_tensor(pha) for pha in phas]) 54 | bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs]) 55 | 56 | # Resize 57 | params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range) 58 | fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 59 | phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 60 | params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range) 61 | bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 62 | 63 | # Horizontal flip 64 | if random.random() < self.prob_hflip: 65 | fgrs = F.hflip(fgrs) 66 | phas = F.hflip(phas) 67 | if random.random() < self.prob_hflip: 68 | bgrs = F.hflip(bgrs) 69 | 70 | # Noise 71 | if random.random() < self.prob_noise: 72 | fgrs, bgrs = self._motion_noise(fgrs, bgrs) 73 | 74 | # Color jitter 75 | if random.random() < self.prob_color_jitter: 76 | fgrs = self._motion_color_jitter(fgrs) 77 | if random.random() < self.prob_color_jitter: 78 | bgrs = self._motion_color_jitter(bgrs) 79 | 80 | # Grayscale 81 | if random.random() < self.prob_grayscale: 82 | fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous() 83 | bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous() 84 | 85 | # Sharpen 86 | if random.random() < self.prob_sharpness: 87 | sharpness = random.random() * 8 88 | fgrs = F.adjust_sharpness(fgrs, sharpness) 89 | phas = F.adjust_sharpness(phas, sharpness) 90 | bgrs = F.adjust_sharpness(bgrs, sharpness) 91 | 92 | # Blur 93 | if random.random() < self.prob_blur / 3: 94 | fgrs, phas = self._motion_blur(fgrs, phas) 95 | if random.random() < self.prob_blur / 3: 96 | bgrs = self._motion_blur(bgrs) 97 | if random.random() < self.prob_blur / 3: 98 | fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs) 99 | 100 | # Pause 101 | if random.random() < self.prob_pause: 102 | fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs) 103 | 104 | return fgrs, phas, bgrs 105 | 106 | def _static_affine(self, *imgs, scale_ranges): 107 | params = transforms.RandomAffine.get_params( 108 | degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges, 109 | shears=(-5, 5), img_size=imgs[0][0].size) 110 | imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs] 111 | return imgs if len(imgs) > 1 else imgs[0] 112 | 113 | def _motion_affine(self, *imgs): 114 | config = dict(degrees=(-10, 10), translate=(0.1, 0.1), 115 | scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) 116 | angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) 117 | angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) 118 | 119 | T = len(imgs[0]) 120 | easing = random_easing_fn() 121 | for t in range(T): 122 | percentage = easing(t / (T - 1)) 123 | angle = lerp(angleA, angleB, percentage) 124 | transX = lerp(transXA, transXB, percentage) 125 | transY = lerp(transYA, transYB, percentage) 126 | scale = lerp(scaleA, scaleB, percentage) 127 | shearX = lerp(shearXA, shearXB, percentage) 128 | shearY = lerp(shearYA, shearYB, percentage) 129 | for img in imgs: 130 | img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) 131 | return imgs if len(imgs) > 1 else imgs[0] 132 | 133 | def _motion_noise(self, *imgs): 134 | grain_size = random.random() * 3 + 1 # range 1 ~ 4 135 | monochrome = random.random() < 0.5 136 | for img in imgs: 137 | T, C, H, W = img.shape 138 | noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size))) 139 | noise.mul_(random.random() * 0.2 / grain_size) 140 | if grain_size != 1: 141 | noise = F.resize(noise, (H, W)) 142 | img.add_(noise).clamp_(0, 1) 143 | return imgs if len(imgs) > 1 else imgs[0] 144 | 145 | def _motion_color_jitter(self, *imgs): 146 | brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \ 147 | = torch.randn(8).mul(0.1).tolist() 148 | strength = random.random() * 0.2 149 | easing = random_easing_fn() 150 | T = len(imgs[0]) 151 | for t in range(T): 152 | percentage = easing(t / (T - 1)) * strength 153 | for img in imgs: 154 | img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) 155 | img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1)) 156 | img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) 157 | img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1))) 158 | return imgs if len(imgs) > 1 else imgs[0] 159 | 160 | def _motion_blur(self, *imgs): 161 | blurA = random.random() * 10 162 | blurB = random.random() * 10 163 | 164 | T = len(imgs[0]) 165 | easing = random_easing_fn() 166 | for t in range(T): 167 | percentage = easing(t / (T - 1)) 168 | blur = max(lerp(blurA, blurB, percentage), 0) 169 | if blur != 0: 170 | kernel_size = int(blur * 2) 171 | if kernel_size % 2 == 0: 172 | kernel_size += 1 # Make kernel_size odd 173 | for img in imgs: 174 | img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur) 175 | 176 | return imgs if len(imgs) > 1 else imgs[0] 177 | 178 | def _motion_pause(self, *imgs): 179 | T = len(imgs[0]) 180 | pause_frame = random.choice(range(T - 1)) 181 | pause_length = random.choice(range(T - pause_frame)) 182 | for img in imgs: 183 | img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame] 184 | return imgs if len(imgs) > 1 else imgs[0] 185 | 186 | 187 | def lerp(a, b, percentage): 188 | return a * (1 - percentage) + b * percentage 189 | 190 | 191 | def random_easing_fn(): 192 | if random.random() < 0.2: 193 | return ef.LinearInOut() 194 | else: 195 | return random.choice([ 196 | ef.BackEaseIn, 197 | ef.BackEaseOut, 198 | ef.BackEaseInOut, 199 | ef.BounceEaseIn, 200 | ef.BounceEaseOut, 201 | ef.BounceEaseInOut, 202 | ef.CircularEaseIn, 203 | ef.CircularEaseOut, 204 | ef.CircularEaseInOut, 205 | ef.CubicEaseIn, 206 | ef.CubicEaseOut, 207 | ef.CubicEaseInOut, 208 | ef.ExponentialEaseIn, 209 | ef.ExponentialEaseOut, 210 | ef.ExponentialEaseInOut, 211 | ef.ElasticEaseIn, 212 | ef.ElasticEaseOut, 213 | ef.ElasticEaseInOut, 214 | ef.QuadEaseIn, 215 | ef.QuadEaseOut, 216 | ef.QuadEaseInOut, 217 | ef.QuarticEaseIn, 218 | ef.QuarticEaseOut, 219 | ef.QuarticEaseInOut, 220 | ef.QuinticEaseIn, 221 | ef.QuinticEaseOut, 222 | ef.QuinticEaseInOut, 223 | ef.SineEaseIn, 224 | ef.SineEaseOut, 225 | ef.SineEaseInOut, 226 | Step, 227 | ])() 228 | 229 | class Step: # Custom easing function for sudden change. 230 | def __call__(self, value): 231 | return 0 if value < 0.5 else 1 232 | 233 | 234 | # ---------------------------- Frame Sampler ---------------------------- 235 | 236 | 237 | class TrainFrameSampler: 238 | def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]): 239 | self.speed = speed 240 | 241 | def __call__(self, seq_length): 242 | frames = list(range(seq_length)) 243 | 244 | # Speed up 245 | speed = random.choice(self.speed) 246 | frames = [int(f * speed) for f in frames] 247 | 248 | # Shift 249 | shift = random.choice(range(seq_length)) 250 | frames = [f + shift for f in frames] 251 | 252 | # Reverse 253 | if random.random() < 0.5: 254 | frames = frames[::-1] 255 | 256 | return frames 257 | 258 | class ValidFrameSampler: 259 | def __call__(self, seq_length): 260 | return range(seq_length) 261 | -------------------------------------------------------------------------------- /dataset/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import json 5 | import os 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | from torchvision.transforms import functional as F 9 | from PIL import Image 10 | 11 | 12 | class CocoPanopticDataset(Dataset): 13 | def __init__(self, 14 | imgdir: str, 15 | anndir: str, 16 | annfile: str, 17 | transform=None): 18 | with open(annfile) as f: 19 | self.data = json.load(f)['annotations'] 20 | self.data = list(filter(lambda data: any(info['category_id'] == 1 for info in data['segments_info']), self.data)) 21 | self.imgdir = imgdir 22 | self.anndir = anndir 23 | self.transform = transform 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | def __getitem__(self, idx): 29 | data = self.data[idx] 30 | img = self._load_img(data) 31 | seg = self._load_seg(data) 32 | 33 | if self.transform is not None: 34 | img, seg = self.transform(img, seg) 35 | 36 | return img, seg 37 | 38 | def _load_img(self, data): 39 | with Image.open(os.path.join(self.imgdir, data['file_name'].replace('.png', '.jpg'))) as img: 40 | return img.convert('RGB') 41 | 42 | def _load_seg(self, data): 43 | with Image.open(os.path.join(self.anndir, data['file_name'])) as ann: 44 | ann.load() 45 | 46 | ann = np.array(ann, copy=False).astype(np.int32) 47 | ann = ann[:, :, 0] + 256 * ann[:, :, 1] + 256 * 256 * ann[:, :, 2] 48 | seg = np.zeros(ann.shape, np.uint8) 49 | 50 | for segments_info in data['segments_info']: 51 | if segments_info['category_id'] in [1, 27, 32]: # person, backpack, tie 52 | seg[ann == segments_info['id']] = 255 53 | 54 | return Image.fromarray(seg) 55 | 56 | 57 | class CocoPanopticTrainAugmentation: 58 | def __init__(self, size): 59 | self.size = size 60 | self.jitter = transforms.ColorJitter(0.1, 0.1, 0.1, 0.1) 61 | 62 | def __call__(self, img, seg): 63 | # Affine 64 | params = transforms.RandomAffine.get_params(degrees=(-20, 20), translate=(0.1, 0.1), 65 | scale_ranges=(1, 1), shears=(-10, 10), img_size=img.size) 66 | img = F.affine(img, *params, interpolation=F.InterpolationMode.BILINEAR) 67 | seg = F.affine(seg, *params, interpolation=F.InterpolationMode.NEAREST) 68 | 69 | # Resize 70 | params = transforms.RandomResizedCrop.get_params(img, scale=(0.5, 1), ratio=(0.7, 1.3)) 71 | img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 72 | seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST) 73 | 74 | # Horizontal flip 75 | if random.random() < 0.5: 76 | img = F.hflip(img) 77 | seg = F.hflip(seg) 78 | 79 | # Color jitter 80 | img = self.jitter(img) 81 | 82 | # To tensor 83 | img = F.to_tensor(img) 84 | seg = F.to_tensor(seg) 85 | 86 | return img, seg 87 | 88 | 89 | class CocoPanopticValidAugmentation: 90 | def __init__(self, size): 91 | self.size = size 92 | 93 | def __call__(self, img, seg): 94 | # Resize 95 | params = transforms.RandomResizedCrop.get_params(img, scale=(1, 1), ratio=(1., 1.)) 96 | img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 97 | seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST) 98 | 99 | # To tensor 100 | img = F.to_tensor(img) 101 | seg = F.to_tensor(seg) 102 | 103 | return img, seg -------------------------------------------------------------------------------- /dataset/imagematte.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | 6 | from .augmentation import MotionAugmentation 7 | 8 | 9 | class ImageMatteDataset(Dataset): 10 | def __init__(self, 11 | imagematte_dir, 12 | background_image_dir, 13 | background_video_dir, 14 | size, 15 | seq_length, 16 | seq_sampler, 17 | transform): 18 | self.imagematte_dir = imagematte_dir 19 | self.imagematte_files = os.listdir(os.path.join(imagematte_dir, 'fgr')) 20 | self.background_image_dir = background_image_dir 21 | self.background_image_files = os.listdir(background_image_dir) 22 | self.background_video_dir = background_video_dir 23 | self.background_video_clips = os.listdir(background_video_dir) 24 | self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) 25 | for clip in self.background_video_clips] 26 | self.seq_length = seq_length 27 | self.seq_sampler = seq_sampler 28 | self.size = size 29 | self.transform = transform 30 | 31 | def __len__(self): 32 | return max(len(self.imagematte_files), len(self.background_image_files) + len(self.background_video_clips)) 33 | 34 | def __getitem__(self, idx): 35 | if random.random() < 0.5: 36 | bgrs = self._get_random_image_background() 37 | else: 38 | bgrs = self._get_random_video_background() 39 | 40 | fgrs, phas = self._get_imagematte(idx) 41 | 42 | if self.transform is not None: 43 | return self.transform(fgrs, phas, bgrs) 44 | 45 | return fgrs, phas, bgrs 46 | 47 | def _get_imagematte(self, idx): 48 | with Image.open(os.path.join(self.imagematte_dir, 'fgr', self.imagematte_files[idx % len(self.imagematte_files)])) as fgr, \ 49 | Image.open(os.path.join(self.imagematte_dir, 'pha', self.imagematte_files[idx % len(self.imagematte_files)])) as pha: 50 | fgr = self._downsample_if_needed(fgr.convert('RGB')) 51 | pha = self._downsample_if_needed(pha.convert('L')) 52 | fgrs = [fgr] * self.seq_length 53 | phas = [pha] * self.seq_length 54 | return fgrs, phas 55 | 56 | def _get_random_image_background(self): 57 | with Image.open(os.path.join(self.background_image_dir, self.background_image_files[random.choice(range(len(self.background_image_files)))])) as bgr: 58 | bgr = self._downsample_if_needed(bgr.convert('RGB')) 59 | bgrs = [bgr] * self.seq_length 60 | return bgrs 61 | 62 | def _get_random_video_background(self): 63 | clip_idx = random.choice(range(len(self.background_video_clips))) 64 | frame_count = len(self.background_video_frames[clip_idx]) 65 | frame_idx = random.choice(range(max(1, frame_count - self.seq_length))) 66 | clip = self.background_video_clips[clip_idx] 67 | bgrs = [] 68 | for i in self.seq_sampler(self.seq_length): 69 | frame_idx_t = frame_idx + i 70 | frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count] 71 | with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr: 72 | bgr = self._downsample_if_needed(bgr.convert('RGB')) 73 | bgrs.append(bgr) 74 | return bgrs 75 | 76 | def _downsample_if_needed(self, img): 77 | w, h = img.size 78 | if min(w, h) > self.size: 79 | scale = self.size / min(w, h) 80 | w = int(scale * w) 81 | h = int(scale * h) 82 | img = img.resize((w, h)) 83 | return img 84 | 85 | class ImageMatteAugmentation(MotionAugmentation): 86 | def __init__(self, size): 87 | super().__init__( 88 | size=size, 89 | prob_fgr_affine=0.95, 90 | prob_bgr_affine=0.3, 91 | prob_noise=0.05, 92 | prob_color_jitter=0.3, 93 | prob_grayscale=0.03, 94 | prob_sharpness=0.05, 95 | prob_blur=0.02, 96 | prob_hflip=0.5, 97 | prob_pause=0.03, 98 | ) 99 | -------------------------------------------------------------------------------- /dataset/spd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | 5 | 6 | class SuperviselyPersonDataset(Dataset): 7 | def __init__(self, imgdir, segdir, transform=None): 8 | self.img_dir = imgdir 9 | self.img_files = sorted(os.listdir(imgdir)) 10 | self.seg_dir = segdir 11 | self.seg_files = sorted(os.listdir(segdir)) 12 | assert len(self.img_files) == len(self.seg_files) 13 | self.transform = transform 14 | 15 | def __len__(self): 16 | return len(self.img_files) 17 | 18 | def __getitem__(self, idx): 19 | with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \ 20 | Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg: 21 | img = img.convert('RGB') 22 | seg = seg.convert('L') 23 | 24 | if self.transform is not None: 25 | img, seg = self.transform(img, seg) 26 | 27 | return img, seg 28 | -------------------------------------------------------------------------------- /dataset/videomatte.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | 6 | from .augmentation import MotionAugmentation 7 | 8 | 9 | class VideoMatteDataset(Dataset): 10 | def __init__(self, 11 | videomatte_dir, 12 | background_image_dir, 13 | background_video_dir, 14 | size, 15 | seq_length, 16 | seq_sampler, 17 | transform=None): 18 | self.background_image_dir = background_image_dir 19 | self.background_image_files = os.listdir(background_image_dir) 20 | self.background_video_dir = background_video_dir 21 | self.background_video_clips = sorted(os.listdir(background_video_dir)) 22 | self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) 23 | for clip in self.background_video_clips] 24 | 25 | self.videomatte_dir = videomatte_dir 26 | self.videomatte_clips = sorted(os.listdir(os.path.join(videomatte_dir, 'fgr'))) 27 | self.videomatte_frames = [sorted(os.listdir(os.path.join(videomatte_dir, 'fgr', clip))) 28 | for clip in self.videomatte_clips] 29 | self.videomatte_idx = [(clip_idx, frame_idx) 30 | for clip_idx in range(len(self.videomatte_clips)) 31 | for frame_idx in range(0, len(self.videomatte_frames[clip_idx]), seq_length)] 32 | self.size = size 33 | self.seq_length = seq_length 34 | self.seq_sampler = seq_sampler 35 | self.transform = transform 36 | 37 | def __len__(self): 38 | return len(self.videomatte_idx) 39 | 40 | def __getitem__(self, idx): 41 | if random.random() < 0.5: 42 | bgrs = self._get_random_image_background() 43 | else: 44 | bgrs = self._get_random_video_background() 45 | 46 | fgrs, phas = self._get_videomatte(idx) 47 | 48 | if self.transform is not None: 49 | return self.transform(fgrs, phas, bgrs) 50 | 51 | return fgrs, phas, bgrs 52 | 53 | def _get_random_image_background(self): 54 | with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr: 55 | bgr = self._downsample_if_needed(bgr.convert('RGB')) 56 | bgrs = [bgr] * self.seq_length 57 | return bgrs 58 | 59 | def _get_random_video_background(self): 60 | clip_idx = random.choice(range(len(self.background_video_clips))) 61 | frame_count = len(self.background_video_frames[clip_idx]) 62 | frame_idx = random.choice(range(max(1, frame_count - self.seq_length))) 63 | clip = self.background_video_clips[clip_idx] 64 | bgrs = [] 65 | for i in self.seq_sampler(self.seq_length): 66 | frame_idx_t = frame_idx + i 67 | frame = self.background_video_frames[clip_idx][frame_idx_t % frame_count] 68 | with Image.open(os.path.join(self.background_video_dir, clip, frame)) as bgr: 69 | bgr = self._downsample_if_needed(bgr.convert('RGB')) 70 | bgrs.append(bgr) 71 | return bgrs 72 | 73 | def _get_videomatte(self, idx): 74 | clip_idx, frame_idx = self.videomatte_idx[idx] 75 | clip = self.videomatte_clips[clip_idx] 76 | frame_count = len(self.videomatte_frames[clip_idx]) 77 | fgrs, phas = [], [] 78 | for i in self.seq_sampler(self.seq_length): 79 | frame = self.videomatte_frames[clip_idx][(frame_idx + i) % frame_count] 80 | with Image.open(os.path.join(self.videomatte_dir, 'fgr', clip, frame)) as fgr, \ 81 | Image.open(os.path.join(self.videomatte_dir, 'pha', clip, frame)) as pha: 82 | fgr = self._downsample_if_needed(fgr.convert('RGB')) 83 | pha = self._downsample_if_needed(pha.convert('L')) 84 | fgrs.append(fgr) 85 | phas.append(pha) 86 | return fgrs, phas 87 | 88 | def _downsample_if_needed(self, img): 89 | w, h = img.size 90 | if min(w, h) > self.size: 91 | scale = self.size / min(w, h) 92 | w = int(scale * w) 93 | h = int(scale * h) 94 | img = img.resize((w, h)) 95 | return img 96 | 97 | class VideoMatteTrainAugmentation(MotionAugmentation): 98 | def __init__(self, size): 99 | super().__init__( 100 | size=size, 101 | prob_fgr_affine=0.3, 102 | prob_bgr_affine=0.3, 103 | prob_noise=0.1, 104 | prob_color_jitter=0.3, 105 | prob_grayscale=0.02, 106 | prob_sharpness=0.1, 107 | prob_blur=0.02, 108 | prob_hflip=0.5, 109 | prob_pause=0.03, 110 | ) 111 | 112 | class VideoMatteValidAugmentation(MotionAugmentation): 113 | def __init__(self, size): 114 | super().__init__( 115 | size=size, 116 | prob_fgr_affine=0, 117 | prob_bgr_affine=0, 118 | prob_noise=0, 119 | prob_color_jitter=0, 120 | prob_grayscale=0, 121 | prob_sharpness=0, 122 | prob_blur=0, 123 | prob_hflip=0, 124 | prob_pause=0, 125 | ) 126 | -------------------------------------------------------------------------------- /dataset/youtubevis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | import numpy as np 5 | import random 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | from torchvision import transforms 9 | from torchvision.transforms import functional as F 10 | 11 | 12 | class YouTubeVISDataset(Dataset): 13 | def __init__(self, videodir, annfile, size, seq_length, seq_sampler, transform=None): 14 | self.videodir = videodir 15 | self.size = size 16 | self.seq_length = seq_length 17 | self.seq_sampler = seq_sampler 18 | self.transform = transform 19 | 20 | with open(annfile) as f: 21 | data = json.load(f) 22 | 23 | self.masks = {} 24 | for ann in data['annotations']: 25 | if ann['category_id'] == 26: # person 26 | video_id = ann['video_id'] 27 | if video_id not in self.masks: 28 | self.masks[video_id] = [[] for _ in range(len(ann['segmentations']))] 29 | for frame, mask in zip(self.masks[video_id], ann['segmentations']): 30 | if mask is not None: 31 | frame.append(mask) 32 | 33 | self.videos = {} 34 | for video in data['videos']: 35 | video_id = video['id'] 36 | if video_id in self.masks: 37 | self.videos[video_id] = video 38 | 39 | self.index = [] 40 | for video_id in self.videos.keys(): 41 | for frame in range(len(self.videos[video_id]['file_names'])): 42 | self.index.append((video_id, frame)) 43 | 44 | def __len__(self): 45 | return len(self.index) 46 | 47 | def __getitem__(self, idx): 48 | video_id, frame_id = self.index[idx] 49 | video = self.videos[video_id] 50 | frame_count = len(self.videos[video_id]['file_names']) 51 | H, W = video['height'], video['width'] 52 | 53 | imgs, segs = [], [] 54 | for t in self.seq_sampler(self.seq_length): 55 | frame = (frame_id + t) % frame_count 56 | 57 | filename = video['file_names'][frame] 58 | masks = self.masks[video_id][frame] 59 | 60 | with Image.open(os.path.join(self.videodir, filename)) as img: 61 | imgs.append(self._downsample_if_needed(img.convert('RGB'), Image.BILINEAR)) 62 | 63 | seg = np.zeros((H, W), dtype=np.uint8) 64 | for mask in masks: 65 | seg |= self._decode_rle(mask) 66 | segs.append(self._downsample_if_needed(Image.fromarray(seg), Image.NEAREST)) 67 | 68 | if self.transform is not None: 69 | imgs, segs = self.transform(imgs, segs) 70 | 71 | return imgs, segs 72 | 73 | def _decode_rle(self, rle): 74 | H, W = rle['size'] 75 | msk = np.zeros(H * W, dtype=np.uint8) 76 | encoding = rle['counts'] 77 | skip = 0 78 | for i in range(0, len(encoding) - 1, 2): 79 | skip += encoding[i] 80 | draw = encoding[i + 1] 81 | msk[skip : skip + draw] = 255 82 | skip += draw 83 | return msk.reshape(W, H).transpose() 84 | 85 | def _downsample_if_needed(self, img, resample): 86 | w, h = img.size 87 | if min(w, h) > self.size: 88 | scale = self.size / min(w, h) 89 | w = int(scale * w) 90 | h = int(scale * h) 91 | img = img.resize((w, h), resample) 92 | return img 93 | 94 | 95 | class YouTubeVISAugmentation: 96 | def __init__(self, size): 97 | self.size = size 98 | self.jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.15) 99 | 100 | def __call__(self, imgs, segs): 101 | 102 | # To tensor 103 | imgs = torch.stack([F.to_tensor(img) for img in imgs]) 104 | segs = torch.stack([F.to_tensor(seg) for seg in segs]) 105 | 106 | # Resize 107 | params = transforms.RandomResizedCrop.get_params(imgs, scale=(0.8, 1), ratio=(0.9, 1.1)) 108 | imgs = F.resized_crop(imgs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 109 | segs = F.resized_crop(segs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) 110 | 111 | # Color jitter 112 | imgs = self.jitter(imgs) 113 | 114 | # Grayscale 115 | if random.random() < 0.05: 116 | imgs = F.rgb_to_grayscale(imgs, num_output_channels=3) 117 | 118 | # Horizontal flip 119 | if random.random() < 0.5: 120 | imgs = F.hflip(imgs) 121 | segs = F.hflip(segs) 122 | 123 | return imgs, segs 124 | -------------------------------------------------------------------------------- /documentation/image/showreel.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterL1n/RobustVideoMatting/53d74c6826735f01f4406b5ca9075eee27bec094/documentation/image/showreel.gif -------------------------------------------------------------------------------- /documentation/image/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterL1n/RobustVideoMatting/53d74c6826735f01f4406b5ca9075eee27bec094/documentation/image/teaser.gif -------------------------------------------------------------------------------- /documentation/inference.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 |

English | 中文

4 | 5 | ## Content 6 | 7 | * [Concepts](#concepts) 8 | * [Downsample Ratio](#downsample-ratio) 9 | * [Recurrent States](#recurrent-states) 10 | * [PyTorch](#pytorch) 11 | * [TorchHub](#torchhub) 12 | * [TorchScript](#torchscript) 13 | * [ONNX](#onnx) 14 | * [TensorFlow](#tensorflow) 15 | * [TensorFlow.js](#tensorflowjs) 16 | * [CoreML](#coreml) 17 | 18 |
19 | 20 | 21 | ## Concepts 22 | 23 | ### Downsample Ratio 24 | 25 | The table provides a general guideline. Please adjust based on your video content. 26 | 27 | | Resolution | Portrait | Full-Body | 28 | | ------------- | ------------- | -------------- | 29 | | <= 512x512 | 1 | 1 | 30 | | 1280x720 | 0.375 | 0.6 | 31 | | 1920x1080 | 0.25 | 0.4 | 32 | | 3840x2160 | 0.125 | 0.2 | 33 | 34 | Internally, the model resizes down the input for stage 1. Then, it refines at high-resolution for stage 2. 35 | 36 | Set `downsample_ratio` so that the downsampled resolution is between 256 and 512. For example, for `1920x1080` input with `downsample_ratio=0.25`, the resized resolution `480x270` is between 256 and 512. 37 | 38 | Adjust `downsample_ratio` base on the video content. If the shot is portrait, a lower `downsample_ratio` is sufficient. If the shot contains the full human body, use high `downsample_ratio`. Note that higher `downsample_ratio` is not always better. 39 | 40 | 41 |
42 | 43 | ### Recurrent States 44 | The model is a recurrent neural network. You must process frames sequentially and recycle its recurrent states. 45 | 46 | **Correct Way** 47 | 48 | The recurrent outputs are recycled back as input when processing the next frame. The states are essentially the model's memory. 49 | 50 | ```python 51 | rec = [None] * 4 # Initial recurrent states are None 52 | 53 | for frame in YOUR_VIDEO: 54 | fgr, pha, *rec = model(frame, *rec, downsample_ratio) 55 | ``` 56 | 57 | **Wrong Way** 58 | 59 | The model does not utilize the recurrent states. Only use it to process independent images. 60 | 61 | ```python 62 | for frame in YOUR_VIDEO: 63 | fgr, pha = model(frame, downsample_ratio)[:2] 64 | ``` 65 | 66 | More technical details are in the [paper](https://peterl1n.github.io/RobustVideoMatting/). 67 | 68 |


69 | 70 | 71 | ## PyTorch 72 | 73 | Model loading: 74 | 75 | ```python 76 | import torch 77 | from model import MattingNetwork 78 | 79 | model = MattingNetwork(variant='mobilenetv3').eval().cuda() # Or variant="resnet50" 80 | model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) 81 | ``` 82 | 83 | Example inference loop: 84 | ```python 85 | rec = [None] * 4 # Set initial recurrent states to None 86 | 87 | for src in YOUR_VIDEO: # src can be [B, C, H, W] or [B, T, C, H, W] 88 | fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25) 89 | ``` 90 | 91 | * `src`: Input frame. 92 | * Can be of shape `[B, C, H, W]` or `[B, T, C, H, W]`. 93 | * If `[B, T, C, H, W]`, a chunk of `T` frames can be given at once for better parallelism. 94 | * RGB input is normalized to `0~1` range. 95 | 96 | * `fgr, pha`: Foreground and alpha predictions. 97 | * Can be of shape `[B, C, H, W]` or `[B, T, C, H, W]` depends on `src`. 98 | * `fgr` has `C=3` for RGB, `pha` has `C=1`. 99 | * Outputs normalized to `0~1` range. 100 | * `rec`: Recurrent states. 101 | * Type of `List[Tensor, Tensor, Tensor, Tensor]`. 102 | * Initial `rec` can be `List[None, None, None, None]`. 103 | * It has 4 recurrent states because the model has 4 ConvGRU layers. 104 | * All tensors are rank 4 regardless of `src` rank. 105 | * If a chunk of `T` frames is given, only the last frame's recurrent states will be returned. 106 | 107 | To inference on video, here is a complete example: 108 | 109 | ```python 110 | from torch.utils.data import DataLoader 111 | from torchvision.transforms import ToTensor 112 | from inference_utils import VideoReader, VideoWriter 113 | 114 | reader = VideoReader('input.mp4', transform=ToTensor()) 115 | writer = VideoWriter('output.mp4', frame_rate=30) 116 | 117 | bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background. 118 | rec = [None] * 4 # Initial recurrent states. 119 | 120 | with torch.no_grad(): 121 | for src in DataLoader(reader): 122 | fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio=0.25) # Cycle the recurrent states. 123 | writer.write(fgr * pha + bgr * (1 - pha)) 124 | ``` 125 | 126 | Or you can use the provided video converter: 127 | 128 | ```python 129 | from inference import convert_video 130 | 131 | convert_video( 132 | model, # The loaded model, can be on any device (cpu or cuda). 133 | input_source='input.mp4', # A video file or an image sequence directory. 134 | input_resize=(1920, 1080), # [Optional] Resize the input (also the output). 135 | downsample_ratio=0.25, # [Optional] If None, make downsampled max size be 512px. 136 | output_type='video', # Choose "video" or "png_sequence" 137 | output_composition='com.mp4', # File path if video; directory path if png sequence. 138 | output_alpha="pha.mp4", # [Optional] Output the raw alpha prediction. 139 | output_foreground="fgr.mp4", # [Optional] Output the raw foreground prediction. 140 | output_video_mbps=4, # Output video mbps. Not needed for png sequence. 141 | seq_chunk=12, # Process n frames at once for better parallelism. 142 | num_workers=1, # Only for image sequence input. Reader threads. 143 | progress=True # Print conversion progress. 144 | ) 145 | ``` 146 | 147 | The converter can also be invoked in command line: 148 | 149 | ```sh 150 | python inference.py \ 151 | --variant mobilenetv3 \ 152 | --checkpoint "CHECKPOINT" \ 153 | --device cuda \ 154 | --input-source "input.mp4" \ 155 | --downsample-ratio 0.25 \ 156 | --output-type video \ 157 | --output-composition "composition.mp4" \ 158 | --output-alpha "alpha.mp4" \ 159 | --output-foreground "foreground.mp4" \ 160 | --output-video-mbps 4 \ 161 | --seq-chunk 12 162 | ``` 163 | 164 |


165 | 166 | ## TorchHub 167 | 168 | Model loading: 169 | 170 | ```python 171 | model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50" 172 | ``` 173 | 174 | Use the conversion function. Refer to the documentation for `convert_video` function above. 175 | 176 | ```python 177 | convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") 178 | 179 | convert_video(model, ...args...) 180 | ``` 181 | 182 |


183 | 184 | ## TorchScript 185 | 186 | Model loading: 187 | 188 | ```python 189 | import torch 190 | model = torch.jit.load('rvm_mobilenetv3.torchscript') 191 | ``` 192 | 193 | Optionally, freeze the model. This will trigger graph optimization, such as BatchNorm fusion etc. Frozen models are faster. 194 | 195 | ```python 196 | model = torch.jit.freeze(model) 197 | ``` 198 | 199 | Then, you can use the `model` exactly the same as a PyTorch model, with the exception that you must manually provide `device` and `dtype` to the converter API for frozen model. For example: 200 | 201 | ```python 202 | convert_video(frozen_model, ...args..., device='cuda', dtype=torch.float32) 203 | ``` 204 | 205 |


206 | 207 | ## ONNX 208 | 209 | Model spec: 210 | * Inputs: [`src`, `r1i`, `r2i`, `r3i`, `r4i`, `downsample_ratio`]. 211 | * `src` is the RGB input frame of shape `[B, C, H, W]` normalized to `0~1` range. 212 | * `rXi` are the recurrent state inputs. Initial recurrent states are zero value tensors of shape `[1, 1, 1, 1]`. 213 | * `downsample_ratio` is a tensor of shape `[1]`. 214 | * Only `downsample_ratio` must have `dtype=FP32`. Other inputs must have `dtype` matching the loaded model's precision. 215 | * Outputs: [`fgr`, `pha`, `r1o`, `r2o`, `r3o`, `r4o`] 216 | * `fgr, pha` are the foreground and alpha prediction. Normalized to `0~1` range. 217 | * `rXo` are the recurrent state outputs. 218 | 219 | We only show examples of using onnxruntime CUDA backend in Python. 220 | 221 | Model loading 222 | 223 | ```python 224 | import onnxruntime as ort 225 | 226 | sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx') 227 | ``` 228 | 229 | Naive inference loop 230 | 231 | ```python 232 | import numpy as np 233 | 234 | rec = [ np.zeros([1, 1, 1, 1], dtype=np.float16) ] * 4 # Must match dtype of the model. 235 | downsample_ratio = np.array([0.25], dtype=np.float32) # dtype always FP32 236 | 237 | for src in YOUR_VIDEO: # src is of [B, C, H, W] with dtype of the model. 238 | fgr, pha, *rec = sess.run([], { 239 | 'src': src, 240 | 'r1i': rec[0], 241 | 'r2i': rec[1], 242 | 'r3i': rec[2], 243 | 'r4i': rec[3], 244 | 'downsample_ratio': downsample_ratio 245 | }) 246 | ``` 247 | 248 | If you use GPU version of ONNX Runtime, the above naive implementation has recurrent states transferred between CPU and GPU on every frame. They could have just stayed on the GPU for better performance. Below is an example using `iobinding` to eliminate useless transfers. 249 | 250 | ```python 251 | import onnxruntime as ort 252 | import numpy as np 253 | 254 | # Load model. 255 | sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx') 256 | 257 | # Create an io binding. 258 | io = sess.io_binding() 259 | 260 | # Create tensors on CUDA. 261 | rec = [ ort.OrtValue.ortvalue_from_numpy(np.zeros([1, 1, 1, 1], dtype=np.float16), 'cuda') ] * 4 262 | downsample_ratio = ort.OrtValue.ortvalue_from_numpy(np.asarray([0.25], dtype=np.float32), 'cuda') 263 | 264 | # Set output binding. 265 | for name in ['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o']: 266 | io.bind_output(name, 'cuda') 267 | 268 | # Inference loop 269 | for src in YOUR_VIDEO: 270 | io.bind_cpu_input('src', src) 271 | io.bind_ortvalue_input('r1i', rec[0]) 272 | io.bind_ortvalue_input('r2i', rec[1]) 273 | io.bind_ortvalue_input('r3i', rec[2]) 274 | io.bind_ortvalue_input('r4i', rec[3]) 275 | io.bind_ortvalue_input('downsample_ratio', downsample_ratio) 276 | 277 | sess.run_with_iobinding(io) 278 | 279 | fgr, pha, *rec = io.get_outputs() 280 | 281 | # Only transfer `fgr` and `pha` to CPU. 282 | fgr = fgr.numpy() 283 | pha = pha.numpy() 284 | ``` 285 | 286 | Note: depending on the inference tool you choose, it may not support all the operations in our official ONNX model. You are responsible for modifying the model code and exporting your own ONNX model. You can refer to our exporter code in the [onnx branch](https://github.com/PeterL1n/RobustVideoMatting/tree/onnx). 287 | 288 |


289 | 290 | ### TensorFlow 291 | 292 | An example usage: 293 | 294 | ```python 295 | import tensorflow as tf 296 | 297 | model = tf.keras.models.load_model('rvm_mobilenetv3_tf') 298 | model = tf.function(model) 299 | 300 | rec = [ tf.constant(0.) ] * 4 # Initial recurrent states. 301 | downsample_ratio = tf.constant(0.25) # Adjust based on your video. 302 | 303 | for src in YOUR_VIDEO: # src is of shape [B, H, W, C], not [B, C, H, W]! 304 | out = model([src, *rec, downsample_ratio]) 305 | fgr, pha, *rec = out['fgr'], out['pha'], out['r1o'], out['r2o'], out['r3o'], out['r4o'] 306 | ``` 307 | 308 | Note the the tensors are all channel last. Otherwise, the inputs and outputs are exactly the same as PyTorch. 309 | 310 | We also provide the raw TensorFlow model code in the [tensorflow branch](https://github.com/PeterL1n/RobustVideoMatting/tree/tensorflow). You can transfer PyTorch checkpoint weights to TensorFlow models. 311 | 312 |


313 | 314 | ### TensorFlow.js 315 | 316 | We provide a starter code in the [tfjs branch](https://github.com/PeterL1n/RobustVideoMatting/tree/tfjs). The example is very self-explanatory. It shows how to properly use the model. 317 | 318 |


319 | 320 | ### CoreML 321 | 322 | We only show example usage of the CoreML models in Python API using `coremltools`. In production, the same logic can be applied in Swift. When processing the first frame, do not provide recurrent states. CoreML will internally construct zero tensors of the correct shapes as the initial recurrent states. 323 | 324 | ```python 325 | import coremltools as ct 326 | 327 | model = ct.models.model.MLModel('rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel') 328 | 329 | r1, r2, r3, r4 = None, None, None, None 330 | 331 | for src in YOUR_VIDEO: # src is PIL.Image. 332 | 333 | if r1 is None: 334 | # Initial frame, do not provide recurrent states. 335 | inputs = {'src': src} 336 | else: 337 | # Subsequent frames, provide recurrent states. 338 | inputs = {'src': src, 'r1i': r1, 'r2i': r2, 'r3i': r3, 'r4i': r4} 339 | 340 | outputs = model.predict(inputs) 341 | 342 | fgr = outputs['fgr'] # PIL.Image. 343 | pha = outputs['pha'] # PIL.Image. 344 | 345 | r1 = outputs['r1o'] # Numpy array. 346 | r2 = outputs['r2o'] # Numpy array. 347 | r3 = outputs['r3o'] # Numpy array. 348 | r4 = outputs['r4o'] # Numpy array. 349 | 350 | ``` 351 | 352 | Our CoreML models only support fixed resolutions. If you need other resolutions, you can export them yourself. See [coreml branch](https://github.com/PeterL1n/RobustVideoMatting/tree/coreml) for model export. -------------------------------------------------------------------------------- /documentation/inference_zh_Hans.md: -------------------------------------------------------------------------------- 1 | # 推断文档 2 | 3 |

English | 中文

4 | 5 | ## 目录 6 | 7 | * [概念](#概念) 8 | * [下采样比](#下采样比) 9 | * [循环记忆](#循环记忆) 10 | * [PyTorch](#pytorch) 11 | * [TorchHub](#torchhub) 12 | * [TorchScript](#torchscript) 13 | * [ONNX](#onnx) 14 | * [TensorFlow](#tensorflow) 15 | * [TensorFlow.js](#tensorflowjs) 16 | * [CoreML](#coreml) 17 | 18 |
19 | 20 | 21 | ## 概念 22 | 23 | ### 下采样比 24 | 25 | 该表仅供参考。可根据视频内容进行调节。 26 | 27 | | 分辨率 | 人像 | 全身 | 28 | | ------------- | ------------- | -------------- | 29 | | <= 512x512 | 1 | 1 | 30 | | 1280x720 | 0.375 | 0.6 | 31 | | 1920x1080 | 0.25 | 0.4 | 32 | | 3840x2160 | 0.125 | 0.2 | 33 | 34 | 模型在内部将高分辨率输入缩小做初步的处理,然后再放大做细分处理。 35 | 36 | 建议设置 `downsample_ratio` 使缩小后的分辨率维持在 256 到 512 像素之间. 例如,`1920x1080` 的输入用 `downsample_ratio=0.25`,缩小后的分辨率 `480x270` 在 256 到 512 像素之间。 37 | 38 | 根据视频内容调整 `downsample_ratio`。若视频是上身人像,低 `downsample_ratio` 足矣。若视频是全身像,建议尝试更高的 `downsample_ratio`。但注意,过高的 `downsample_ratio` 反而会降低效果。 39 | 40 | 41 |
42 | 43 | ### 循环记忆 44 | 此模型是循环神经网络(Recurrent Neural Network)。必须按顺序处理视频每帧,并提供网络循环记忆。 45 | 46 | **正确用法** 47 | 48 | 循环记忆输出被传递到下一帧做输入。 49 | 50 | ```python 51 | rec = [None] * 4 # 初始值设置为 None 52 | 53 | for frame in YOUR_VIDEO: 54 | fgr, pha, *rec = model(frame, *rec, downsample_ratio) 55 | ``` 56 | 57 | **错误用法** 58 | 59 | 没有使用循环记忆。此方法仅可用于处理单独的图片。 60 | 61 | ```python 62 | for frame in YOUR_VIDEO: 63 | fgr, pha = model(frame, downsample_ratio)[:2] 64 | ``` 65 | 66 | 更多技术细节见[论文](https://peterl1n.github.io/RobustVideoMatting/)。 67 | 68 |


69 | 70 | 71 | ## PyTorch 72 | 73 | 载入模型: 74 | 75 | ```python 76 | import torch 77 | from model import MattingNetwork 78 | 79 | model = MattingNetwork(variant='mobilenetv3').eval().cuda() # 或 variant="resnet50" 80 | model.load_state_dict(torch.load('rvm_mobilenetv3.pth')) 81 | ``` 82 | 83 | 推断循环: 84 | ```python 85 | rec = [None] * 4 # 初始值设置为 None 86 | 87 | for src in YOUR_VIDEO: # src 可以是 [B, C, H, W] 或 [B, T, C, H, W] 88 | fgr, pha, *rec = model(src, *rec, downsample_ratio=0.25) 89 | ``` 90 | 91 | * `src`: 输入帧(Source)。 92 | * 可以是 `[B, C, H, W]` 或 `[B, T, C, H, W]` 的张量。 93 | * 若是 `[B, T, C, H, W]`,可给模型一次 `T` 帧,做一小段一小段地处理,用于更好的并行计算。 94 | * RGB 通道输入,范围为 `0~1`。 95 | 96 | * `fgr, pha`: 前景(Foreground)和透明度通道(Alpha)的预测。 97 | * 根据`src`,可为 `[B, C, H, W]` 或 `[B, T, C, H, W]` 的输出。 98 | * `fgr` 是 RGB 三通道,`pha` 为一通道。 99 | * 输出范围为 `0~1`。 100 | * `rec`: 循环记忆(Recurrent States)。 101 | * `List[Tensor, Tensor, Tensor, Tensor]` 类型。 102 | * 初始 `rec` 为 `List[None, None, None, None]`。 103 | * 有四个记忆,因为网络使用四个 `ConvGRU` 层。 104 | * 无论 `src` 的 Rank,所有记忆张量的 Rank 为 4。 105 | * 若一次给予 `T` 帧,只返回处理完最后一帧后的记忆。 106 | 107 | 完整的推断例子: 108 | 109 | ```python 110 | from torch.utils.data import DataLoader 111 | from torchvision.transforms import ToTensor 112 | from inference_utils import VideoReader, VideoWriter 113 | 114 | reader = VideoReader('input.mp4', transform=ToTensor()) 115 | writer = VideoWriter('output.mp4', frame_rate=30) 116 | 117 | bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # 绿背景 118 | rec = [None] * 4 # 初始记忆 119 | 120 | with torch.no_grad(): 121 | for src in DataLoader(reader): 122 | fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio=0.25) # 将上一帧的记忆给下一帧 123 | writer.write(fgr * pha + bgr * (1 - pha)) 124 | ``` 125 | 126 | 或者使用提供的视频转换 API: 127 | 128 | ```python 129 | from inference import convert_video 130 | 131 | convert_video( 132 | model, # 模型,可以加载到任何设备(cpu 或 cuda) 133 | input_source='input.mp4', # 视频文件,或图片序列文件夹 134 | input_resize=(1920, 1080), # [可选项] 缩放视频大小 135 | downsample_ratio=0.25, # [可选项] 下采样比,若 None,自动下采样至 512px 136 | output_type='video', # 可选 "video"(视频)或 "png_sequence"(PNG 序列) 137 | output_composition='com.mp4', # 若导出视频,提供文件路径。若导出 PNG 序列,提供文件夹路径 138 | output_alpha="pha.mp4", # [可选项] 输出透明度预测 139 | output_foreground="fgr.mp4", # [可选项] 输出前景预测 140 | output_video_mbps=4, # 若导出视频,提供视频码率 141 | seq_chunk=12, # 设置多帧并行计算 142 | num_workers=1, # 只适用于图片序列输入,读取线程 143 | progress=True # 显示进度条 144 | ) 145 | ``` 146 | 147 | 也可通过命令行调用转换 API: 148 | 149 | ```sh 150 | python inference.py \ 151 | --variant mobilenetv3 \ 152 | --checkpoint "CHECKPOINT" \ 153 | --device cuda \ 154 | --input-source "input.mp4" \ 155 | --downsample-ratio 0.25 \ 156 | --output-type video \ 157 | --output-composition "composition.mp4" \ 158 | --output-alpha "alpha.mp4" \ 159 | --output-foreground "foreground.mp4" \ 160 | --output-video-mbps 4 \ 161 | --seq-chunk 12 162 | ``` 163 | 164 |


165 | 166 | ## TorchHub 167 | 168 | 载入模型: 169 | 170 | ```python 171 | model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50" 172 | ``` 173 | 174 | 使用转换 API,具体请参考之前对 `convert_video` 的文档。 175 | 176 | ```python 177 | convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") 178 | 179 | convert_video(model, ...args...) 180 | ``` 181 | 182 |


183 | 184 | ## TorchScript 185 | 186 | 载入模型: 187 | 188 | ```python 189 | import torch 190 | model = torch.jit.load('rvm_mobilenetv3.torchscript') 191 | ``` 192 | 193 | 也可以可选的将模型固化(Freeze)。这会对模型进行优化,例如 BatchNorm Fusion 等。固化的模型更快。 194 | 195 | ```python 196 | model = torch.jit.freeze(model) 197 | ``` 198 | 199 | 然后,可以将 `model` 作为普通的 PyTorch 模型使用。但注意,若用固化模型调用转换 API,必须手动提供 `device` 和 `dtype`: 200 | 201 | ```python 202 | convert_video(frozen_model, ...args..., device='cuda', dtype=torch.float32) 203 | ``` 204 | 205 |


206 | 207 | ## ONNX 208 | 209 | 模型规格: 210 | * 输入: [`src`, `r1i`, `r2i`, `r3i`, `r4i`, `downsample_ratio`]. 211 | * `src`:输入帧,RGB 通道,形状为 `[B, C, H, W]`,范围为`0~1`。 212 | * `rXi`:记忆输入,初始值是是形状为 `[1, 1, 1, 1]` 的零张量。 213 | * `downsample_ratio` 下采样比,张量形状为 `[1]`。 214 | * 只有 `downsample_ratio` 必须是 `FP32`,其他输入必须和加载的模型使用一样的 `dtype`。 215 | * 输出: [`fgr`, `pha`, `r1o`, `r2o`, `r3o`, `r4o`] 216 | * `fgr, pha`:前景和透明度通道输出,范围为 `0~1`。 217 | * `rXo`:记忆输出。 218 | 219 | 我们只展示用 ONNX Runtime CUDA Backend 在 Python 上的使用范例。 220 | 221 | 载入模型: 222 | 223 | ```python 224 | import onnxruntime as ort 225 | 226 | sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx') 227 | ``` 228 | 229 | 简单推断循环,但此方法不是最优化的: 230 | 231 | ```python 232 | import numpy as np 233 | 234 | rec = [ np.zeros([1, 1, 1, 1], dtype=np.float16) ] * 4 # 必须用模型一样的 dtype 235 | downsample_ratio = np.array([0.25], dtype=np.float32) # 必须是 FP32 236 | 237 | for src in YOUR_VIDEO: # src 张量是 [B, C, H, W] 形状,必须用模型一样的 dtype 238 | fgr, pha, *rec = sess.run([], { 239 | 'src': src, 240 | 'r1i': rec[0], 241 | 'r2i': rec[1], 242 | 'r3i': rec[2], 243 | 'r4i': rec[3], 244 | 'downsample_ratio': downsample_ratio 245 | }) 246 | ``` 247 | 248 | 若使用 GPU,上例会将记忆输出传回到 CPU,再在下一帧时传回到 GPU。这种传输是无意义的,因为记忆值可以留在 GPU 上。下例使用 `iobinding` 来杜绝无用的传输。 249 | 250 | ```python 251 | import onnxruntime as ort 252 | import numpy as np 253 | 254 | # 载入模型 255 | sess = ort.InferenceSession('rvm_mobilenetv3_fp16.onnx') 256 | 257 | # 创建 io binding. 258 | io = sess.io_binding() 259 | 260 | # 在 CUDA 上创建张量 261 | rec = [ ort.OrtValue.ortvalue_from_numpy(np.zeros([1, 1, 1, 1], dtype=np.float16), 'cuda') ] * 4 262 | downsample_ratio = ort.OrtValue.ortvalue_from_numpy(np.asarray([0.25], dtype=np.float32), 'cuda') 263 | 264 | # 设置输出项 265 | for name in ['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o']: 266 | io.bind_output(name, 'cuda') 267 | 268 | # 推断 269 | for src in YOUR_VIDEO: 270 | io.bind_cpu_input('src', src) 271 | io.bind_ortvalue_input('r1i', rec[0]) 272 | io.bind_ortvalue_input('r2i', rec[1]) 273 | io.bind_ortvalue_input('r3i', rec[2]) 274 | io.bind_ortvalue_input('r4i', rec[3]) 275 | io.bind_ortvalue_input('downsample_ratio', downsample_ratio) 276 | 277 | sess.run_with_iobinding(io) 278 | 279 | fgr, pha, *rec = io.get_outputs() 280 | 281 | # 只将 `fgr` 和 `pha` 回传到 CPU 282 | fgr = fgr.numpy() 283 | pha = pha.numpy() 284 | ``` 285 | 286 | 注:若你使用其他推断框架,可能有些 ONNX ops 不被支持,需被替换。可以参考 [onnx](https://github.com/PeterL1n/RobustVideoMatting/tree/onnx) 分支的代码做自行导出。 287 | 288 |


289 | 290 | ### TensorFlow 291 | 292 | 范例: 293 | 294 | ```python 295 | import tensorflow as tf 296 | 297 | model = tf.keras.models.load_model('rvm_mobilenetv3_tf') 298 | model = tf.function(model) 299 | 300 | rec = [ tf.constant(0.) ] * 4 # 初始记忆 301 | downsample_ratio = tf.constant(0.25) # 下采样率,根据视频调整 302 | 303 | for src in YOUR_VIDEO: # src 张量是 [B, H, W, C] 的形状,而不是 [B, C, H, W]! 304 | out = model([src, *rec, downsample_ratio]) 305 | fgr, pha, *rec = out['fgr'], out['pha'], out['r1o'], out['r2o'], out['r3o'], out['r4o'] 306 | ``` 307 | 308 | 注意,在 TensorFlow 上,所有张量都是 Channal Last 的格式。 309 | 310 | 我们提供 TensorFlow 的原始模型代码,请参考 [tensorflow](https://github.com/PeterL1n/RobustVideoMatting/tree/tensorflow) 分支。您可自行将 PyTorch 的权值转到 TensorFlow 模型上。 311 | 312 | 313 |


314 | 315 | ### TensorFlow.js 316 | 317 | 我们在 [tfjs](https://github.com/PeterL1n/RobustVideoMatting/tree/tfjs) 分支提供范例代码。代码简单易懂,解释如何正确使用模型。 318 | 319 |


320 | 321 | ### CoreML 322 | 323 | 我们只展示在 Python 下通过 `coremltools` 使用 CoreML 模型。在部署时,同样逻辑可用于 Swift。模型的循环记忆输入不需要在处理第一帧时提供。CoreML 内部会自动创建零张量作为初始记忆。 324 | 325 | ```python 326 | import coremltools as ct 327 | 328 | model = ct.models.model.MLModel('rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel') 329 | 330 | r1, r2, r3, r4 = None, None, None, None 331 | 332 | for src in YOUR_VIDEO: # src 是 PIL.Image. 333 | 334 | if r1 is None: 335 | # 初始帧, 不用提供循环记忆 336 | inputs = {'src': src} 337 | else: 338 | # 剩余帧,提供循环记忆 339 | inputs = {'src': src, 'r1i': r1, 'r2i': r2, 'r3i': r3, 'r4i': r4} 340 | 341 | outputs = model.predict(inputs) 342 | 343 | fgr = outputs['fgr'] # PIL.Image 344 | pha = outputs['pha'] # PIL.Image 345 | 346 | r1 = outputs['r1o'] # Numpy array 347 | r2 = outputs['r2o'] # Numpy array 348 | r3 = outputs['r3o'] # Numpy array 349 | r4 = outputs['r4o'] # Numpy array 350 | 351 | ``` 352 | 353 | 我们的 CoreML 模型只支持固定分辨率。如果你需要其他分辨率,可自行导出。导出代码见 [coreml](https://github.com/PeterL1n/RobustVideoMatting/tree/coreml) 分支。 -------------------------------------------------------------------------------- /documentation/misc/aim_test.txt: -------------------------------------------------------------------------------- 1 | boy-1518482_1920.png 2 | girl-1219339_1920.png 3 | girl-1467820_1280.png 4 | girl-beautiful-young-face-53000.png 5 | long-1245787_1920.png 6 | model-600238_1920.png 7 | pexels-photo-58463.png 8 | sea-sunny-person-beach.png 9 | wedding-dresses-1486260_1280.png 10 | woman-952506_1920 (1).png 11 | woman-morning-bathrobe-bathroom.png 12 | -------------------------------------------------------------------------------- /documentation/misc/d646_test.txt: -------------------------------------------------------------------------------- 1 | test_13.png 2 | test_16.png 3 | test_18.png 4 | test_22.png 5 | test_32.png 6 | test_35.png 7 | test_39.png 8 | test_42.png 9 | test_46.png 10 | test_4.png 11 | test_6.png 12 | -------------------------------------------------------------------------------- /documentation/misc/dvm_background_test_clips.txt: -------------------------------------------------------------------------------- 1 | 0000 2 | 0001 3 | 0002 4 | 0004 5 | 0005 6 | 0007 7 | 0008 8 | 0009 9 | 0010 10 | 0012 11 | 0013 12 | 0014 13 | 0015 14 | 0016 15 | 0017 16 | 0018 17 | 0019 18 | 0021 19 | 0022 20 | 0023 21 | 0024 22 | 0025 23 | 0027 24 | 0029 25 | 0030 26 | 0032 27 | 0033 28 | 0034 29 | 0035 30 | 0037 31 | 0038 32 | 0039 33 | 0040 34 | 0041 35 | 0042 36 | 0043 37 | 0045 38 | 0046 39 | 0047 40 | 0048 41 | 0050 42 | 0051 43 | 0052 44 | 0054 45 | 0055 46 | 0057 47 | 0058 48 | 0059 49 | 0060 50 | 0061 51 | 0062 52 | 0063 53 | 0064 54 | 0065 55 | 0066 56 | 0068 57 | 0070 58 | 0071 59 | 0073 60 | 0074 61 | 0075 62 | 0077 63 | 0078 64 | 0079 65 | 0080 66 | 0081 67 | 0082 68 | 0083 69 | 0084 70 | 0085 71 | 0086 72 | 0089 73 | 0097 74 | 0100 75 | 0101 76 | 0102 77 | 0103 78 | 0104 79 | 0106 80 | 0107 81 | 0109 82 | 0110 83 | 0111 84 | 0113 85 | 0115 86 | 0116 87 | 0117 88 | 0119 89 | 0120 90 | 0121 91 | 0122 92 | 0123 93 | 0124 94 | 0125 95 | 0126 96 | 0127 97 | 0128 98 | 0129 99 | 0130 100 | 0131 101 | 0132 102 | 0133 103 | 0134 104 | 0135 105 | 0136 106 | 0137 107 | 0143 108 | 0145 109 | 0147 110 | 0148 111 | 0150 112 | 0159 113 | 0160 114 | 0161 115 | 0162 116 | 0165 117 | 0166 118 | 0168 119 | 0172 120 | 0174 121 | 0175 122 | 0176 123 | 0178 124 | 0181 125 | 0182 126 | 0183 127 | 0184 128 | 0185 129 | 0187 130 | 0194 131 | 0198 132 | 0200 133 | 0201 134 | 0207 135 | 0210 136 | 0211 137 | 0212 138 | 0215 139 | 0217 140 | 0218 141 | 0219 142 | 0220 143 | 0222 144 | 0223 145 | 0224 146 | 0225 147 | 0226 148 | 0227 149 | 0229 150 | 0230 151 | 0231 152 | 0232 153 | 0233 154 | 0234 155 | 0235 156 | 0237 157 | 0240 158 | 0241 159 | 0242 160 | 0243 161 | 0244 162 | 0245 163 | -------------------------------------------------------------------------------- /documentation/misc/imagematte_train.txt: -------------------------------------------------------------------------------- 1 | 10743257206_18e7f44f2e_b.jpg 2 | 10845279884_d2d4c7b4d1_b.jpg 3 | 1-1252426161dfXY.jpg 4 | 1-1255621189mTnS.jpg 5 | 1-1259162624NMFK.jpg 6 | 1-1259245823Un3j.jpg 7 | 11363165393_05d7a21d76_b.jpg 8 | 131686738165901828.jpg 9 | 13564741125_753939e9ce_o.jpg 10 | 14731860273_5b40b19b51_o.jpg 11 | 16087-a-young-woman-showing-a-bitten-green-apple-pv.jpg 12 | 1609484818_b9bb12b.jpg 13 | 17620-a-beautiful-woman-in-a-bikini-pv.jpg 14 | 20672673163_20c8467827_b.jpg 15 | 3262986095_2d5afe583c_b.jpg 16 | 3588101233_f91aa5e3a3.jpg 17 | 3858897226_cae5b75963_o.jpg 18 | 4889657410_2d503ca287_o.jpg 19 | 4981835627_c4e6c4ffa8_o.jpg 20 | 5025666458_576b974455_o.jpg 21 | 5149410930_3a943dc43f_b.jpg 22 | 539641011387760661.jpg 23 | 5892503248_4b882863c7_o.jpg 24 | 604673748289192179.jpg 25 | 606189768665464996.jpg 26 | 624753897218113578.jpg 27 | 657454154710122500.jpg 28 | 664308724952072193.jpg 29 | 7669262460_e4be408343_b.jpg 30 | 8244818049_dfa59a3eb8_b.jpg 31 | 8688417335_01f3bafbe5_o.jpg 32 | 9434599749_e7ccfc7812_b.jpg 33 | Aaron_Friedman_Headshot.jpg 34 | arrgh___r___28_by_mjranum_stock.jpg 35 | arrgh___r___29_by_mjranum_stock.jpg 36 | arrgh___r___30_by_mjranum_stock.jpg 37 | a-single-person-1084191_960_720.jpg 38 | ballerina-855652_1920.jpg 39 | beautiful-19075_960_720.jpg 40 | boy-454633_1920.jpg 41 | bride-2819673_1920.jpg 42 | bride-442894_1920.jpg 43 | face-1223346_960_720.jpg 44 | fashion-model-portrait.jpg 45 | fashion-model-pose.jpg 46 | girl-1535859_1920.jpg 47 | Girl_in_front_of_a_green_background.jpg 48 | goth_by_bugidifino-d4w7zms.jpg 49 | h_0.jpg 50 | h_100.jpg 51 | h_101.jpg 52 | h_102.jpg 53 | h_103.jpg 54 | h_104.jpg 55 | h_105.jpg 56 | h_106.jpg 57 | h_107.jpg 58 | h_108.jpg 59 | h_109.jpg 60 | h_10.jpg 61 | h_111.jpg 62 | h_112.jpg 63 | h_113.jpg 64 | h_114.jpg 65 | h_115.jpg 66 | h_116.jpg 67 | h_117.jpg 68 | h_118.jpg 69 | h_119.jpg 70 | h_11.jpg 71 | h_120.jpg 72 | h_121.jpg 73 | h_122.jpg 74 | h_123.jpg 75 | h_124.jpg 76 | h_125.jpg 77 | h_126.jpg 78 | h_127.jpg 79 | h_128.jpg 80 | h_129.jpg 81 | h_12.jpg 82 | h_130.jpg 83 | h_131.jpg 84 | h_132.jpg 85 | h_133.jpg 86 | h_134.jpg 87 | h_135.jpg 88 | h_136.jpg 89 | h_137.jpg 90 | h_138.jpg 91 | h_139.jpg 92 | h_13.jpg 93 | h_140.jpg 94 | h_141.jpg 95 | h_142.jpg 96 | h_143.jpg 97 | h_144.jpg 98 | h_145.jpg 99 | h_146.jpg 100 | h_147.jpg 101 | h_148.jpg 102 | h_149.jpg 103 | h_14.jpg 104 | h_151.jpg 105 | h_152.jpg 106 | h_153.jpg 107 | h_154.jpg 108 | h_155.jpg 109 | h_156.jpg 110 | h_157.jpg 111 | h_158.jpg 112 | h_159.jpg 113 | h_15.jpg 114 | h_160.jpg 115 | h_161.jpg 116 | h_162.jpg 117 | h_163.jpg 118 | h_164.jpg 119 | h_165.jpg 120 | h_166.jpg 121 | h_167.jpg 122 | h_168.jpg 123 | h_169.jpg 124 | h_170.jpg 125 | h_171.jpg 126 | h_172.jpg 127 | h_173.jpg 128 | h_174.jpg 129 | h_175.jpg 130 | h_176.jpg 131 | h_177.jpg 132 | h_178.jpg 133 | h_179.jpg 134 | h_17.jpg 135 | h_180.jpg 136 | h_181.jpg 137 | h_182.jpg 138 | h_183.jpg 139 | h_184.jpg 140 | h_185.jpg 141 | h_186.jpg 142 | h_187.jpg 143 | h_188.jpg 144 | h_189.jpg 145 | h_18.jpg 146 | h_190.jpg 147 | h_191.jpg 148 | h_192.jpg 149 | h_193.jpg 150 | h_194.jpg 151 | h_195.jpg 152 | h_196.jpg 153 | h_197.jpg 154 | h_198.jpg 155 | h_199.jpg 156 | h_19.jpg 157 | h_1.jpg 158 | h_200.jpg 159 | h_201.jpg 160 | h_202.jpg 161 | h_204.jpg 162 | h_205.jpg 163 | h_206.jpg 164 | h_207.jpg 165 | h_208.jpg 166 | h_209.jpg 167 | h_20.jpg 168 | h_210.jpg 169 | h_211.jpg 170 | h_212.jpg 171 | h_213.jpg 172 | h_214.jpg 173 | h_215.jpg 174 | h_216.jpg 175 | h_217.jpg 176 | h_218.jpg 177 | h_219.jpg 178 | h_21.jpg 179 | h_220.jpg 180 | h_221.jpg 181 | h_222.jpg 182 | h_223.jpg 183 | h_224.jpg 184 | h_225.jpg 185 | h_226.jpg 186 | h_227.jpg 187 | h_228.jpg 188 | h_229.jpg 189 | h_22.jpg 190 | h_230.jpg 191 | h_231.jpg 192 | h_232.jpg 193 | h_233.jpg 194 | h_234.jpg 195 | h_235.jpg 196 | h_236.jpg 197 | h_237.jpg 198 | h_238.jpg 199 | h_239.jpg 200 | h_23.jpg 201 | h_240.jpg 202 | h_241.jpg 203 | h_242.jpg 204 | h_243.jpg 205 | h_244.jpg 206 | h_245.jpg 207 | h_247.jpg 208 | h_248.jpg 209 | h_249.jpg 210 | h_24.jpg 211 | h_250.jpg 212 | h_251.jpg 213 | h_252.jpg 214 | h_253.jpg 215 | h_254.jpg 216 | h_255.jpg 217 | h_256.jpg 218 | h_257.jpg 219 | h_258.jpg 220 | h_259.jpg 221 | h_25.jpg 222 | h_260.jpg 223 | h_261.jpg 224 | h_262.jpg 225 | h_263.jpg 226 | h_264.jpg 227 | h_265.jpg 228 | h_266.jpg 229 | h_268.jpg 230 | h_269.jpg 231 | h_26.jpg 232 | h_270.jpg 233 | h_271.jpg 234 | h_272.jpg 235 | h_273.jpg 236 | h_274.jpg 237 | h_276.jpg 238 | h_277.jpg 239 | h_278.jpg 240 | h_279.jpg 241 | h_27.jpg 242 | h_280.jpg 243 | h_281.jpg 244 | h_282.jpg 245 | h_283.jpg 246 | h_284.jpg 247 | h_285.jpg 248 | h_286.jpg 249 | h_287.jpg 250 | h_288.jpg 251 | h_289.jpg 252 | h_28.jpg 253 | h_290.jpg 254 | h_291.jpg 255 | h_292.jpg 256 | h_293.jpg 257 | h_294.jpg 258 | h_295.jpg 259 | h_296.jpg 260 | h_297.jpg 261 | h_298.jpg 262 | h_299.jpg 263 | h_29.jpg 264 | h_300.jpg 265 | h_301.jpg 266 | h_302.jpg 267 | h_303.jpg 268 | h_304.jpg 269 | h_305.jpg 270 | h_307.jpg 271 | h_308.jpg 272 | h_309.jpg 273 | h_30.jpg 274 | h_310.jpg 275 | h_311.jpg 276 | h_312.jpg 277 | h_313.jpg 278 | h_314.jpg 279 | h_315.jpg 280 | h_316.jpg 281 | h_317.jpg 282 | h_318.jpg 283 | h_319.jpg 284 | h_31.jpg 285 | h_320.jpg 286 | h_321.jpg 287 | h_322.jpg 288 | h_323.jpg 289 | h_324.jpg 290 | h_325.jpg 291 | h_326.jpg 292 | h_327.jpg 293 | h_329.jpg 294 | h_32.jpg 295 | h_33.jpg 296 | h_34.jpg 297 | h_35.jpg 298 | h_36.jpg 299 | h_37.jpg 300 | h_38.jpg 301 | h_39.jpg 302 | h_3.jpg 303 | h_40.jpg 304 | h_41.jpg 305 | h_42.jpg 306 | h_43.jpg 307 | h_44.jpg 308 | h_45.jpg 309 | h_46.jpg 310 | h_47.jpg 311 | h_48.jpg 312 | h_49.jpg 313 | h_4.jpg 314 | h_50.jpg 315 | h_51.jpg 316 | h_52.jpg 317 | h_53.jpg 318 | h_54.jpg 319 | h_55.jpg 320 | h_56.jpg 321 | h_57.jpg 322 | h_58.jpg 323 | h_59.jpg 324 | h_5.jpg 325 | h_60.jpg 326 | h_61.jpg 327 | h_62.jpg 328 | h_63.jpg 329 | h_65.jpg 330 | h_67.jpg 331 | h_68.jpg 332 | h_69.jpg 333 | h_6.jpg 334 | h_70.jpg 335 | h_71.jpg 336 | h_72.jpg 337 | h_73.jpg 338 | h_74.jpg 339 | h_75.jpg 340 | h_76.jpg 341 | h_77.jpg 342 | h_78.jpg 343 | h_79.jpg 344 | h_7.jpg 345 | h_80.jpg 346 | h_81.jpg 347 | h_82.jpg 348 | h_83.jpg 349 | h_84.jpg 350 | h_85.jpg 351 | h_86.jpg 352 | h_87.jpg 353 | h_88.jpg 354 | h_89.jpg 355 | h_8.jpg 356 | h_90.jpg 357 | h_91.jpg 358 | h_92.jpg 359 | h_93.jpg 360 | h_94.jpg 361 | h_95.jpg 362 | h_96.jpg 363 | h_97.jpg 364 | h_98.jpg 365 | h_99.jpg 366 | h_9.jpg 367 | hair-flying-142210_1920.jpg 368 | headshotid_by_bokogreat_stock-d355xf3.jpg 369 | lil_white_goth_grl___23_by_mjranum_stock.jpg 370 | lil_white_goth_grl___26_by_mjranum_stock.jpg 371 | man-388104_960_720.jpg 372 | man_headshot.jpg 373 | MFettes-headshot.jpg 374 | model-429733_960_720.jpg 375 | model-610352_960_720.jpg 376 | model-858753_960_720.jpg 377 | model-858755_960_720.jpg 378 | model-873675_960_720.jpg 379 | model-873678_960_720.jpg 380 | model-873690_960_720.jpg 381 | model-881425_960_720.jpg 382 | model-881431_960_720.jpg 383 | model-female-girl-beautiful-51969.jpg 384 | Model_in_green_dress_3.jpg 385 | Modern_shingle_bob_haircut.jpg 386 | Motivate_(Fitness_model).jpg 387 | Official_portrait_of_Barack_Obama.jpg 388 | person-woman-eyes-face.jpg 389 | pink-hair-855660_960_720.jpg 390 | portrait-750774_1920.jpg 391 | Professor_Steven_Chu_ForMemRS_headshot.jpg 392 | sailor_flying_4_by_senshistock-d4k2wmr.jpg 393 | skin-care-937667_960_720.jpg 394 | sorcery___8_by_mjranum_stock.jpg 395 | t_62.jpg 396 | t_65.jpg 397 | test_32.jpg 398 | test_8.jpg 399 | train_245.jpg 400 | train_246.jpg 401 | train_255.jpg 402 | train_304.jpg 403 | train_333.jpg 404 | train_361.jpg 405 | train_395.jpg 406 | train_480.jpg 407 | train_488.jpg 408 | train_539.jpg 409 | wedding-846926_1920.jpg 410 | Wild_hair.jpg 411 | with_wings___pose_reference_by_senshistock-d6by42n_2.jpg 412 | with_wings___pose_reference_by_senshistock-d6by42n.jpg 413 | woman-1138435_960_720.jpg 414 | woman1.jpg 415 | woman2.jpg 416 | woman-659354_960_720.jpg 417 | woman-804072_960_720.jpg 418 | woman-868519_960_720.jpg 419 | Woman_in_white_shirt_on_August_2009_02.jpg 420 | women-878869_1920.jpg 421 | -------------------------------------------------------------------------------- /documentation/misc/imagematte_valid.txt: -------------------------------------------------------------------------------- 1 | 13564741125_753939e9ce_o.jpg 2 | 3858897226_cae5b75963_o.jpg 3 | 538724499685900405.jpg 4 | ballerina-855652_1920.jpg 5 | boy-454633_1920.jpg 6 | h_110.jpg 7 | h_150.jpg 8 | h_16.jpg 9 | h_246.jpg 10 | h_267.jpg 11 | h_275.jpg 12 | h_306.jpg 13 | h_328.jpg 14 | model-610352_960_720.jpg 15 | t_66.jpg 16 | -------------------------------------------------------------------------------- /documentation/misc/spd_preprocess.py: -------------------------------------------------------------------------------- 1 | # pip install supervisely 2 | import supervisely_lib as sly 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | # Download dataset from 9 | project_root = 'PATH_TO/Supervisely Person Dataset' # <-- Configure input 10 | project = sly.Project(project_root, sly.OpenMode.READ) 11 | 12 | output_path = 'OUTPUT_DIR' # <-- Configure output 13 | os.makedirs(os.path.join(output_path, 'train', 'src')) 14 | os.makedirs(os.path.join(output_path, 'train', 'msk')) 15 | os.makedirs(os.path.join(output_path, 'valid', 'src')) 16 | os.makedirs(os.path.join(output_path, 'valid', 'msk')) 17 | 18 | max_size = 2048 # <-- Configure max size 19 | 20 | for dataset in project.datasets: 21 | for item in tqdm(dataset): 22 | ann = sly.Annotation.load_json_file(dataset.get_ann_path(item), project.meta) 23 | msk = np.zeros(ann.img_size, dtype=np.uint8) 24 | for label in ann.labels: 25 | label.geometry.draw(msk, color=[255]) 26 | msk = Image.fromarray(msk) 27 | 28 | img = Image.open(dataset.get_img_path(item)).convert('RGB') 29 | if img.size[0] > max_size or img.size[1] > max_size: 30 | scale = max_size / max(img.size) 31 | img = img.resize((int(img.size[0] * scale), int(img.size[1] * scale)), Image.BILINEAR) 32 | msk = msk.resize((int(msk.size[0] * scale), int(msk.size[1] * scale)), Image.NEAREST) 33 | 34 | img.save(os.path.join(output_path, 'train', 'src', item.replace('.png', '.jpg'))) 35 | msk.save(os.path.join(output_path, 'train', 'msk', item.replace('.png', '.jpg'))) 36 | 37 | # Move first 100 to validation set 38 | names = os.listdir(os.path.join(output_path, 'train', 'src')) 39 | for name in tqdm(names[:100]): 40 | os.rename( 41 | os.path.join(output_path, 'train', 'src', name), 42 | os.path.join(output_path, 'valid', 'src', name)) 43 | os.rename( 44 | os.path.join(output_path, 'train', 'msk', name), 45 | os.path.join(output_path, 'valid', 'msk', name)) -------------------------------------------------------------------------------- /documentation/training.md: -------------------------------------------------------------------------------- 1 | # Training Documentation 2 | 3 | This documentation only shows the way to re-produce our [paper](https://peterl1n.github.io/RobustVideoMatting/). If you would like to remove or add a dataset to the training, you are responsible for adapting the training code yourself. 4 | 5 | ## Datasets 6 | 7 | The following datasets are used during our training. 8 | 9 | **IMPORTANT: If you choose to download our preprocessed versions. Please avoid repeated downloads and cache the data locally. All traffics cost our expense. Please be responsible. We may only provide the preprocessed version of a limited time.** 10 | 11 | ### Matting Datasets 12 | * [VideoMatte240K](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets) 13 | * Download JPEG SD version (6G) for stage 1 and 2. 14 | * Download JPEG HD version (60G) for stage 3 and 4. 15 | * Manually move clips `0000`, `0100`, `0200`, `0300` from the training set to a validation set. 16 | * ImageMatte 17 | * ImageMatte consists of [Distinctions-646](https://wukaoliu.github.io/HAttMatting/) and [Adobe Image Matting](https://sites.google.com/view/deepimagematting) datasets. 18 | * Only needed for stage 4. 19 | * You need to contact their authors to acquire. 20 | * After downloading both datasets, merge their samples together to form ImageMatte dataset. 21 | * Only keep samples of humans. 22 | * Full list of images we used in ImageMatte for training: 23 | * [imagematte_train.txt](/documentation/misc/imagematte_train.txt) 24 | * [imagematte_valid.txt](/documentation/misc/imagematte_valid.txt) 25 | * Full list of images we used for evaluation. 26 | * [aim_test.txt](/documentation/misc/aim_test.txt) 27 | * [d646_test.txt](/documentation/misc/d646_test.txt) 28 | ### Background Datasets 29 | * Video Backgrounds 30 | * We process from [DVM Background Set](https://github.com/nowsyn/DVM) by selecting clips without humans and extract only the first 100 frames as JPEG sequence. 31 | * Full list of clips we used: 32 | * [dvm_background_train_clips.txt](/documentation/misc/dvm_background_train_clips.txt) 33 | * [dvm_background_test_clips.txt](/documentation/misc/dvm_background_test_clips.txt) 34 | * You can download our preprocessed versions: 35 | * [Train set (14.6G)](https://robustvideomatting.blob.core.windows.net/data/BackgroundVideosTrain.tar) (Manually move some clips to validation set) 36 | * [Test set (936M)](https://robustvideomatting.blob.core.windows.net/data/BackgroundVideosTest.tar) (Not needed for training. Only used for making synthetic test samples for evaluation) 37 | * Image Backgrounds 38 | * Train set: 39 | * We crawled 8000 suitable images from Google and Flicker. 40 | * We will not publish these images. 41 | * [Test set](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets) 42 | * We use the validation background set from [BGMv2](https://grail.cs.washington.edu/projects/background-matting-v2/) project. 43 | * It contains about 200 images. 44 | * It is not used in our training. Only used for making synthetic test samples for evaluation. 45 | * But if you just want to quickly tryout training, you may use this as a temporary subsitute for the train set. 46 | 47 | ### Segmentation Datasets 48 | 49 | * [COCO](https://cocodataset.org/#download) 50 | * Download [train2017.zip (18G)](http://images.cocodataset.org/zips/train2017.zip) 51 | * Download [panoptic_annotations_trainval2017.zip (821M)](http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip) 52 | * Note that our train script expects the panopitc version. 53 | * [YouTubeVIS 2021](https://youtube-vos.org/dataset/vis/) 54 | * Download the train set. No preprocessing needed. 55 | * [Supervisely Person Dataset](https://supervise.ly/explore/projects/supervisely-person-dataset-23304/datasets) 56 | * We used the supervisedly library to convert their encoding to bitmaps masks before using our script. We also resized down some of the large images to avoid disk loading bottleneck. 57 | * You can refer to [spd_preprocess.py](/documentation/misc/spd_preprocess.py) 58 | * Or, you can download our [preprocessed version (800M)](https://robustvideomatting.blob.core.windows.net/data/SuperviselyPersonDataset.tar) 59 | 60 | ## Training 61 | 62 | For reference, our training was done on data center machines with 48 CPU cores, 300G CPU memory, and 4 Nvidia V100 32G GPUs. 63 | 64 | During our official training, the code contains custom logics for our infrastructure. For release, the script has been cleaned up. There may be bugs existing in this version of the code but not in our official training. If you find problems, please file an issue. 65 | 66 | After you have downloaded the datasets. Please configure `train_config.py` to provide paths to your datasets. 67 | 68 | The training consists of 4 stages. For detail, please refer to the [paper](https://peterl1n.github.io/RobustVideoMatting/). 69 | 70 | ### Stage 1 71 | ```sh 72 | python train.py \ 73 | --model-variant mobilenetv3 \ 74 | --dataset videomatte \ 75 | --resolution-lr 512 \ 76 | --seq-length-lr 15 \ 77 | --learning-rate-backbone 0.0001 \ 78 | --learning-rate-aspp 0.0002 \ 79 | --learning-rate-decoder 0.0002 \ 80 | --learning-rate-refiner 0 \ 81 | --checkpoint-dir checkpoint/stage1 \ 82 | --log-dir log/stage1 \ 83 | --epoch-start 0 \ 84 | --epoch-end 20 85 | ``` 86 | 87 | ### Stage 2 88 | ```sh 89 | python train.py \ 90 | --model-variant mobilenetv3 \ 91 | --dataset videomatte \ 92 | --resolution-lr 512 \ 93 | --seq-length-lr 50 \ 94 | --learning-rate-backbone 0.00005 \ 95 | --learning-rate-aspp 0.0001 \ 96 | --learning-rate-decoder 0.0001 \ 97 | --learning-rate-refiner 0 \ 98 | --checkpoint checkpoint/stage1/epoch-19.pth \ 99 | --checkpoint-dir checkpoint/stage2 \ 100 | --log-dir log/stage2 \ 101 | --epoch-start 20 \ 102 | --epoch-end 22 103 | ``` 104 | 105 | ### Stage 3 106 | ```sh 107 | python train.py \ 108 | --model-variant mobilenetv3 \ 109 | --dataset videomatte \ 110 | --train-hr \ 111 | --resolution-lr 512 \ 112 | --resolution-hr 2048 \ 113 | --seq-length-lr 40 \ 114 | --seq-length-hr 6 \ 115 | --learning-rate-backbone 0.00001 \ 116 | --learning-rate-aspp 0.00001 \ 117 | --learning-rate-decoder 0.00001 \ 118 | --learning-rate-refiner 0.0002 \ 119 | --checkpoint checkpoint/stage2/epoch-21.pth \ 120 | --checkpoint-dir checkpoint/stage3 \ 121 | --log-dir log/stage3 \ 122 | --epoch-start 22 \ 123 | --epoch-end 23 124 | ``` 125 | 126 | ### Stage 4 127 | ```sh 128 | python train.py \ 129 | --model-variant mobilenetv3 \ 130 | --dataset imagematte \ 131 | --train-hr \ 132 | --resolution-lr 512 \ 133 | --resolution-hr 2048 \ 134 | --seq-length-lr 40 \ 135 | --seq-length-hr 6 \ 136 | --learning-rate-backbone 0.00001 \ 137 | --learning-rate-aspp 0.00001 \ 138 | --learning-rate-decoder 0.00005 \ 139 | --learning-rate-refiner 0.0002 \ 140 | --checkpoint checkpoint/stage3/epoch-22.pth \ 141 | --checkpoint-dir checkpoint/stage4 \ 142 | --log-dir log/stage4 \ 143 | --epoch-start 23 \ 144 | --epoch-end 28 145 | ``` 146 | 147 |


148 | 149 | ## Evaluation 150 | 151 | We synthetically composite test samples to both image and video backgrounds. Image samples (from D646, AIM) are augmented with synthetic motion. 152 | 153 | We only provide the composited VideoMatte240K test set. They are used in our paper evaluation. For D646 and AIM, you need to acquire the data from their authors and composite them yourself. The composition scripts we used are saved in `/evaluation` folder as reference backup. You need to modify them based on your setup. 154 | 155 | * [videomatte_512x512.tar (PNG 1.8G)](https://robustvideomatting.blob.core.windows.net/eval/videomatte_512x288.tar) 156 | * [videomatte_1920x1080.tar (JPG 2.2G)](https://robustvideomatting.blob.core.windows.net/eval/videomatte_1920x1080.tar) 157 | 158 | Evaluation scripts are provided in `/evaluation` folder. -------------------------------------------------------------------------------- /evaluation/evaluate_hr.py: -------------------------------------------------------------------------------- 1 | """ 2 | HR (High-Resolution) evaluation. We found using numpy is very slow for high resolution, so we moved it to PyTorch using CUDA. 3 | 4 | Note, the script only does evaluation. You will need to first inference yourself and save the results to disk 5 | Expected directory format for both prediction and ground-truth is: 6 | 7 | videomatte_1920x1080 8 | ├── videomatte_motion 9 | ├── pha 10 | ├── 0000 11 | ├── 0000.png 12 | ├── fgr 13 | ├── 0000 14 | ├── 0000.png 15 | ├── videomatte_static 16 | ├── pha 17 | ├── 0000 18 | ├── 0000.png 19 | ├── fgr 20 | ├── 0000 21 | ├── 0000.png 22 | 23 | Prediction must have the exact file structure and file name as the ground-truth, 24 | meaning that if the ground-truth is png/jpg, prediction should be png/jpg. 25 | 26 | Example usage: 27 | 28 | python evaluate.py \ 29 | --pred-dir pred/videomatte_1920x1080 \ 30 | --true-dir true/videomatte_1920x1080 31 | 32 | An excel sheet with evaluation results will be written to "pred/videomatte_1920x1080/videomatte_1920x1080.xlsx" 33 | """ 34 | 35 | 36 | import argparse 37 | import os 38 | import cv2 39 | import kornia 40 | import numpy as np 41 | import xlsxwriter 42 | import torch 43 | from concurrent.futures import ThreadPoolExecutor 44 | from tqdm import tqdm 45 | 46 | 47 | class Evaluator: 48 | def __init__(self): 49 | self.parse_args() 50 | self.init_metrics() 51 | self.evaluate() 52 | self.write_excel() 53 | 54 | def parse_args(self): 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--pred-dir', type=str, required=True) 57 | parser.add_argument('--true-dir', type=str, required=True) 58 | parser.add_argument('--num-workers', type=int, default=48) 59 | parser.add_argument('--metrics', type=str, nargs='+', default=[ 60 | 'pha_mad', 'pha_mse', 'pha_grad', 'pha_dtssd', 'fgr_mse']) 61 | self.args = parser.parse_args() 62 | 63 | def init_metrics(self): 64 | self.mad = MetricMAD() 65 | self.mse = MetricMSE() 66 | self.grad = MetricGRAD() 67 | self.dtssd = MetricDTSSD() 68 | 69 | def evaluate(self): 70 | tasks = [] 71 | position = 0 72 | 73 | with ThreadPoolExecutor(max_workers=self.args.num_workers) as executor: 74 | for dataset in sorted(os.listdir(self.args.pred_dir)): 75 | if os.path.isdir(os.path.join(self.args.pred_dir, dataset)): 76 | for clip in sorted(os.listdir(os.path.join(self.args.pred_dir, dataset))): 77 | future = executor.submit(self.evaluate_worker, dataset, clip, position) 78 | tasks.append((dataset, clip, future)) 79 | position += 1 80 | 81 | self.results = [(dataset, clip, future.result()) for dataset, clip, future in tasks] 82 | 83 | def write_excel(self): 84 | workbook = xlsxwriter.Workbook(os.path.join(self.args.pred_dir, f'{os.path.basename(self.args.pred_dir)}.xlsx')) 85 | summarysheet = workbook.add_worksheet('summary') 86 | metricsheets = [workbook.add_worksheet(metric) for metric in self.results[0][2].keys()] 87 | 88 | for i, metric in enumerate(self.results[0][2].keys()): 89 | summarysheet.write(i, 0, metric) 90 | summarysheet.write(i, 1, f'={metric}!B2') 91 | 92 | for row, (dataset, clip, metrics) in enumerate(self.results): 93 | for metricsheet, metric in zip(metricsheets, metrics.values()): 94 | # Write the header 95 | if row == 0: 96 | metricsheet.write(1, 0, 'Average') 97 | metricsheet.write(1, 1, f'=AVERAGE(C2:ZZ2)') 98 | for col in range(len(metric)): 99 | metricsheet.write(0, col + 2, col) 100 | colname = xlsxwriter.utility.xl_col_to_name(col + 2) 101 | metricsheet.write(1, col + 2, f'=AVERAGE({colname}3:{colname}9999)') 102 | 103 | metricsheet.write(row + 2, 0, dataset) 104 | metricsheet.write(row + 2, 1, clip) 105 | metricsheet.write_row(row + 2, 2, metric) 106 | 107 | workbook.close() 108 | 109 | def evaluate_worker(self, dataset, clip, position): 110 | framenames = sorted(os.listdir(os.path.join(self.args.pred_dir, dataset, clip, 'pha'))) 111 | metrics = {metric_name : [] for metric_name in self.args.metrics} 112 | 113 | pred_pha_tm1 = None 114 | true_pha_tm1 = None 115 | 116 | for i, framename in enumerate(tqdm(framenames, desc=f'{dataset} {clip}', position=position, dynamic_ncols=True)): 117 | true_pha = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE) 118 | pred_pha = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE) 119 | 120 | true_pha = torch.from_numpy(true_pha).cuda(non_blocking=True).float().div_(255) 121 | pred_pha = torch.from_numpy(pred_pha).cuda(non_blocking=True).float().div_(255) 122 | 123 | if 'pha_mad' in self.args.metrics: 124 | metrics['pha_mad'].append(self.mad(pred_pha, true_pha)) 125 | if 'pha_mse' in self.args.metrics: 126 | metrics['pha_mse'].append(self.mse(pred_pha, true_pha)) 127 | if 'pha_grad' in self.args.metrics: 128 | metrics['pha_grad'].append(self.grad(pred_pha, true_pha)) 129 | if 'pha_conn' in self.args.metrics: 130 | metrics['pha_conn'].append(self.conn(pred_pha, true_pha)) 131 | if 'pha_dtssd' in self.args.metrics: 132 | if i == 0: 133 | metrics['pha_dtssd'].append(0) 134 | else: 135 | metrics['pha_dtssd'].append(self.dtssd(pred_pha, pred_pha_tm1, true_pha, true_pha_tm1)) 136 | 137 | pred_pha_tm1 = pred_pha 138 | true_pha_tm1 = true_pha 139 | 140 | if 'fgr_mse' in self.args.metrics: 141 | true_fgr = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR) 142 | pred_fgr = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR) 143 | 144 | true_fgr = torch.from_numpy(true_fgr).float().div_(255) 145 | pred_fgr = torch.from_numpy(pred_fgr).float().div_(255) 146 | 147 | true_msk = true_pha > 0 148 | metrics['fgr_mse'].append(self.mse(pred_fgr[true_msk], true_fgr[true_msk])) 149 | 150 | return metrics 151 | 152 | 153 | class MetricMAD: 154 | def __call__(self, pred, true): 155 | return (pred - true).abs_().mean() * 1e3 156 | 157 | 158 | class MetricMSE: 159 | def __call__(self, pred, true): 160 | return ((pred - true) ** 2).mean() * 1e3 161 | 162 | 163 | class MetricGRAD: 164 | def __init__(self, sigma=1.4): 165 | self.filter_x, self.filter_y = self.gauss_filter(sigma) 166 | self.filter_x = torch.from_numpy(self.filter_x).unsqueeze(0).cuda() 167 | self.filter_y = torch.from_numpy(self.filter_y).unsqueeze(0).cuda() 168 | 169 | def __call__(self, pred, true): 170 | true_grad = self.gauss_gradient(true) 171 | pred_grad = self.gauss_gradient(pred) 172 | return ((true_grad - pred_grad) ** 2).sum() / 1000 173 | 174 | def gauss_gradient(self, img): 175 | img_filtered_x = kornia.filters.filter2D(img[None, None, :, :], self.filter_x, border_type='replicate')[0, 0] 176 | img_filtered_y = kornia.filters.filter2D(img[None, None, :, :], self.filter_y, border_type='replicate')[0, 0] 177 | return (img_filtered_x**2 + img_filtered_y**2).sqrt() 178 | 179 | @staticmethod 180 | def gauss_filter(sigma, epsilon=1e-2): 181 | half_size = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) 182 | size = np.int(2 * half_size + 1) 183 | 184 | # create filter in x axis 185 | filter_x = np.zeros((size, size)) 186 | for i in range(size): 187 | for j in range(size): 188 | filter_x[i, j] = MetricGRAD.gaussian(i - half_size, sigma) * MetricGRAD.dgaussian( 189 | j - half_size, sigma) 190 | 191 | # normalize filter 192 | norm = np.sqrt((filter_x**2).sum()) 193 | filter_x = filter_x / norm 194 | filter_y = np.transpose(filter_x) 195 | 196 | return filter_x, filter_y 197 | 198 | @staticmethod 199 | def gaussian(x, sigma): 200 | return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) 201 | 202 | @staticmethod 203 | def dgaussian(x, sigma): 204 | return -x * MetricGRAD.gaussian(x, sigma) / sigma**2 205 | 206 | 207 | class MetricDTSSD: 208 | def __call__(self, pred_t, pred_tm1, true_t, true_tm1): 209 | dtSSD = ((pred_t - pred_tm1) - (true_t - true_tm1)) ** 2 210 | dtSSD = dtSSD.sum() / true_t.numel() 211 | dtSSD = dtSSD.sqrt() 212 | return dtSSD * 1e2 213 | 214 | 215 | if __name__ == '__main__': 216 | Evaluator() -------------------------------------------------------------------------------- /evaluation/evaluate_lr.py: -------------------------------------------------------------------------------- 1 | """ 2 | LR (Low-Resolution) evaluation. 3 | 4 | Note, the script only does evaluation. You will need to first inference yourself and save the results to disk 5 | Expected directory format for both prediction and ground-truth is: 6 | 7 | videomatte_512x288 8 | ├── videomatte_motion 9 | ├── pha 10 | ├── 0000 11 | ├── 0000.png 12 | ├── fgr 13 | ├── 0000 14 | ├── 0000.png 15 | ├── videomatte_static 16 | ├── pha 17 | ├── 0000 18 | ├── 0000.png 19 | ├── fgr 20 | ├── 0000 21 | ├── 0000.png 22 | 23 | Prediction must have the exact file structure and file name as the ground-truth, 24 | meaning that if the ground-truth is png/jpg, prediction should be png/jpg. 25 | 26 | Example usage: 27 | 28 | python evaluate.py \ 29 | --pred-dir PATH_TO_PREDICTIONS/videomatte_512x288 \ 30 | --true-dir PATH_TO_GROUNDTURTH/videomatte_512x288 31 | 32 | An excel sheet with evaluation results will be written to "PATH_TO_PREDICTIONS/videomatte_512x288/videomatte_512x288.xlsx" 33 | """ 34 | 35 | 36 | import argparse 37 | import os 38 | import cv2 39 | import numpy as np 40 | import xlsxwriter 41 | from concurrent.futures import ThreadPoolExecutor 42 | from tqdm import tqdm 43 | 44 | 45 | class Evaluator: 46 | def __init__(self): 47 | self.parse_args() 48 | self.init_metrics() 49 | self.evaluate() 50 | self.write_excel() 51 | 52 | def parse_args(self): 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--pred-dir', type=str, required=True) 55 | parser.add_argument('--true-dir', type=str, required=True) 56 | parser.add_argument('--num-workers', type=int, default=48) 57 | parser.add_argument('--metrics', type=str, nargs='+', default=[ 58 | 'pha_mad', 'pha_mse', 'pha_grad', 'pha_conn', 'pha_dtssd', 'fgr_mad', 'fgr_mse']) 59 | self.args = parser.parse_args() 60 | 61 | def init_metrics(self): 62 | self.mad = MetricMAD() 63 | self.mse = MetricMSE() 64 | self.grad = MetricGRAD() 65 | self.conn = MetricCONN() 66 | self.dtssd = MetricDTSSD() 67 | 68 | def evaluate(self): 69 | tasks = [] 70 | position = 0 71 | 72 | with ThreadPoolExecutor(max_workers=self.args.num_workers) as executor: 73 | for dataset in sorted(os.listdir(self.args.pred_dir)): 74 | if os.path.isdir(os.path.join(self.args.pred_dir, dataset)): 75 | for clip in sorted(os.listdir(os.path.join(self.args.pred_dir, dataset))): 76 | future = executor.submit(self.evaluate_worker, dataset, clip, position) 77 | tasks.append((dataset, clip, future)) 78 | position += 1 79 | 80 | self.results = [(dataset, clip, future.result()) for dataset, clip, future in tasks] 81 | 82 | def write_excel(self): 83 | workbook = xlsxwriter.Workbook(os.path.join(self.args.pred_dir, f'{os.path.basename(self.args.pred_dir)}.xlsx')) 84 | summarysheet = workbook.add_worksheet('summary') 85 | metricsheets = [workbook.add_worksheet(metric) for metric in self.results[0][2].keys()] 86 | 87 | for i, metric in enumerate(self.results[0][2].keys()): 88 | summarysheet.write(i, 0, metric) 89 | summarysheet.write(i, 1, f'={metric}!B2') 90 | 91 | for row, (dataset, clip, metrics) in enumerate(self.results): 92 | for metricsheet, metric in zip(metricsheets, metrics.values()): 93 | # Write the header 94 | if row == 0: 95 | metricsheet.write(1, 0, 'Average') 96 | metricsheet.write(1, 1, f'=AVERAGE(C2:ZZ2)') 97 | for col in range(len(metric)): 98 | metricsheet.write(0, col + 2, col) 99 | colname = xlsxwriter.utility.xl_col_to_name(col + 2) 100 | metricsheet.write(1, col + 2, f'=AVERAGE({colname}3:{colname}9999)') 101 | 102 | metricsheet.write(row + 2, 0, dataset) 103 | metricsheet.write(row + 2, 1, clip) 104 | metricsheet.write_row(row + 2, 2, metric) 105 | 106 | workbook.close() 107 | 108 | def evaluate_worker(self, dataset, clip, position): 109 | framenames = sorted(os.listdir(os.path.join(self.args.pred_dir, dataset, clip, 'pha'))) 110 | metrics = {metric_name : [] for metric_name in self.args.metrics} 111 | 112 | pred_pha_tm1 = None 113 | true_pha_tm1 = None 114 | 115 | for i, framename in enumerate(tqdm(framenames, desc=f'{dataset} {clip}', position=position, dynamic_ncols=True)): 116 | true_pha = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255 117 | pred_pha = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'pha', framename), cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255 118 | if 'pha_mad' in self.args.metrics: 119 | metrics['pha_mad'].append(self.mad(pred_pha, true_pha)) 120 | if 'pha_mse' in self.args.metrics: 121 | metrics['pha_mse'].append(self.mse(pred_pha, true_pha)) 122 | if 'pha_grad' in self.args.metrics: 123 | metrics['pha_grad'].append(self.grad(pred_pha, true_pha)) 124 | if 'pha_conn' in self.args.metrics: 125 | metrics['pha_conn'].append(self.conn(pred_pha, true_pha)) 126 | if 'pha_dtssd' in self.args.metrics: 127 | if i == 0: 128 | metrics['pha_dtssd'].append(0) 129 | else: 130 | metrics['pha_dtssd'].append(self.dtssd(pred_pha, pred_pha_tm1, true_pha, true_pha_tm1)) 131 | 132 | pred_pha_tm1 = pred_pha 133 | true_pha_tm1 = true_pha 134 | 135 | if 'fgr_mse' in self.args.metrics or 'fgr_mad' in self.args.metrics: 136 | true_fgr = cv2.imread(os.path.join(self.args.true_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR).astype(np.float32) / 255 137 | pred_fgr = cv2.imread(os.path.join(self.args.pred_dir, dataset, clip, 'fgr', framename), cv2.IMREAD_COLOR).astype(np.float32) / 255 138 | true_msk = true_pha > 0 139 | 140 | if 'fgr_mse' in self.args.metrics: 141 | metrics['fgr_mse'].append(self.mse(pred_fgr[true_msk], true_fgr[true_msk])) 142 | if 'fgr_mad' in self.args.metrics: 143 | metrics['fgr_mad'].append(self.mad(pred_fgr[true_msk], true_fgr[true_msk])) 144 | 145 | return metrics 146 | 147 | 148 | class MetricMAD: 149 | def __call__(self, pred, true): 150 | return np.abs(pred - true).mean() * 1e3 151 | 152 | 153 | class MetricMSE: 154 | def __call__(self, pred, true): 155 | return ((pred - true) ** 2).mean() * 1e3 156 | 157 | 158 | class MetricGRAD: 159 | def __init__(self, sigma=1.4): 160 | self.filter_x, self.filter_y = self.gauss_filter(sigma) 161 | 162 | def __call__(self, pred, true): 163 | pred_normed = np.zeros_like(pred) 164 | true_normed = np.zeros_like(true) 165 | cv2.normalize(pred, pred_normed, 1., 0., cv2.NORM_MINMAX) 166 | cv2.normalize(true, true_normed, 1., 0., cv2.NORM_MINMAX) 167 | 168 | true_grad = self.gauss_gradient(true_normed).astype(np.float32) 169 | pred_grad = self.gauss_gradient(pred_normed).astype(np.float32) 170 | 171 | grad_loss = ((true_grad - pred_grad) ** 2).sum() 172 | return grad_loss / 1000 173 | 174 | def gauss_gradient(self, img): 175 | img_filtered_x = cv2.filter2D(img, -1, self.filter_x, borderType=cv2.BORDER_REPLICATE) 176 | img_filtered_y = cv2.filter2D(img, -1, self.filter_y, borderType=cv2.BORDER_REPLICATE) 177 | return np.sqrt(img_filtered_x**2 + img_filtered_y**2) 178 | 179 | @staticmethod 180 | def gauss_filter(sigma, epsilon=1e-2): 181 | half_size = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))) 182 | size = np.int(2 * half_size + 1) 183 | 184 | # create filter in x axis 185 | filter_x = np.zeros((size, size)) 186 | for i in range(size): 187 | for j in range(size): 188 | filter_x[i, j] = MetricGRAD.gaussian(i - half_size, sigma) * MetricGRAD.dgaussian( 189 | j - half_size, sigma) 190 | 191 | # normalize filter 192 | norm = np.sqrt((filter_x**2).sum()) 193 | filter_x = filter_x / norm 194 | filter_y = np.transpose(filter_x) 195 | 196 | return filter_x, filter_y 197 | 198 | @staticmethod 199 | def gaussian(x, sigma): 200 | return np.exp(-x**2 / (2 * sigma**2)) / (sigma * np.sqrt(2 * np.pi)) 201 | 202 | @staticmethod 203 | def dgaussian(x, sigma): 204 | return -x * MetricGRAD.gaussian(x, sigma) / sigma**2 205 | 206 | 207 | class MetricCONN: 208 | def __call__(self, pred, true): 209 | step=0.1 210 | thresh_steps = np.arange(0, 1 + step, step) 211 | round_down_map = -np.ones_like(true) 212 | for i in range(1, len(thresh_steps)): 213 | true_thresh = true >= thresh_steps[i] 214 | pred_thresh = pred >= thresh_steps[i] 215 | intersection = (true_thresh & pred_thresh).astype(np.uint8) 216 | 217 | # connected components 218 | _, output, stats, _ = cv2.connectedComponentsWithStats( 219 | intersection, connectivity=4) 220 | # start from 1 in dim 0 to exclude background 221 | size = stats[1:, -1] 222 | 223 | # largest connected component of the intersection 224 | omega = np.zeros_like(true) 225 | if len(size) != 0: 226 | max_id = np.argmax(size) 227 | # plus one to include background 228 | omega[output == max_id + 1] = 1 229 | 230 | mask = (round_down_map == -1) & (omega == 0) 231 | round_down_map[mask] = thresh_steps[i - 1] 232 | round_down_map[round_down_map == -1] = 1 233 | 234 | true_diff = true - round_down_map 235 | pred_diff = pred - round_down_map 236 | # only calculate difference larger than or equal to 0.15 237 | true_phi = 1 - true_diff * (true_diff >= 0.15) 238 | pred_phi = 1 - pred_diff * (pred_diff >= 0.15) 239 | 240 | connectivity_error = np.sum(np.abs(true_phi - pred_phi)) 241 | return connectivity_error / 1000 242 | 243 | 244 | class MetricDTSSD: 245 | def __call__(self, pred_t, pred_tm1, true_t, true_tm1): 246 | dtSSD = ((pred_t - pred_tm1) - (true_t - true_tm1)) ** 2 247 | dtSSD = np.sum(dtSSD) / true_t.size 248 | dtSSD = np.sqrt(dtSSD) 249 | return dtSSD * 1e2 250 | 251 | 252 | 253 | if __name__ == '__main__': 254 | Evaluator() -------------------------------------------------------------------------------- /evaluation/generate_imagematte_with_background_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | python generate_imagematte_with_background_image.py \ 3 | --imagematte-dir ../matting-data/Distinctions/test \ 4 | --background-dir ../matting-data/Backgrounds/valid \ 5 | --resolution 512 \ 6 | --out-dir ../matting-data/evaluation/distinction_static_sd/ \ 7 | --random-seed 10 8 | 9 | Seed: 10 | 10 - distinction-static 11 | 11 - distinction-motion 12 | 12 - adobe-static 13 | 13 - adobe-motion 14 | 15 | """ 16 | 17 | import argparse 18 | import os 19 | import pims 20 | import numpy as np 21 | import random 22 | from PIL import Image 23 | from tqdm import tqdm 24 | from tqdm.contrib.concurrent import process_map 25 | from torchvision import transforms 26 | from torchvision.transforms import functional as F 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--imagematte-dir', type=str, required=True) 30 | parser.add_argument('--background-dir', type=str, required=True) 31 | parser.add_argument('--num-samples', type=int, default=20) 32 | parser.add_argument('--num-frames', type=int, default=100) 33 | parser.add_argument('--resolution', type=int, required=True) 34 | parser.add_argument('--out-dir', type=str, required=True) 35 | parser.add_argument('--random-seed', type=int) 36 | parser.add_argument('--extension', type=str, default='.png') 37 | args = parser.parse_args() 38 | 39 | random.seed(args.random_seed) 40 | 41 | imagematte_filenames = os.listdir(os.path.join(args.imagematte_dir, 'fgr')) 42 | background_filenames = os.listdir(args.background_dir) 43 | random.shuffle(imagematte_filenames) 44 | random.shuffle(background_filenames) 45 | 46 | 47 | def lerp(a, b, percentage): 48 | return a * (1 - percentage) + b * percentage 49 | 50 | def motion_affine(*imgs): 51 | config = dict(degrees=(-10, 10), translate=(0.1, 0.1), 52 | scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) 53 | angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) 54 | angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) 55 | 56 | T = len(imgs[0]) 57 | variation_over_time = random.random() 58 | for t in range(T): 59 | percentage = (t / (T - 1)) * variation_over_time 60 | angle = lerp(angleA, angleB, percentage) 61 | transX = lerp(transXA, transXB, percentage) 62 | transY = lerp(transYA, transYB, percentage) 63 | scale = lerp(scaleA, scaleB, percentage) 64 | shearX = lerp(shearXA, shearXB, percentage) 65 | shearY = lerp(shearYA, shearYB, percentage) 66 | for img in imgs: 67 | img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) 68 | return imgs 69 | 70 | 71 | 72 | def process(i): 73 | imagematte_filename = imagematte_filenames[i % len(imagematte_filenames)] 74 | background_filename = background_filenames[i % len(background_filenames)] 75 | 76 | out_path = os.path.join(args.out_dir, str(i).zfill(4)) 77 | os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True) 78 | os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True) 79 | os.makedirs(os.path.join(out_path, 'com'), exist_ok=True) 80 | os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True) 81 | 82 | with Image.open(os.path.join(args.background_dir, background_filename)) as bgr: 83 | bgr = bgr.convert('RGB') 84 | 85 | w, h = bgr.size 86 | scale = args.resolution / min(h, w) 87 | w, h = int(w * scale), int(h * scale) 88 | bgr = bgr.resize((w, h)) 89 | bgr = F.center_crop(bgr, (args.resolution, args.resolution)) 90 | 91 | with Image.open(os.path.join(args.imagematte_dir, 'fgr', imagematte_filename)) as fgr, \ 92 | Image.open(os.path.join(args.imagematte_dir, 'pha', imagematte_filename)) as pha: 93 | fgr = fgr.convert('RGB') 94 | pha = pha.convert('L') 95 | 96 | fgrs = [fgr] * args.num_frames 97 | phas = [pha] * args.num_frames 98 | fgrs, phas = motion_affine(fgrs, phas) 99 | 100 | for t in tqdm(range(args.num_frames), desc=str(i).zfill(4)): 101 | fgr = fgrs[t] 102 | pha = phas[t] 103 | 104 | w, h = fgr.size 105 | scale = args.resolution / max(h, w) 106 | w, h = int(w * scale), int(h * scale) 107 | 108 | fgr = fgr.resize((w, h)) 109 | pha = pha.resize((w, h)) 110 | 111 | if h < args.resolution: 112 | pt = (args.resolution - h) // 2 113 | pb = args.resolution - h - pt 114 | else: 115 | pt = 0 116 | pb = 0 117 | 118 | if w < args.resolution: 119 | pl = (args.resolution - w) // 2 120 | pr = args.resolution - w - pl 121 | else: 122 | pl = 0 123 | pr = 0 124 | 125 | fgr = F.pad(fgr, [pl, pt, pr, pb]) 126 | pha = F.pad(pha, [pl, pt, pr, pb]) 127 | 128 | if i // len(imagematte_filenames) % 2 == 1: 129 | fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT) 130 | pha = pha.transpose(Image.FLIP_LEFT_RIGHT) 131 | 132 | fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension)) 133 | pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension)) 134 | 135 | if t == 0: 136 | bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) 137 | else: 138 | os.symlink(str(0).zfill(4) + args.extension, os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) 139 | 140 | pha = np.asarray(pha).astype(float)[:, :, None] / 255 141 | com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha))) 142 | com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension)) 143 | 144 | 145 | if __name__ == '__main__': 146 | r = process_map(process, range(args.num_samples), max_workers=32) -------------------------------------------------------------------------------- /evaluation/generate_imagematte_with_background_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | python generate_imagematte_with_background_video.py \ 3 | --imagematte-dir ../matting-data/Distinctions/test \ 4 | --background-dir ../matting-data/BackgroundVideos_mp4/test \ 5 | --resolution 512 \ 6 | --out-dir ../matting-data/evaluation/distinction_motion_sd/ \ 7 | --random-seed 11 8 | 9 | Seed: 10 | 10 - distinction-static 11 | 11 - distinction-motion 12 | 12 - adobe-static 13 | 13 - adobe-motion 14 | 15 | """ 16 | 17 | import argparse 18 | import os 19 | import pims 20 | import numpy as np 21 | import random 22 | from multiprocessing import Pool 23 | from PIL import Image 24 | # from tqdm import tqdm 25 | from tqdm.contrib.concurrent import process_map 26 | from torchvision import transforms 27 | from torchvision.transforms import functional as F 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--imagematte-dir', type=str, required=True) 31 | parser.add_argument('--background-dir', type=str, required=True) 32 | parser.add_argument('--num-samples', type=int, default=20) 33 | parser.add_argument('--num-frames', type=int, default=100) 34 | parser.add_argument('--resolution', type=int, required=True) 35 | parser.add_argument('--out-dir', type=str, required=True) 36 | parser.add_argument('--random-seed', type=int) 37 | parser.add_argument('--extension', type=str, default='.png') 38 | args = parser.parse_args() 39 | 40 | random.seed(args.random_seed) 41 | 42 | imagematte_filenames = os.listdir(os.path.join(args.imagematte_dir, 'fgr')) 43 | random.shuffle(imagematte_filenames) 44 | 45 | background_filenames = [ 46 | "0000.mp4", 47 | "0007.mp4", 48 | "0008.mp4", 49 | "0010.mp4", 50 | "0013.mp4", 51 | "0015.mp4", 52 | "0016.mp4", 53 | "0018.mp4", 54 | "0021.mp4", 55 | "0029.mp4", 56 | "0033.mp4", 57 | "0035.mp4", 58 | "0039.mp4", 59 | "0050.mp4", 60 | "0052.mp4", 61 | "0055.mp4", 62 | "0060.mp4", 63 | "0063.mp4", 64 | "0087.mp4", 65 | "0086.mp4", 66 | "0090.mp4", 67 | "0101.mp4", 68 | "0110.mp4", 69 | "0117.mp4", 70 | "0120.mp4", 71 | "0122.mp4", 72 | "0123.mp4", 73 | "0125.mp4", 74 | "0128.mp4", 75 | "0131.mp4", 76 | "0172.mp4", 77 | "0176.mp4", 78 | "0181.mp4", 79 | "0187.mp4", 80 | "0193.mp4", 81 | "0198.mp4", 82 | "0220.mp4", 83 | "0221.mp4", 84 | "0224.mp4", 85 | "0229.mp4", 86 | "0233.mp4", 87 | "0238.mp4", 88 | "0241.mp4", 89 | "0245.mp4", 90 | "0246.mp4" 91 | ] 92 | 93 | random.shuffle(background_filenames) 94 | 95 | def lerp(a, b, percentage): 96 | return a * (1 - percentage) + b * percentage 97 | 98 | def motion_affine(*imgs): 99 | config = dict(degrees=(-10, 10), translate=(0.1, 0.1), 100 | scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) 101 | angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) 102 | angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) 103 | 104 | T = len(imgs[0]) 105 | variation_over_time = random.random() 106 | for t in range(T): 107 | percentage = (t / (T - 1)) * variation_over_time 108 | angle = lerp(angleA, angleB, percentage) 109 | transX = lerp(transXA, transXB, percentage) 110 | transY = lerp(transYA, transYB, percentage) 111 | scale = lerp(scaleA, scaleB, percentage) 112 | shearX = lerp(shearXA, shearXB, percentage) 113 | shearY = lerp(shearYA, shearYB, percentage) 114 | for img in imgs: 115 | img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) 116 | return imgs 117 | 118 | 119 | def process(i): 120 | imagematte_filename = imagematte_filenames[i % len(imagematte_filenames)] 121 | background_filename = background_filenames[i % len(background_filenames)] 122 | 123 | bgrs = pims.PyAVVideoReader(os.path.join(args.background_dir, background_filename)) 124 | 125 | out_path = os.path.join(args.out_dir, str(i).zfill(4)) 126 | os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True) 127 | os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True) 128 | os.makedirs(os.path.join(out_path, 'com'), exist_ok=True) 129 | os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True) 130 | 131 | with Image.open(os.path.join(args.imagematte_dir, 'fgr', imagematte_filename)) as fgr, \ 132 | Image.open(os.path.join(args.imagematte_dir, 'pha', imagematte_filename)) as pha: 133 | fgr = fgr.convert('RGB') 134 | pha = pha.convert('L') 135 | 136 | fgrs = [fgr] * args.num_frames 137 | phas = [pha] * args.num_frames 138 | fgrs, phas = motion_affine(fgrs, phas) 139 | 140 | for t in range(args.num_frames): 141 | fgr = fgrs[t] 142 | pha = phas[t] 143 | 144 | w, h = fgr.size 145 | scale = args.resolution / max(h, w) 146 | w, h = int(w * scale), int(h * scale) 147 | 148 | fgr = fgr.resize((w, h)) 149 | pha = pha.resize((w, h)) 150 | 151 | if h < args.resolution: 152 | pt = (args.resolution - h) // 2 153 | pb = args.resolution - h - pt 154 | else: 155 | pt = 0 156 | pb = 0 157 | 158 | if w < args.resolution: 159 | pl = (args.resolution - w) // 2 160 | pr = args.resolution - w - pl 161 | else: 162 | pl = 0 163 | pr = 0 164 | 165 | fgr = F.pad(fgr, [pl, pt, pr, pb]) 166 | pha = F.pad(pha, [pl, pt, pr, pb]) 167 | 168 | if i // len(imagematte_filenames) % 2 == 1: 169 | fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT) 170 | pha = pha.transpose(Image.FLIP_LEFT_RIGHT) 171 | 172 | fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension)) 173 | pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension)) 174 | 175 | bgr = Image.fromarray(bgrs[t]).convert('RGB') 176 | w, h = bgr.size 177 | scale = args.resolution / min(h, w) 178 | w, h = int(w * scale), int(h * scale) 179 | bgr = bgr.resize((w, h)) 180 | bgr = F.center_crop(bgr, (args.resolution, args.resolution)) 181 | bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) 182 | 183 | pha = np.asarray(pha).astype(float)[:, :, None] / 255 184 | com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha))) 185 | com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension)) 186 | 187 | if __name__ == '__main__': 188 | r = process_map(process, range(args.num_samples), max_workers=10) 189 | 190 | -------------------------------------------------------------------------------- /evaluation/generate_videomatte_with_background_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | python generate_videomatte_with_background_image.py \ 3 | --videomatte-dir ../matting-data/VideoMatte240K_JPEG_HD/test \ 4 | --background-dir ../matting-data/Backgrounds/valid \ 5 | --num-samples 25 \ 6 | --resize 512 288 \ 7 | --out-dir ../matting-data/evaluation/vidematte_static_sd/ 8 | """ 9 | 10 | import argparse 11 | import os 12 | import pims 13 | import numpy as np 14 | import random 15 | from PIL import Image 16 | from tqdm import tqdm 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--videomatte-dir', type=str, required=True) 20 | parser.add_argument('--background-dir', type=str, required=True) 21 | parser.add_argument('--num-samples', type=int, default=20) 22 | parser.add_argument('--num-frames', type=int, default=100) 23 | parser.add_argument('--resize', type=int, default=None, nargs=2) 24 | parser.add_argument('--out-dir', type=str, required=True) 25 | parser.add_argument('--extension', type=str, default='.png') 26 | args = parser.parse_args() 27 | 28 | random.seed(10) 29 | 30 | videomatte_filenames = [(clipname, sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr', clipname)))) 31 | for clipname in sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr')))] 32 | 33 | background_filenames = os.listdir(args.background_dir) 34 | random.shuffle(background_filenames) 35 | 36 | for i in range(args.num_samples): 37 | 38 | clipname, framenames = videomatte_filenames[i % len(videomatte_filenames)] 39 | 40 | out_path = os.path.join(args.out_dir, str(i).zfill(4)) 41 | os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True) 42 | os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True) 43 | os.makedirs(os.path.join(out_path, 'com'), exist_ok=True) 44 | os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True) 45 | 46 | with Image.open(os.path.join(args.background_dir, background_filenames[i])) as bgr: 47 | bgr = bgr.convert('RGB') 48 | 49 | 50 | base_t = random.choice(range(len(framenames) - args.num_frames)) 51 | 52 | for t in tqdm(range(args.num_frames), desc=str(i).zfill(4)): 53 | with Image.open(os.path.join(args.videomatte_dir, 'fgr', clipname, framenames[base_t + t])) as fgr, \ 54 | Image.open(os.path.join(args.videomatte_dir, 'pha', clipname, framenames[base_t + t])) as pha: 55 | fgr = fgr.convert('RGB') 56 | pha = pha.convert('L') 57 | 58 | if args.resize is not None: 59 | fgr = fgr.resize(args.resize, Image.BILINEAR) 60 | pha = pha.resize(args.resize, Image.BILINEAR) 61 | 62 | 63 | if i // len(videomatte_filenames) % 2 == 1: 64 | fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT) 65 | pha = pha.transpose(Image.FLIP_LEFT_RIGHT) 66 | 67 | fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + args.extension)) 68 | pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + args.extension)) 69 | 70 | if t == 0: 71 | bgr = bgr.resize(fgr.size, Image.BILINEAR) 72 | bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) 73 | else: 74 | os.symlink(str(0).zfill(4) + args.extension, os.path.join(out_path, 'bgr', str(t).zfill(4) + args.extension)) 75 | 76 | pha = np.asarray(pha).astype(float)[:, :, None] / 255 77 | com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha))) 78 | com.save(os.path.join(out_path, 'com', str(t).zfill(4) + args.extension)) 79 | -------------------------------------------------------------------------------- /evaluation/generate_videomatte_with_background_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | python generate_videomatte_with_background_video.py \ 3 | --videomatte-dir ../matting-data/VideoMatte240K_JPEG_HD/test \ 4 | --background-dir ../matting-data/BackgroundVideos_mp4/test \ 5 | --resize 512 288 \ 6 | --out-dir ../matting-data/evaluation/vidematte_motion_sd/ 7 | """ 8 | 9 | import argparse 10 | import os 11 | import pims 12 | import numpy as np 13 | import random 14 | from PIL import Image 15 | from tqdm import tqdm 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--videomatte-dir', type=str, required=True) 19 | parser.add_argument('--background-dir', type=str, required=True) 20 | parser.add_argument('--num-samples', type=int, default=20) 21 | parser.add_argument('--num-frames', type=int, default=100) 22 | parser.add_argument('--resize', type=int, default=None, nargs=2) 23 | parser.add_argument('--out-dir', type=str, required=True) 24 | args = parser.parse_args() 25 | 26 | # Hand selected a list of videos 27 | background_filenames = [ 28 | "0000.mp4", 29 | "0007.mp4", 30 | "0008.mp4", 31 | "0010.mp4", 32 | "0013.mp4", 33 | "0015.mp4", 34 | "0016.mp4", 35 | "0018.mp4", 36 | "0021.mp4", 37 | "0029.mp4", 38 | "0033.mp4", 39 | "0035.mp4", 40 | "0039.mp4", 41 | "0050.mp4", 42 | "0052.mp4", 43 | "0055.mp4", 44 | "0060.mp4", 45 | "0063.mp4", 46 | "0087.mp4", 47 | "0086.mp4", 48 | "0090.mp4", 49 | "0101.mp4", 50 | "0110.mp4", 51 | "0117.mp4", 52 | "0120.mp4", 53 | "0122.mp4", 54 | "0123.mp4", 55 | "0125.mp4", 56 | "0128.mp4", 57 | "0131.mp4", 58 | "0172.mp4", 59 | "0176.mp4", 60 | "0181.mp4", 61 | "0187.mp4", 62 | "0193.mp4", 63 | "0198.mp4", 64 | "0220.mp4", 65 | "0221.mp4", 66 | "0224.mp4", 67 | "0229.mp4", 68 | "0233.mp4", 69 | "0238.mp4", 70 | "0241.mp4", 71 | "0245.mp4", 72 | "0246.mp4" 73 | ] 74 | 75 | random.seed(10) 76 | 77 | videomatte_filenames = [(clipname, sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr', clipname)))) 78 | for clipname in sorted(os.listdir(os.path.join(args.videomatte_dir, 'fgr')))] 79 | 80 | random.shuffle(background_filenames) 81 | 82 | for i in range(args.num_samples): 83 | bgrs = pims.PyAVVideoReader(os.path.join(args.background_dir, background_filenames[i % len(background_filenames)])) 84 | clipname, framenames = videomatte_filenames[i % len(videomatte_filenames)] 85 | 86 | out_path = os.path.join(args.out_dir, str(i).zfill(4)) 87 | os.makedirs(os.path.join(out_path, 'fgr'), exist_ok=True) 88 | os.makedirs(os.path.join(out_path, 'pha'), exist_ok=True) 89 | os.makedirs(os.path.join(out_path, 'com'), exist_ok=True) 90 | os.makedirs(os.path.join(out_path, 'bgr'), exist_ok=True) 91 | 92 | base_t = random.choice(range(len(framenames) - args.num_frames)) 93 | 94 | for t in tqdm(range(args.num_frames), desc=str(i).zfill(4)): 95 | with Image.open(os.path.join(args.videomatte_dir, 'fgr', clipname, framenames[base_t + t])) as fgr, \ 96 | Image.open(os.path.join(args.videomatte_dir, 'pha', clipname, framenames[base_t + t])) as pha: 97 | fgr = fgr.convert('RGB') 98 | pha = pha.convert('L') 99 | 100 | if args.resize is not None: 101 | fgr = fgr.resize(args.resize, Image.BILINEAR) 102 | pha = pha.resize(args.resize, Image.BILINEAR) 103 | 104 | 105 | if i // len(videomatte_filenames) % 2 == 1: 106 | fgr = fgr.transpose(Image.FLIP_LEFT_RIGHT) 107 | pha = pha.transpose(Image.FLIP_LEFT_RIGHT) 108 | 109 | fgr.save(os.path.join(out_path, 'fgr', str(t).zfill(4) + '.png')) 110 | pha.save(os.path.join(out_path, 'pha', str(t).zfill(4) + '.png')) 111 | 112 | bgr = Image.fromarray(bgrs[t]) 113 | bgr = bgr.resize(fgr.size, Image.BILINEAR) 114 | bgr.save(os.path.join(out_path, 'bgr', str(t).zfill(4) + '.png')) 115 | 116 | pha = np.asarray(pha).astype(float)[:, :, None] / 255 117 | com = Image.fromarray(np.uint8(np.asarray(fgr) * pha + np.asarray(bgr) * (1 - pha))) 118 | com.save(os.path.join(out_path, 'com', str(t).zfill(4) + '.png')) 119 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loading model 3 | model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") 4 | model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50") 5 | 6 | Converter API 7 | convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter") 8 | """ 9 | 10 | 11 | dependencies = ['torch', 'torchvision'] 12 | 13 | import torch 14 | from model import MattingNetwork 15 | 16 | 17 | def mobilenetv3(pretrained: bool = True, progress: bool = True): 18 | model = MattingNetwork('mobilenetv3') 19 | if pretrained: 20 | url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_mobilenetv3.pth' 21 | model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress)) 22 | return model 23 | 24 | 25 | def resnet50(pretrained: bool = True, progress: bool = True): 26 | model = MattingNetwork('resnet50') 27 | if pretrained: 28 | url = 'https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth' 29 | model.load_state_dict(torch.hub.load_state_dict_from_url(url, map_location='cpu', progress=progress)) 30 | return model 31 | 32 | 33 | def converter(): 34 | try: 35 | from inference import convert_video 36 | return convert_video 37 | except ModuleNotFoundError as error: 38 | print(error) 39 | print('Please run "pip install av tqdm pims"') 40 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | python inference.py \ 3 | --variant mobilenetv3 \ 4 | --checkpoint "CHECKPOINT" \ 5 | --device cuda \ 6 | --input-source "input.mp4" \ 7 | --output-type video \ 8 | --output-composition "composition.mp4" \ 9 | --output-alpha "alpha.mp4" \ 10 | --output-foreground "foreground.mp4" \ 11 | --output-video-mbps 4 \ 12 | --seq-chunk 1 13 | """ 14 | 15 | import torch 16 | import os 17 | from torch.utils.data import DataLoader 18 | from torchvision import transforms 19 | from typing import Optional, Tuple 20 | from tqdm.auto import tqdm 21 | 22 | from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter 23 | 24 | def convert_video(model, 25 | input_source: str, 26 | input_resize: Optional[Tuple[int, int]] = None, 27 | downsample_ratio: Optional[float] = None, 28 | output_type: str = 'video', 29 | output_composition: Optional[str] = None, 30 | output_alpha: Optional[str] = None, 31 | output_foreground: Optional[str] = None, 32 | output_video_mbps: Optional[float] = None, 33 | seq_chunk: int = 1, 34 | num_workers: int = 0, 35 | progress: bool = True, 36 | device: Optional[str] = None, 37 | dtype: Optional[torch.dtype] = None): 38 | 39 | """ 40 | Args: 41 | input_source:A video file, or an image sequence directory. Images must be sorted in accending order, support png and jpg. 42 | input_resize: If provided, the input are first resized to (w, h). 43 | downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one. 44 | output_type: Options: ["video", "png_sequence"]. 45 | output_composition: 46 | The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'. 47 | If output_type == 'video', the composition has green screen background. 48 | If output_type == 'png_sequence'. the composition is RGBA png images. 49 | output_alpha: The alpha output from the model. 50 | output_foreground: The foreground output from the model. 51 | seq_chunk: Number of frames to process at once. Increase it for better parallelism. 52 | num_workers: PyTorch's DataLoader workers. Only use >0 for image input. 53 | progress: Show progress bar. 54 | device: Only need to manually provide if model is a TorchScript freezed model. 55 | dtype: Only need to manually provide if model is a TorchScript freezed model. 56 | """ 57 | 58 | assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).' 59 | assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.' 60 | assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.' 61 | assert seq_chunk >= 1, 'Sequence chunk must be >= 1' 62 | assert num_workers >= 0, 'Number of workers must be >= 0' 63 | 64 | # Initialize transform 65 | if input_resize is not None: 66 | transform = transforms.Compose([ 67 | transforms.Resize(input_resize[::-1]), 68 | transforms.ToTensor() 69 | ]) 70 | else: 71 | transform = transforms.ToTensor() 72 | 73 | # Initialize reader 74 | if os.path.isfile(input_source): 75 | source = VideoReader(input_source, transform) 76 | else: 77 | source = ImageSequenceReader(input_source, transform) 78 | reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers) 79 | 80 | # Initialize writers 81 | if output_type == 'video': 82 | frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30 83 | output_video_mbps = 1 if output_video_mbps is None else output_video_mbps 84 | if output_composition is not None: 85 | writer_com = VideoWriter( 86 | path=output_composition, 87 | frame_rate=frame_rate, 88 | bit_rate=int(output_video_mbps * 1000000)) 89 | if output_alpha is not None: 90 | writer_pha = VideoWriter( 91 | path=output_alpha, 92 | frame_rate=frame_rate, 93 | bit_rate=int(output_video_mbps * 1000000)) 94 | if output_foreground is not None: 95 | writer_fgr = VideoWriter( 96 | path=output_foreground, 97 | frame_rate=frame_rate, 98 | bit_rate=int(output_video_mbps * 1000000)) 99 | else: 100 | if output_composition is not None: 101 | writer_com = ImageSequenceWriter(output_composition, 'png') 102 | if output_alpha is not None: 103 | writer_pha = ImageSequenceWriter(output_alpha, 'png') 104 | if output_foreground is not None: 105 | writer_fgr = ImageSequenceWriter(output_foreground, 'png') 106 | 107 | # Inference 108 | model = model.eval() 109 | if device is None or dtype is None: 110 | param = next(model.parameters()) 111 | dtype = param.dtype 112 | device = param.device 113 | 114 | if (output_composition is not None) and (output_type == 'video'): 115 | bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1) 116 | 117 | try: 118 | with torch.no_grad(): 119 | bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True) 120 | rec = [None] * 4 121 | for src in reader: 122 | 123 | if downsample_ratio is None: 124 | downsample_ratio = auto_downsample_ratio(*src.shape[2:]) 125 | 126 | src = src.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W] 127 | fgr, pha, *rec = model(src, *rec, downsample_ratio) 128 | 129 | if output_foreground is not None: 130 | writer_fgr.write(fgr[0]) 131 | if output_alpha is not None: 132 | writer_pha.write(pha[0]) 133 | if output_composition is not None: 134 | if output_type == 'video': 135 | com = fgr * pha + bgr * (1 - pha) 136 | else: 137 | fgr = fgr * pha.gt(0) 138 | com = torch.cat([fgr, pha], dim=-3) 139 | writer_com.write(com[0]) 140 | 141 | bar.update(src.size(1)) 142 | 143 | finally: 144 | # Clean up 145 | if output_composition is not None: 146 | writer_com.close() 147 | if output_alpha is not None: 148 | writer_pha.close() 149 | if output_foreground is not None: 150 | writer_fgr.close() 151 | 152 | 153 | def auto_downsample_ratio(h, w): 154 | """ 155 | Automatically find a downsample ratio so that the largest side of the resolution be 512px. 156 | """ 157 | return min(512 / max(h, w), 1) 158 | 159 | 160 | class Converter: 161 | def __init__(self, variant: str, checkpoint: str, device: str): 162 | self.model = MattingNetwork(variant).eval().to(device) 163 | self.model.load_state_dict(torch.load(checkpoint, map_location=device)) 164 | self.model = torch.jit.script(self.model) 165 | self.model = torch.jit.freeze(self.model) 166 | self.device = device 167 | 168 | def convert(self, *args, **kwargs): 169 | convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs) 170 | 171 | if __name__ == '__main__': 172 | import argparse 173 | from model import MattingNetwork 174 | 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument('--variant', type=str, required=True, choices=['mobilenetv3', 'resnet50']) 177 | parser.add_argument('--checkpoint', type=str, required=True) 178 | parser.add_argument('--device', type=str, required=True) 179 | parser.add_argument('--input-source', type=str, required=True) 180 | parser.add_argument('--input-resize', type=int, default=None, nargs=2) 181 | parser.add_argument('--downsample-ratio', type=float) 182 | parser.add_argument('--output-composition', type=str) 183 | parser.add_argument('--output-alpha', type=str) 184 | parser.add_argument('--output-foreground', type=str) 185 | parser.add_argument('--output-type', type=str, required=True, choices=['video', 'png_sequence']) 186 | parser.add_argument('--output-video-mbps', type=int, default=1) 187 | parser.add_argument('--seq-chunk', type=int, default=1) 188 | parser.add_argument('--num-workers', type=int, default=0) 189 | parser.add_argument('--disable-progress', action='store_true') 190 | args = parser.parse_args() 191 | 192 | converter = Converter(args.variant, args.checkpoint, args.device) 193 | converter.convert( 194 | input_source=args.input_source, 195 | input_resize=args.input_resize, 196 | downsample_ratio=args.downsample_ratio, 197 | output_type=args.output_type, 198 | output_composition=args.output_composition, 199 | output_alpha=args.output_alpha, 200 | output_foreground=args.output_foreground, 201 | output_video_mbps=args.output_video_mbps, 202 | seq_chunk=args.seq_chunk, 203 | num_workers=args.num_workers, 204 | progress=not args.disable_progress 205 | ) 206 | 207 | 208 | -------------------------------------------------------------------------------- /inference_speed_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | python inference_speed_test.py \ 3 | --model-variant mobilenetv3 \ 4 | --resolution 1920 1080 \ 5 | --downsample-ratio 0.25 \ 6 | --precision float32 7 | """ 8 | 9 | import argparse 10 | import torch 11 | from tqdm import tqdm 12 | 13 | from model.model import MattingNetwork 14 | 15 | torch.backends.cudnn.benchmark = True 16 | 17 | class InferenceSpeedTest: 18 | def __init__(self): 19 | self.parse_args() 20 | self.init_model() 21 | self.loop() 22 | 23 | def parse_args(self): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--model-variant', type=str, required=True) 26 | parser.add_argument('--resolution', type=int, required=True, nargs=2) 27 | parser.add_argument('--downsample-ratio', type=float, required=True) 28 | parser.add_argument('--precision', type=str, default='float32') 29 | parser.add_argument('--disable-refiner', action='store_true') 30 | self.args = parser.parse_args() 31 | 32 | def init_model(self): 33 | self.device = 'cuda' 34 | self.precision = {'float32': torch.float32, 'float16': torch.float16}[self.args.precision] 35 | self.model = MattingNetwork(self.args.model_variant) 36 | self.model = self.model.to(device=self.device, dtype=self.precision).eval() 37 | self.model = torch.jit.script(self.model) 38 | self.model = torch.jit.freeze(self.model) 39 | 40 | def loop(self): 41 | w, h = self.args.resolution 42 | src = torch.randn((1, 3, h, w), device=self.device, dtype=self.precision) 43 | with torch.no_grad(): 44 | rec = None, None, None, None 45 | for _ in tqdm(range(1000)): 46 | fgr, pha, *rec = self.model(src, *rec, self.args.downsample_ratio) 47 | torch.cuda.synchronize() 48 | 49 | if __name__ == '__main__': 50 | InferenceSpeedTest() -------------------------------------------------------------------------------- /inference_utils.py: -------------------------------------------------------------------------------- 1 | import av 2 | import os 3 | import pims 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from torchvision.transforms.functional import to_pil_image 7 | from PIL import Image 8 | 9 | 10 | class VideoReader(Dataset): 11 | def __init__(self, path, transform=None): 12 | self.video = pims.PyAVVideoReader(path) 13 | self.rate = self.video.frame_rate 14 | self.transform = transform 15 | 16 | @property 17 | def frame_rate(self): 18 | return self.rate 19 | 20 | def __len__(self): 21 | return len(self.video) 22 | 23 | def __getitem__(self, idx): 24 | frame = self.video[idx] 25 | frame = Image.fromarray(np.asarray(frame)) 26 | if self.transform is not None: 27 | frame = self.transform(frame) 28 | return frame 29 | 30 | 31 | class VideoWriter: 32 | def __init__(self, path, frame_rate, bit_rate=1000000): 33 | self.container = av.open(path, mode='w') 34 | self.stream = self.container.add_stream('h264', rate=f'{frame_rate:.4f}') 35 | self.stream.pix_fmt = 'yuv420p' 36 | self.stream.bit_rate = bit_rate 37 | 38 | def write(self, frames): 39 | # frames: [T, C, H, W] 40 | self.stream.width = frames.size(3) 41 | self.stream.height = frames.size(2) 42 | if frames.size(1) == 1: 43 | frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB 44 | frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy() 45 | for t in range(frames.shape[0]): 46 | frame = frames[t] 47 | frame = av.VideoFrame.from_ndarray(frame, format='rgb24') 48 | self.container.mux(self.stream.encode(frame)) 49 | 50 | def close(self): 51 | self.container.mux(self.stream.encode()) 52 | self.container.close() 53 | 54 | 55 | class ImageSequenceReader(Dataset): 56 | def __init__(self, path, transform=None): 57 | self.path = path 58 | self.files = sorted(os.listdir(path)) 59 | self.transform = transform 60 | 61 | def __len__(self): 62 | return len(self.files) 63 | 64 | def __getitem__(self, idx): 65 | with Image.open(os.path.join(self.path, self.files[idx])) as img: 66 | img.load() 67 | if self.transform is not None: 68 | return self.transform(img) 69 | return img 70 | 71 | 72 | class ImageSequenceWriter: 73 | def __init__(self, path, extension='jpg'): 74 | self.path = path 75 | self.extension = extension 76 | self.counter = 0 77 | os.makedirs(path, exist_ok=True) 78 | 79 | def write(self, frames): 80 | # frames: [T, C, H, W] 81 | for t in range(frames.shape[0]): 82 | to_pil_image(frames[t]).save(os.path.join( 83 | self.path, str(self.counter).zfill(4) + '.' + self.extension)) 84 | self.counter += 1 85 | 86 | def close(self): 87 | pass 88 | 89 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import MattingNetwork -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from typing import Tuple, Optional 6 | 7 | class RecurrentDecoder(nn.Module): 8 | def __init__(self, feature_channels, decoder_channels): 9 | super().__init__() 10 | self.avgpool = AvgPool() 11 | self.decode4 = BottleneckBlock(feature_channels[3]) 12 | self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0]) 13 | self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1]) 14 | self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2]) 15 | self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3]) 16 | 17 | def forward(self, 18 | s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor, 19 | r1: Optional[Tensor], r2: Optional[Tensor], 20 | r3: Optional[Tensor], r4: Optional[Tensor]): 21 | s1, s2, s3 = self.avgpool(s0) 22 | x4, r4 = self.decode4(f4, r4) 23 | x3, r3 = self.decode3(x4, f3, s3, r3) 24 | x2, r2 = self.decode2(x3, f2, s2, r2) 25 | x1, r1 = self.decode1(x2, f1, s1, r1) 26 | x0 = self.decode0(x1, s0) 27 | return x0, r1, r2, r3, r4 28 | 29 | 30 | class AvgPool(nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | self.avgpool = nn.AvgPool2d(2, 2, count_include_pad=False, ceil_mode=True) 34 | 35 | def forward_single_frame(self, s0): 36 | s1 = self.avgpool(s0) 37 | s2 = self.avgpool(s1) 38 | s3 = self.avgpool(s2) 39 | return s1, s2, s3 40 | 41 | def forward_time_series(self, s0): 42 | B, T = s0.shape[:2] 43 | s0 = s0.flatten(0, 1) 44 | s1, s2, s3 = self.forward_single_frame(s0) 45 | s1 = s1.unflatten(0, (B, T)) 46 | s2 = s2.unflatten(0, (B, T)) 47 | s3 = s3.unflatten(0, (B, T)) 48 | return s1, s2, s3 49 | 50 | def forward(self, s0): 51 | if s0.ndim == 5: 52 | return self.forward_time_series(s0) 53 | else: 54 | return self.forward_single_frame(s0) 55 | 56 | 57 | class BottleneckBlock(nn.Module): 58 | def __init__(self, channels): 59 | super().__init__() 60 | self.channels = channels 61 | self.gru = ConvGRU(channels // 2) 62 | 63 | def forward(self, x, r: Optional[Tensor]): 64 | a, b = x.split(self.channels // 2, dim=-3) 65 | b, r = self.gru(b, r) 66 | x = torch.cat([a, b], dim=-3) 67 | return x, r 68 | 69 | 70 | class UpsamplingBlock(nn.Module): 71 | def __init__(self, in_channels, skip_channels, src_channels, out_channels): 72 | super().__init__() 73 | self.out_channels = out_channels 74 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 75 | self.conv = nn.Sequential( 76 | nn.Conv2d(in_channels + skip_channels + src_channels, out_channels, 3, 1, 1, bias=False), 77 | nn.BatchNorm2d(out_channels), 78 | nn.ReLU(True), 79 | ) 80 | self.gru = ConvGRU(out_channels // 2) 81 | 82 | def forward_single_frame(self, x, f, s, r: Optional[Tensor]): 83 | x = self.upsample(x) 84 | x = x[:, :, :s.size(2), :s.size(3)] 85 | x = torch.cat([x, f, s], dim=1) 86 | x = self.conv(x) 87 | a, b = x.split(self.out_channels // 2, dim=1) 88 | b, r = self.gru(b, r) 89 | x = torch.cat([a, b], dim=1) 90 | return x, r 91 | 92 | def forward_time_series(self, x, f, s, r: Optional[Tensor]): 93 | B, T, _, H, W = s.shape 94 | x = x.flatten(0, 1) 95 | f = f.flatten(0, 1) 96 | s = s.flatten(0, 1) 97 | x = self.upsample(x) 98 | x = x[:, :, :H, :W] 99 | x = torch.cat([x, f, s], dim=1) 100 | x = self.conv(x) 101 | x = x.unflatten(0, (B, T)) 102 | a, b = x.split(self.out_channels // 2, dim=2) 103 | b, r = self.gru(b, r) 104 | x = torch.cat([a, b], dim=2) 105 | return x, r 106 | 107 | def forward(self, x, f, s, r: Optional[Tensor]): 108 | if x.ndim == 5: 109 | return self.forward_time_series(x, f, s, r) 110 | else: 111 | return self.forward_single_frame(x, f, s, r) 112 | 113 | 114 | class OutputBlock(nn.Module): 115 | def __init__(self, in_channels, src_channels, out_channels): 116 | super().__init__() 117 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 118 | self.conv = nn.Sequential( 119 | nn.Conv2d(in_channels + src_channels, out_channels, 3, 1, 1, bias=False), 120 | nn.BatchNorm2d(out_channels), 121 | nn.ReLU(True), 122 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), 123 | nn.BatchNorm2d(out_channels), 124 | nn.ReLU(True), 125 | ) 126 | 127 | def forward_single_frame(self, x, s): 128 | x = self.upsample(x) 129 | x = x[:, :, :s.size(2), :s.size(3)] 130 | x = torch.cat([x, s], dim=1) 131 | x = self.conv(x) 132 | return x 133 | 134 | def forward_time_series(self, x, s): 135 | B, T, _, H, W = s.shape 136 | x = x.flatten(0, 1) 137 | s = s.flatten(0, 1) 138 | x = self.upsample(x) 139 | x = x[:, :, :H, :W] 140 | x = torch.cat([x, s], dim=1) 141 | x = self.conv(x) 142 | x = x.unflatten(0, (B, T)) 143 | return x 144 | 145 | def forward(self, x, s): 146 | if x.ndim == 5: 147 | return self.forward_time_series(x, s) 148 | else: 149 | return self.forward_single_frame(x, s) 150 | 151 | 152 | class ConvGRU(nn.Module): 153 | def __init__(self, 154 | channels: int, 155 | kernel_size: int = 3, 156 | padding: int = 1): 157 | super().__init__() 158 | self.channels = channels 159 | self.ih = nn.Sequential( 160 | nn.Conv2d(channels * 2, channels * 2, kernel_size, padding=padding), 161 | nn.Sigmoid() 162 | ) 163 | self.hh = nn.Sequential( 164 | nn.Conv2d(channels * 2, channels, kernel_size, padding=padding), 165 | nn.Tanh() 166 | ) 167 | 168 | def forward_single_frame(self, x, h): 169 | r, z = self.ih(torch.cat([x, h], dim=1)).split(self.channels, dim=1) 170 | c = self.hh(torch.cat([x, r * h], dim=1)) 171 | h = (1 - z) * h + z * c 172 | return h, h 173 | 174 | def forward_time_series(self, x, h): 175 | o = [] 176 | for xt in x.unbind(dim=1): 177 | ot, h = self.forward_single_frame(xt, h) 178 | o.append(ot) 179 | o = torch.stack(o, dim=1) 180 | return o, h 181 | 182 | def forward(self, x, h: Optional[Tensor]): 183 | if h is None: 184 | h = torch.zeros((x.size(0), x.size(-3), x.size(-2), x.size(-1)), 185 | device=x.device, dtype=x.dtype) 186 | 187 | if x.ndim == 5: 188 | return self.forward_time_series(x, h) 189 | else: 190 | return self.forward_single_frame(x, h) 191 | 192 | 193 | class Projection(nn.Module): 194 | def __init__(self, in_channels, out_channels): 195 | super().__init__() 196 | self.conv = nn.Conv2d(in_channels, out_channels, 1) 197 | 198 | def forward_single_frame(self, x): 199 | return self.conv(x) 200 | 201 | def forward_time_series(self, x): 202 | B, T = x.shape[:2] 203 | return self.conv(x.flatten(0, 1)).unflatten(0, (B, T)) 204 | 205 | def forward(self, x): 206 | if x.ndim == 5: 207 | return self.forward_time_series(x) 208 | else: 209 | return self.forward_single_frame(x) 210 | -------------------------------------------------------------------------------- /model/deep_guided_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | """ 6 | Adopted from 7 | """ 8 | 9 | class DeepGuidedFilterRefiner(nn.Module): 10 | def __init__(self, hid_channels=16): 11 | super().__init__() 12 | self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4) 13 | self.box_filter.weight.data[...] = 1 / 9 14 | self.conv = nn.Sequential( 15 | nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False), 16 | nn.BatchNorm2d(hid_channels), 17 | nn.ReLU(True), 18 | nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False), 19 | nn.BatchNorm2d(hid_channels), 20 | nn.ReLU(True), 21 | nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True) 22 | ) 23 | 24 | def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha, base_hid): 25 | fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1) 26 | base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1) 27 | base_y = torch.cat([base_fgr, base_pha], dim=1) 28 | 29 | mean_x = self.box_filter(base_x) 30 | mean_y = self.box_filter(base_y) 31 | cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y 32 | var_x = self.box_filter(base_x * base_x) - mean_x * mean_x 33 | 34 | A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1)) 35 | b = mean_y - A * mean_x 36 | 37 | H, W = fine_src.shape[2:] 38 | A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False) 39 | b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False) 40 | 41 | out = A * fine_x + b 42 | fgr, pha = out.split([3, 1], dim=1) 43 | return fgr, pha 44 | 45 | def forward_time_series(self, fine_src, base_src, base_fgr, base_pha, base_hid): 46 | B, T = fine_src.shape[:2] 47 | fgr, pha = self.forward_single_frame( 48 | fine_src.flatten(0, 1), 49 | base_src.flatten(0, 1), 50 | base_fgr.flatten(0, 1), 51 | base_pha.flatten(0, 1), 52 | base_hid.flatten(0, 1)) 53 | fgr = fgr.unflatten(0, (B, T)) 54 | pha = pha.unflatten(0, (B, T)) 55 | return fgr, pha 56 | 57 | def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid): 58 | if fine_src.ndim == 5: 59 | return self.forward_time_series(fine_src, base_src, base_fgr, base_pha, base_hid) 60 | else: 61 | return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha, base_hid) 62 | -------------------------------------------------------------------------------- /model/fast_guided_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | """ 6 | Adopted from 7 | """ 8 | 9 | class FastGuidedFilterRefiner(nn.Module): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.guilded_filter = FastGuidedFilter(1) 13 | 14 | def forward_single_frame(self, fine_src, base_src, base_fgr, base_pha): 15 | fine_src_gray = fine_src.mean(1, keepdim=True) 16 | base_src_gray = base_src.mean(1, keepdim=True) 17 | 18 | fgr, pha = self.guilded_filter( 19 | torch.cat([base_src, base_src_gray], dim=1), 20 | torch.cat([base_fgr, base_pha], dim=1), 21 | torch.cat([fine_src, fine_src_gray], dim=1)).split([3, 1], dim=1) 22 | 23 | return fgr, pha 24 | 25 | def forward_time_series(self, fine_src, base_src, base_fgr, base_pha): 26 | B, T = fine_src.shape[:2] 27 | fgr, pha = self.forward_single_frame( 28 | fine_src.flatten(0, 1), 29 | base_src.flatten(0, 1), 30 | base_fgr.flatten(0, 1), 31 | base_pha.flatten(0, 1)) 32 | fgr = fgr.unflatten(0, (B, T)) 33 | pha = pha.unflatten(0, (B, T)) 34 | return fgr, pha 35 | 36 | def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid): 37 | if fine_src.ndim == 5: 38 | return self.forward_time_series(fine_src, base_src, base_fgr, base_pha) 39 | else: 40 | return self.forward_single_frame(fine_src, base_src, base_fgr, base_pha) 41 | 42 | 43 | class FastGuidedFilter(nn.Module): 44 | def __init__(self, r: int, eps: float = 1e-5): 45 | super().__init__() 46 | self.r = r 47 | self.eps = eps 48 | self.boxfilter = BoxFilter(r) 49 | 50 | def forward(self, lr_x, lr_y, hr_x): 51 | mean_x = self.boxfilter(lr_x) 52 | mean_y = self.boxfilter(lr_y) 53 | cov_xy = self.boxfilter(lr_x * lr_y) - mean_x * mean_y 54 | var_x = self.boxfilter(lr_x * lr_x) - mean_x * mean_x 55 | A = cov_xy / (var_x + self.eps) 56 | b = mean_y - A * mean_x 57 | A = F.interpolate(A, hr_x.shape[2:], mode='bilinear', align_corners=False) 58 | b = F.interpolate(b, hr_x.shape[2:], mode='bilinear', align_corners=False) 59 | return A * hr_x + b 60 | 61 | 62 | class BoxFilter(nn.Module): 63 | def __init__(self, r): 64 | super(BoxFilter, self).__init__() 65 | self.r = r 66 | 67 | def forward(self, x): 68 | # Note: The original implementation at 69 | # uses faster box blur. However, it may not be friendly for ONNX export. 70 | # We are switching to use simple convolution for box blur. 71 | kernel_size = 2 * self.r + 1 72 | kernel_x = torch.full((x.data.shape[1], 1, 1, kernel_size), 1 / kernel_size, device=x.device, dtype=x.dtype) 73 | kernel_y = torch.full((x.data.shape[1], 1, kernel_size, 1), 1 / kernel_size, device=x.device, dtype=x.dtype) 74 | x = F.conv2d(x, kernel_x, padding=(0, self.r), groups=x.data.shape[1]) 75 | x = F.conv2d(x, kernel_y, padding=(self.r, 0), groups=x.data.shape[1]) 76 | return x -------------------------------------------------------------------------------- /model/lraspp.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class LRASPP(nn.Module): 4 | def __init__(self, in_channels, out_channels): 5 | super().__init__() 6 | self.aspp1 = nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 8 | nn.BatchNorm2d(out_channels), 9 | nn.ReLU(True) 10 | ) 11 | self.aspp2 = nn.Sequential( 12 | nn.AdaptiveAvgPool2d(1), 13 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 14 | nn.Sigmoid() 15 | ) 16 | 17 | def forward_single_frame(self, x): 18 | return self.aspp1(x) * self.aspp2(x) 19 | 20 | def forward_time_series(self, x): 21 | B, T = x.shape[:2] 22 | x = self.forward_single_frame(x.flatten(0, 1)).unflatten(0, (B, T)) 23 | return x 24 | 25 | def forward(self, x): 26 | if x.ndim == 5: 27 | return self.forward_time_series(x) 28 | else: 29 | return self.forward_single_frame(x) -------------------------------------------------------------------------------- /model/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig 4 | from torchvision.transforms.functional import normalize 5 | 6 | class MobileNetV3LargeEncoder(MobileNetV3): 7 | def __init__(self, pretrained: bool = False): 8 | super().__init__( 9 | inverted_residual_setting=[ 10 | InvertedResidualConfig( 16, 3, 16, 16, False, "RE", 1, 1, 1), 11 | InvertedResidualConfig( 16, 3, 64, 24, False, "RE", 2, 1, 1), # C1 12 | InvertedResidualConfig( 24, 3, 72, 24, False, "RE", 1, 1, 1), 13 | InvertedResidualConfig( 24, 5, 72, 40, True, "RE", 2, 1, 1), # C2 14 | InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1), 15 | InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1), 16 | InvertedResidualConfig( 40, 3, 240, 80, False, "HS", 2, 1, 1), # C3 17 | InvertedResidualConfig( 80, 3, 200, 80, False, "HS", 1, 1, 1), 18 | InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1), 19 | InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1), 20 | InvertedResidualConfig( 80, 3, 480, 112, True, "HS", 1, 1, 1), 21 | InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1, 1, 1), 22 | InvertedResidualConfig(112, 5, 672, 160, True, "HS", 2, 2, 1), # C4 23 | InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1), 24 | InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1), 25 | ], 26 | last_channel=1280 27 | ) 28 | 29 | if pretrained: 30 | self.load_state_dict(torch.hub.load_state_dict_from_url( 31 | 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')) 32 | 33 | del self.avgpool 34 | del self.classifier 35 | 36 | def forward_single_frame(self, x): 37 | x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 38 | 39 | x = self.features[0](x) 40 | x = self.features[1](x) 41 | f1 = x 42 | x = self.features[2](x) 43 | x = self.features[3](x) 44 | f2 = x 45 | x = self.features[4](x) 46 | x = self.features[5](x) 47 | x = self.features[6](x) 48 | f3 = x 49 | x = self.features[7](x) 50 | x = self.features[8](x) 51 | x = self.features[9](x) 52 | x = self.features[10](x) 53 | x = self.features[11](x) 54 | x = self.features[12](x) 55 | x = self.features[13](x) 56 | x = self.features[14](x) 57 | x = self.features[15](x) 58 | x = self.features[16](x) 59 | f4 = x 60 | return [f1, f2, f3, f4] 61 | 62 | def forward_time_series(self, x): 63 | B, T = x.shape[:2] 64 | features = self.forward_single_frame(x.flatten(0, 1)) 65 | features = [f.unflatten(0, (B, T)) for f in features] 66 | return features 67 | 68 | def forward(self, x): 69 | if x.ndim == 5: 70 | return self.forward_time_series(x) 71 | else: 72 | return self.forward_single_frame(x) 73 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from typing import Optional, List 6 | 7 | from .mobilenetv3 import MobileNetV3LargeEncoder 8 | from .resnet import ResNet50Encoder 9 | from .lraspp import LRASPP 10 | from .decoder import RecurrentDecoder, Projection 11 | from .fast_guided_filter import FastGuidedFilterRefiner 12 | from .deep_guided_filter import DeepGuidedFilterRefiner 13 | 14 | class MattingNetwork(nn.Module): 15 | def __init__(self, 16 | variant: str = 'mobilenetv3', 17 | refiner: str = 'deep_guided_filter', 18 | pretrained_backbone: bool = False): 19 | super().__init__() 20 | assert variant in ['mobilenetv3', 'resnet50'] 21 | assert refiner in ['fast_guided_filter', 'deep_guided_filter'] 22 | 23 | if variant == 'mobilenetv3': 24 | self.backbone = MobileNetV3LargeEncoder(pretrained_backbone) 25 | self.aspp = LRASPP(960, 128) 26 | self.decoder = RecurrentDecoder([16, 24, 40, 128], [80, 40, 32, 16]) 27 | else: 28 | self.backbone = ResNet50Encoder(pretrained_backbone) 29 | self.aspp = LRASPP(2048, 256) 30 | self.decoder = RecurrentDecoder([64, 256, 512, 256], [128, 64, 32, 16]) 31 | 32 | self.project_mat = Projection(16, 4) 33 | self.project_seg = Projection(16, 1) 34 | 35 | if refiner == 'deep_guided_filter': 36 | self.refiner = DeepGuidedFilterRefiner() 37 | else: 38 | self.refiner = FastGuidedFilterRefiner() 39 | 40 | def forward(self, 41 | src: Tensor, 42 | r1: Optional[Tensor] = None, 43 | r2: Optional[Tensor] = None, 44 | r3: Optional[Tensor] = None, 45 | r4: Optional[Tensor] = None, 46 | downsample_ratio: float = 1, 47 | segmentation_pass: bool = False): 48 | 49 | if downsample_ratio != 1: 50 | src_sm = self._interpolate(src, scale_factor=downsample_ratio) 51 | else: 52 | src_sm = src 53 | 54 | f1, f2, f3, f4 = self.backbone(src_sm) 55 | f4 = self.aspp(f4) 56 | hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) 57 | 58 | if not segmentation_pass: 59 | fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) 60 | if downsample_ratio != 1: 61 | fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) 62 | fgr = fgr_residual + src 63 | fgr = fgr.clamp(0., 1.) 64 | pha = pha.clamp(0., 1.) 65 | return [fgr, pha, *rec] 66 | else: 67 | seg = self.project_seg(hid) 68 | return [seg, *rec] 69 | 70 | def _interpolate(self, x: Tensor, scale_factor: float): 71 | if x.ndim == 5: 72 | B, T = x.shape[:2] 73 | x = F.interpolate(x.flatten(0, 1), scale_factor=scale_factor, 74 | mode='bilinear', align_corners=False, recompute_scale_factor=False) 75 | x = x.unflatten(0, (B, T)) 76 | else: 77 | x = F.interpolate(x, scale_factor=scale_factor, 78 | mode='bilinear', align_corners=False, recompute_scale_factor=False) 79 | return x 80 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models.resnet import ResNet, Bottleneck 4 | 5 | class ResNet50Encoder(ResNet): 6 | def __init__(self, pretrained: bool = False): 7 | super().__init__( 8 | block=Bottleneck, 9 | layers=[3, 4, 6, 3], 10 | replace_stride_with_dilation=[False, False, True], 11 | norm_layer=None) 12 | 13 | if pretrained: 14 | self.load_state_dict(torch.hub.load_state_dict_from_url( 15 | 'https://download.pytorch.org/models/resnet50-0676ba61.pth')) 16 | 17 | del self.avgpool 18 | del self.fc 19 | 20 | def forward_single_frame(self, x): 21 | x = self.conv1(x) 22 | x = self.bn1(x) 23 | x = self.relu(x) 24 | f1 = x # 1/2 25 | x = self.maxpool(x) 26 | x = self.layer1(x) 27 | f2 = x # 1/4 28 | x = self.layer2(x) 29 | f3 = x # 1/8 30 | x = self.layer3(x) 31 | x = self.layer4(x) 32 | f4 = x # 1/16 33 | return [f1, f2, f3, f4] 34 | 35 | def forward_time_series(self, x): 36 | B, T = x.shape[:2] 37 | features = self.forward_single_frame(x.flatten(0, 1)) 38 | features = [f.unflatten(0, (B, T)) for f in features] 39 | return features 40 | 41 | def forward(self, x): 42 | if x.ndim == 5: 43 | return self.forward_time_series(x) 44 | else: 45 | return self.forward_single_frame(x) 46 | -------------------------------------------------------------------------------- /requirements_inference.txt: -------------------------------------------------------------------------------- 1 | av==8.0.3 2 | torch==1.9.0 3 | torchvision==0.10.0 4 | tqdm==4.61.1 5 | pims==0.5 -------------------------------------------------------------------------------- /requirements_training.txt: -------------------------------------------------------------------------------- 1 | easing_functions==1.0.4 2 | tensorboard==2.5.0 3 | torch==1.9.0 4 | torchvision==0.10.0 5 | tqdm==4.61.1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | # First update `train_config.py` to set paths to your dataset locations. 3 | 4 | # You may want to change `--num-workers` according to your machine's memory. 5 | # The default num-workers=8 may cause dataloader to exit unexpectedly when 6 | # machine is out of memory. 7 | 8 | # Stage 1 9 | python train.py \ 10 | --model-variant mobilenetv3 \ 11 | --dataset videomatte \ 12 | --resolution-lr 512 \ 13 | --seq-length-lr 15 \ 14 | --learning-rate-backbone 0.0001 \ 15 | --learning-rate-aspp 0.0002 \ 16 | --learning-rate-decoder 0.0002 \ 17 | --learning-rate-refiner 0 \ 18 | --checkpoint-dir checkpoint/stage1 \ 19 | --log-dir log/stage1 \ 20 | --epoch-start 0 \ 21 | --epoch-end 20 22 | 23 | # Stage 2 24 | python train.py \ 25 | --model-variant mobilenetv3 \ 26 | --dataset videomatte \ 27 | --resolution-lr 512 \ 28 | --seq-length-lr 50 \ 29 | --learning-rate-backbone 0.00005 \ 30 | --learning-rate-aspp 0.0001 \ 31 | --learning-rate-decoder 0.0001 \ 32 | --learning-rate-refiner 0 \ 33 | --checkpoint checkpoint/stage1/epoch-19.pth \ 34 | --checkpoint-dir checkpoint/stage2 \ 35 | --log-dir log/stage2 \ 36 | --epoch-start 20 \ 37 | --epoch-end 22 38 | 39 | # Stage 3 40 | python train.py \ 41 | --model-variant mobilenetv3 \ 42 | --dataset videomatte \ 43 | --train-hr \ 44 | --resolution-lr 512 \ 45 | --resolution-hr 2048 \ 46 | --seq-length-lr 40 \ 47 | --seq-length-hr 6 \ 48 | --learning-rate-backbone 0.00001 \ 49 | --learning-rate-aspp 0.00001 \ 50 | --learning-rate-decoder 0.00001 \ 51 | --learning-rate-refiner 0.0002 \ 52 | --checkpoint checkpoint/stage2/epoch-21.pth \ 53 | --checkpoint-dir checkpoint/stage3 \ 54 | --log-dir log/stage3 \ 55 | --epoch-start 22 \ 56 | --epoch-end 23 57 | 58 | # Stage 4 59 | python train.py \ 60 | --model-variant mobilenetv3 \ 61 | --dataset imagematte \ 62 | --train-hr \ 63 | --resolution-lr 512 \ 64 | --resolution-hr 2048 \ 65 | --seq-length-lr 40 \ 66 | --seq-length-hr 6 \ 67 | --learning-rate-backbone 0.00001 \ 68 | --learning-rate-aspp 0.00001 \ 69 | --learning-rate-decoder 0.00005 \ 70 | --learning-rate-refiner 0.0002 \ 71 | --checkpoint checkpoint/stage3/epoch-22.pth \ 72 | --checkpoint-dir checkpoint/stage4 \ 73 | --log-dir log/stage4 \ 74 | --epoch-start 23 \ 75 | --epoch-end 28 76 | """ 77 | 78 | 79 | import argparse 80 | import torch 81 | import random 82 | import os 83 | from torch import nn 84 | from torch import distributed as dist 85 | from torch import multiprocessing as mp 86 | from torch.nn import functional as F 87 | from torch.nn.parallel import DistributedDataParallel as DDP 88 | from torch.optim import Adam 89 | from torch.cuda.amp import autocast, GradScaler 90 | from torch.utils.data import DataLoader, ConcatDataset 91 | from torch.utils.data.distributed import DistributedSampler 92 | from torch.utils.tensorboard import SummaryWriter 93 | from torchvision.utils import make_grid 94 | from torchvision.transforms.functional import center_crop 95 | from tqdm import tqdm 96 | 97 | from dataset.videomatte import ( 98 | VideoMatteDataset, 99 | VideoMatteTrainAugmentation, 100 | VideoMatteValidAugmentation, 101 | ) 102 | from dataset.imagematte import ( 103 | ImageMatteDataset, 104 | ImageMatteAugmentation 105 | ) 106 | from dataset.coco import ( 107 | CocoPanopticDataset, 108 | CocoPanopticTrainAugmentation, 109 | ) 110 | from dataset.spd import ( 111 | SuperviselyPersonDataset 112 | ) 113 | from dataset.youtubevis import ( 114 | YouTubeVISDataset, 115 | YouTubeVISAugmentation 116 | ) 117 | from dataset.augmentation import ( 118 | TrainFrameSampler, 119 | ValidFrameSampler 120 | ) 121 | from model import MattingNetwork 122 | from train_config import DATA_PATHS 123 | from train_loss import matting_loss, segmentation_loss 124 | 125 | 126 | class Trainer: 127 | def __init__(self, rank, world_size): 128 | self.parse_args() 129 | self.init_distributed(rank, world_size) 130 | self.init_datasets() 131 | self.init_model() 132 | self.init_writer() 133 | self.train() 134 | self.cleanup() 135 | 136 | def parse_args(self): 137 | parser = argparse.ArgumentParser() 138 | # Model 139 | parser.add_argument('--model-variant', type=str, required=True, choices=['mobilenetv3', 'resnet50']) 140 | # Matting dataset 141 | parser.add_argument('--dataset', type=str, required=True, choices=['videomatte', 'imagematte']) 142 | # Learning rate 143 | parser.add_argument('--learning-rate-backbone', type=float, required=True) 144 | parser.add_argument('--learning-rate-aspp', type=float, required=True) 145 | parser.add_argument('--learning-rate-decoder', type=float, required=True) 146 | parser.add_argument('--learning-rate-refiner', type=float, required=True) 147 | # Training setting 148 | parser.add_argument('--train-hr', action='store_true') 149 | parser.add_argument('--resolution-lr', type=int, default=512) 150 | parser.add_argument('--resolution-hr', type=int, default=2048) 151 | parser.add_argument('--seq-length-lr', type=int, required=True) 152 | parser.add_argument('--seq-length-hr', type=int, default=6) 153 | parser.add_argument('--downsample-ratio', type=float, default=0.25) 154 | parser.add_argument('--batch-size-per-gpu', type=int, default=1) 155 | parser.add_argument('--num-workers', type=int, default=8) 156 | parser.add_argument('--epoch-start', type=int, default=0) 157 | parser.add_argument('--epoch-end', type=int, default=16) 158 | # Tensorboard logging 159 | parser.add_argument('--log-dir', type=str, required=True) 160 | parser.add_argument('--log-train-loss-interval', type=int, default=20) 161 | parser.add_argument('--log-train-images-interval', type=int, default=500) 162 | # Checkpoint loading and saving 163 | parser.add_argument('--checkpoint', type=str) 164 | parser.add_argument('--checkpoint-dir', type=str, required=True) 165 | parser.add_argument('--checkpoint-save-interval', type=int, default=500) 166 | # Distributed 167 | parser.add_argument('--distributed-addr', type=str, default='localhost') 168 | parser.add_argument('--distributed-port', type=str, default='12355') 169 | # Debugging 170 | parser.add_argument('--disable-progress-bar', action='store_true') 171 | parser.add_argument('--disable-validation', action='store_true') 172 | parser.add_argument('--disable-mixed-precision', action='store_true') 173 | self.args = parser.parse_args() 174 | 175 | def init_distributed(self, rank, world_size): 176 | self.rank = rank 177 | self.world_size = world_size 178 | self.log('Initializing distributed') 179 | os.environ['MASTER_ADDR'] = self.args.distributed_addr 180 | os.environ['MASTER_PORT'] = self.args.distributed_port 181 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 182 | 183 | def init_datasets(self): 184 | self.log('Initializing matting datasets') 185 | size_hr = (self.args.resolution_hr, self.args.resolution_hr) 186 | size_lr = (self.args.resolution_lr, self.args.resolution_lr) 187 | 188 | # Matting datasets: 189 | if self.args.dataset == 'videomatte': 190 | self.dataset_lr_train = VideoMatteDataset( 191 | videomatte_dir=DATA_PATHS['videomatte']['train'], 192 | background_image_dir=DATA_PATHS['background_images']['train'], 193 | background_video_dir=DATA_PATHS['background_videos']['train'], 194 | size=self.args.resolution_lr, 195 | seq_length=self.args.seq_length_lr, 196 | seq_sampler=TrainFrameSampler(), 197 | transform=VideoMatteTrainAugmentation(size_lr)) 198 | if self.args.train_hr: 199 | self.dataset_hr_train = VideoMatteDataset( 200 | videomatte_dir=DATA_PATHS['videomatte']['train'], 201 | background_image_dir=DATA_PATHS['background_images']['train'], 202 | background_video_dir=DATA_PATHS['background_videos']['train'], 203 | size=self.args.resolution_hr, 204 | seq_length=self.args.seq_length_hr, 205 | seq_sampler=TrainFrameSampler(), 206 | transform=VideoMatteTrainAugmentation(size_hr)) 207 | self.dataset_valid = VideoMatteDataset( 208 | videomatte_dir=DATA_PATHS['videomatte']['valid'], 209 | background_image_dir=DATA_PATHS['background_images']['valid'], 210 | background_video_dir=DATA_PATHS['background_videos']['valid'], 211 | size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, 212 | seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, 213 | seq_sampler=ValidFrameSampler(), 214 | transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr)) 215 | else: 216 | self.dataset_lr_train = ImageMatteDataset( 217 | imagematte_dir=DATA_PATHS['imagematte']['train'], 218 | background_image_dir=DATA_PATHS['background_images']['train'], 219 | background_video_dir=DATA_PATHS['background_videos']['train'], 220 | size=self.args.resolution_lr, 221 | seq_length=self.args.seq_length_lr, 222 | seq_sampler=TrainFrameSampler(), 223 | transform=ImageMatteAugmentation(size_lr)) 224 | if self.args.train_hr: 225 | self.dataset_hr_train = ImageMatteDataset( 226 | imagematte_dir=DATA_PATHS['imagematte']['train'], 227 | background_image_dir=DATA_PATHS['background_images']['train'], 228 | background_video_dir=DATA_PATHS['background_videos']['train'], 229 | size=self.args.resolution_hr, 230 | seq_length=self.args.seq_length_hr, 231 | seq_sampler=TrainFrameSampler(), 232 | transform=ImageMatteAugmentation(size_hr)) 233 | self.dataset_valid = ImageMatteDataset( 234 | imagematte_dir=DATA_PATHS['imagematte']['valid'], 235 | background_image_dir=DATA_PATHS['background_images']['valid'], 236 | background_video_dir=DATA_PATHS['background_videos']['valid'], 237 | size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, 238 | seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, 239 | seq_sampler=ValidFrameSampler(), 240 | transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr)) 241 | 242 | # Matting dataloaders: 243 | self.datasampler_lr_train = DistributedSampler( 244 | dataset=self.dataset_lr_train, 245 | rank=self.rank, 246 | num_replicas=self.world_size, 247 | shuffle=True) 248 | self.dataloader_lr_train = DataLoader( 249 | dataset=self.dataset_lr_train, 250 | batch_size=self.args.batch_size_per_gpu, 251 | num_workers=self.args.num_workers, 252 | sampler=self.datasampler_lr_train, 253 | pin_memory=True) 254 | if self.args.train_hr: 255 | self.datasampler_hr_train = DistributedSampler( 256 | dataset=self.dataset_hr_train, 257 | rank=self.rank, 258 | num_replicas=self.world_size, 259 | shuffle=True) 260 | self.dataloader_hr_train = DataLoader( 261 | dataset=self.dataset_hr_train, 262 | batch_size=self.args.batch_size_per_gpu, 263 | num_workers=self.args.num_workers, 264 | sampler=self.datasampler_hr_train, 265 | pin_memory=True) 266 | self.dataloader_valid = DataLoader( 267 | dataset=self.dataset_valid, 268 | batch_size=self.args.batch_size_per_gpu, 269 | num_workers=self.args.num_workers, 270 | pin_memory=True) 271 | 272 | # Segementation datasets 273 | self.log('Initializing image segmentation datasets') 274 | self.dataset_seg_image = ConcatDataset([ 275 | CocoPanopticDataset( 276 | imgdir=DATA_PATHS['coco_panoptic']['imgdir'], 277 | anndir=DATA_PATHS['coco_panoptic']['anndir'], 278 | annfile=DATA_PATHS['coco_panoptic']['annfile'], 279 | transform=CocoPanopticTrainAugmentation(size_lr)), 280 | SuperviselyPersonDataset( 281 | imgdir=DATA_PATHS['spd']['imgdir'], 282 | segdir=DATA_PATHS['spd']['segdir'], 283 | transform=CocoPanopticTrainAugmentation(size_lr)) 284 | ]) 285 | self.datasampler_seg_image = DistributedSampler( 286 | dataset=self.dataset_seg_image, 287 | rank=self.rank, 288 | num_replicas=self.world_size, 289 | shuffle=True) 290 | self.dataloader_seg_image = DataLoader( 291 | dataset=self.dataset_seg_image, 292 | batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr, 293 | num_workers=self.args.num_workers, 294 | sampler=self.datasampler_seg_image, 295 | pin_memory=True) 296 | 297 | self.log('Initializing video segmentation datasets') 298 | self.dataset_seg_video = YouTubeVISDataset( 299 | videodir=DATA_PATHS['youtubevis']['videodir'], 300 | annfile=DATA_PATHS['youtubevis']['annfile'], 301 | size=self.args.resolution_lr, 302 | seq_length=self.args.seq_length_lr, 303 | seq_sampler=TrainFrameSampler(speed=[1]), 304 | transform=YouTubeVISAugmentation(size_lr)) 305 | self.datasampler_seg_video = DistributedSampler( 306 | dataset=self.dataset_seg_video, 307 | rank=self.rank, 308 | num_replicas=self.world_size, 309 | shuffle=True) 310 | self.dataloader_seg_video = DataLoader( 311 | dataset=self.dataset_seg_video, 312 | batch_size=self.args.batch_size_per_gpu, 313 | num_workers=self.args.num_workers, 314 | sampler=self.datasampler_seg_video, 315 | pin_memory=True) 316 | 317 | def init_model(self): 318 | self.log('Initializing model') 319 | self.model = MattingNetwork(self.args.model_variant, pretrained_backbone=True).to(self.rank) 320 | 321 | if self.args.checkpoint: 322 | self.log(f'Restoring from checkpoint: {self.args.checkpoint}') 323 | self.log(self.model.load_state_dict( 324 | torch.load(self.args.checkpoint, map_location=f'cuda:{self.rank}'))) 325 | 326 | self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) 327 | self.model_ddp = DDP(self.model, device_ids=[self.rank], broadcast_buffers=False, find_unused_parameters=True) 328 | self.optimizer = Adam([ 329 | {'params': self.model.backbone.parameters(), 'lr': self.args.learning_rate_backbone}, 330 | {'params': self.model.aspp.parameters(), 'lr': self.args.learning_rate_aspp}, 331 | {'params': self.model.decoder.parameters(), 'lr': self.args.learning_rate_decoder}, 332 | {'params': self.model.project_mat.parameters(), 'lr': self.args.learning_rate_decoder}, 333 | {'params': self.model.project_seg.parameters(), 'lr': self.args.learning_rate_decoder}, 334 | {'params': self.model.refiner.parameters(), 'lr': self.args.learning_rate_refiner}, 335 | ]) 336 | self.scaler = GradScaler() 337 | 338 | def init_writer(self): 339 | if self.rank == 0: 340 | self.log('Initializing writer') 341 | self.writer = SummaryWriter(self.args.log_dir) 342 | 343 | def train(self): 344 | for epoch in range(self.args.epoch_start, self.args.epoch_end): 345 | self.epoch = epoch 346 | self.step = epoch * len(self.dataloader_lr_train) 347 | 348 | if not self.args.disable_validation: 349 | self.validate() 350 | 351 | self.log(f'Training epoch: {epoch}') 352 | for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_lr_train, disable=self.args.disable_progress_bar, dynamic_ncols=True): 353 | # Low resolution pass 354 | self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=1, tag='lr') 355 | 356 | # High resolution pass 357 | if self.args.train_hr: 358 | true_fgr, true_pha, true_bgr = self.load_next_mat_hr_sample() 359 | self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr') 360 | 361 | # Segmentation pass 362 | if self.step % 2 == 0: 363 | true_img, true_seg = self.load_next_seg_video_sample() 364 | self.train_seg(true_img, true_seg, log_label='seg_video') 365 | else: 366 | true_img, true_seg = self.load_next_seg_image_sample() 367 | self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image') 368 | 369 | if self.step % self.args.checkpoint_save_interval == 0: 370 | self.save() 371 | 372 | self.step += 1 373 | 374 | def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag): 375 | true_fgr = true_fgr.to(self.rank, non_blocking=True) 376 | true_pha = true_pha.to(self.rank, non_blocking=True) 377 | true_bgr = true_bgr.to(self.rank, non_blocking=True) 378 | true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr) 379 | true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) 380 | 381 | with autocast(enabled=not self.args.disable_mixed_precision): 382 | pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2] 383 | loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha) 384 | 385 | self.scaler.scale(loss['total']).backward() 386 | self.scaler.step(self.optimizer) 387 | self.scaler.update() 388 | self.optimizer.zero_grad() 389 | 390 | if self.rank == 0 and self.step % self.args.log_train_loss_interval == 0: 391 | for loss_name, loss_value in loss.items(): 392 | self.writer.add_scalar(f'train_{tag}_{loss_name}', loss_value, self.step) 393 | 394 | if self.rank == 0 and self.step % self.args.log_train_images_interval == 0: 395 | self.writer.add_image(f'train_{tag}_pred_fgr', make_grid(pred_fgr.flatten(0, 1), nrow=pred_fgr.size(1)), self.step) 396 | self.writer.add_image(f'train_{tag}_pred_pha', make_grid(pred_pha.flatten(0, 1), nrow=pred_pha.size(1)), self.step) 397 | self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step) 398 | self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step) 399 | self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step) 400 | 401 | def train_seg(self, true_img, true_seg, log_label): 402 | true_img = true_img.to(self.rank, non_blocking=True) 403 | true_seg = true_seg.to(self.rank, non_blocking=True) 404 | 405 | true_img, true_seg = self.random_crop(true_img, true_seg) 406 | 407 | with autocast(enabled=not self.args.disable_mixed_precision): 408 | pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0] 409 | loss = segmentation_loss(pred_seg, true_seg) 410 | 411 | self.scaler.scale(loss).backward() 412 | self.scaler.step(self.optimizer) 413 | self.scaler.update() 414 | self.optimizer.zero_grad() 415 | 416 | if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0: 417 | self.writer.add_scalar(f'{log_label}_loss', loss, self.step) 418 | 419 | if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0: 420 | self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step) 421 | self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) 422 | self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) 423 | 424 | def load_next_mat_hr_sample(self): 425 | try: 426 | sample = next(self.dataiterator_mat_hr) 427 | except: 428 | self.datasampler_hr_train.set_epoch(self.datasampler_hr_train.epoch + 1) 429 | self.dataiterator_mat_hr = iter(self.dataloader_hr_train) 430 | sample = next(self.dataiterator_mat_hr) 431 | return sample 432 | 433 | def load_next_seg_video_sample(self): 434 | try: 435 | sample = next(self.dataiterator_seg_video) 436 | except: 437 | self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1) 438 | self.dataiterator_seg_video = iter(self.dataloader_seg_video) 439 | sample = next(self.dataiterator_seg_video) 440 | return sample 441 | 442 | def load_next_seg_image_sample(self): 443 | try: 444 | sample = next(self.dataiterator_seg_image) 445 | except: 446 | self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1) 447 | self.dataiterator_seg_image = iter(self.dataloader_seg_image) 448 | sample = next(self.dataiterator_seg_image) 449 | return sample 450 | 451 | def validate(self): 452 | if self.rank == 0: 453 | self.log(f'Validating at the start of epoch: {self.epoch}') 454 | self.model_ddp.eval() 455 | total_loss, total_count = 0, 0 456 | with torch.no_grad(): 457 | with autocast(enabled=not self.args.disable_mixed_precision): 458 | for true_fgr, true_pha, true_bgr in tqdm(self.dataloader_valid, disable=self.args.disable_progress_bar, dynamic_ncols=True): 459 | true_fgr = true_fgr.to(self.rank, non_blocking=True) 460 | true_pha = true_pha.to(self.rank, non_blocking=True) 461 | true_bgr = true_bgr.to(self.rank, non_blocking=True) 462 | true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) 463 | batch_size = true_src.size(0) 464 | pred_fgr, pred_pha = self.model(true_src)[:2] 465 | total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size 466 | total_count += batch_size 467 | avg_loss = total_loss / total_count 468 | self.log(f'Validation set average loss: {avg_loss}') 469 | self.writer.add_scalar('valid_loss', avg_loss, self.step) 470 | self.model_ddp.train() 471 | dist.barrier() 472 | 473 | def random_crop(self, *imgs): 474 | h, w = imgs[0].shape[-2:] 475 | w = random.choice(range(w // 2, w)) 476 | h = random.choice(range(h // 2, h)) 477 | results = [] 478 | for img in imgs: 479 | B, T = img.shape[:2] 480 | img = img.flatten(0, 1) 481 | img = F.interpolate(img, (max(h, w), max(h, w)), mode='bilinear', align_corners=False) 482 | img = center_crop(img, (h, w)) 483 | img = img.reshape(B, T, *img.shape[1:]) 484 | results.append(img) 485 | return results 486 | 487 | def save(self): 488 | if self.rank == 0: 489 | os.makedirs(self.args.checkpoint_dir, exist_ok=True) 490 | torch.save(self.model.state_dict(), os.path.join(self.args.checkpoint_dir, f'epoch-{self.epoch}.pth')) 491 | self.log('Model saved') 492 | dist.barrier() 493 | 494 | def cleanup(self): 495 | dist.destroy_process_group() 496 | 497 | def log(self, msg): 498 | print(f'[GPU{self.rank}] {msg}') 499 | 500 | if __name__ == '__main__': 501 | world_size = torch.cuda.device_count() 502 | mp.spawn( 503 | Trainer, 504 | nprocs=world_size, 505 | args=(world_size,), 506 | join=True) 507 | -------------------------------------------------------------------------------- /train_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Expected directory format: 3 | 4 | VideoMatte Train/Valid: 5 | ├──fgr/ 6 | ├── 0001/ 7 | ├── 00000.jpg 8 | ├── 00001.jpg 9 | ├── pha/ 10 | ├── 0001/ 11 | ├── 00000.jpg 12 | ├── 00001.jpg 13 | 14 | ImageMatte Train/Valid: 15 | ├── fgr/ 16 | ├── sample1.jpg 17 | ├── sample2.jpg 18 | ├── pha/ 19 | ├── sample1.jpg 20 | ├── sample2.jpg 21 | 22 | Background Image Train/Valid 23 | ├── sample1.png 24 | ├── sample2.png 25 | 26 | Background Video Train/Valid 27 | ├── 0000/ 28 | ├── 0000.jpg/ 29 | ├── 0001.jpg/ 30 | 31 | """ 32 | 33 | 34 | DATA_PATHS = { 35 | 36 | 'videomatte': { 37 | 'train': '../matting-data/VideoMatte240K_JPEG_SD/train', 38 | 'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid', 39 | }, 40 | 'imagematte': { 41 | 'train': '../matting-data/ImageMatte/train', 42 | 'valid': '../matting-data/ImageMatte/valid', 43 | }, 44 | 'background_images': { 45 | 'train': '../matting-data/Backgrounds/train', 46 | 'valid': '../matting-data/Backgrounds/valid', 47 | }, 48 | 'background_videos': { 49 | 'train': '../matting-data/BackgroundVideos/train', 50 | 'valid': '../matting-data/BackgroundVideos/valid', 51 | }, 52 | 53 | 54 | 'coco_panoptic': { 55 | 'imgdir': '../matting-data/coco/train2017/', 56 | 'anndir': '../matting-data/coco/panoptic_train2017/', 57 | 'annfile': '../matting-data/coco/annotations/panoptic_train2017.json', 58 | }, 59 | 'spd': { 60 | 'imgdir': '../matting-data/SuperviselyPersonDataset/img', 61 | 'segdir': '../matting-data/SuperviselyPersonDataset/seg', 62 | }, 63 | 'youtubevis': { 64 | 'videodir': '../matting-data/YouTubeVIS/train/JPEGImages', 65 | 'annfile': '../matting-data/YouTubeVIS/train/instances.json', 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /train_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | # --------------------------------------------------------------------------------- Train Loss 5 | 6 | 7 | def matting_loss(pred_fgr, pred_pha, true_fgr, true_pha): 8 | """ 9 | Args: 10 | pred_fgr: Shape(B, T, 3, H, W) 11 | pred_pha: Shape(B, T, 1, H, W) 12 | true_fgr: Shape(B, T, 3, H, W) 13 | true_pha: Shape(B, T, 1, H, W) 14 | """ 15 | loss = dict() 16 | # Alpha losses 17 | loss['pha_l1'] = F.l1_loss(pred_pha, true_pha) 18 | loss['pha_laplacian'] = laplacian_loss(pred_pha.flatten(0, 1), true_pha.flatten(0, 1)) 19 | loss['pha_coherence'] = F.mse_loss(pred_pha[:, 1:] - pred_pha[:, :-1], 20 | true_pha[:, 1:] - true_pha[:, :-1]) * 5 21 | # Foreground losses 22 | true_msk = true_pha.gt(0) 23 | pred_fgr = pred_fgr * true_msk 24 | true_fgr = true_fgr * true_msk 25 | loss['fgr_l1'] = F.l1_loss(pred_fgr, true_fgr) 26 | loss['fgr_coherence'] = F.mse_loss(pred_fgr[:, 1:] - pred_fgr[:, :-1], 27 | true_fgr[:, 1:] - true_fgr[:, :-1]) * 5 28 | # Total 29 | loss['total'] = loss['pha_l1'] + loss['pha_coherence'] + loss['pha_laplacian'] \ 30 | + loss['fgr_l1'] + loss['fgr_coherence'] 31 | return loss 32 | 33 | def segmentation_loss(pred_seg, true_seg): 34 | """ 35 | Args: 36 | pred_seg: Shape(B, T, 1, H, W) 37 | true_seg: Shape(B, T, 1, H, W) 38 | """ 39 | return F.binary_cross_entropy_with_logits(pred_seg, true_seg) 40 | 41 | 42 | # ----------------------------------------------------------------------------- Laplacian Loss 43 | 44 | 45 | def laplacian_loss(pred, true, max_levels=5): 46 | kernel = gauss_kernel(device=pred.device, dtype=pred.dtype) 47 | pred_pyramid = laplacian_pyramid(pred, kernel, max_levels) 48 | true_pyramid = laplacian_pyramid(true, kernel, max_levels) 49 | loss = 0 50 | for level in range(max_levels): 51 | loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level]) 52 | return loss / max_levels 53 | 54 | def laplacian_pyramid(img, kernel, max_levels): 55 | current = img 56 | pyramid = [] 57 | for _ in range(max_levels): 58 | current = crop_to_even_size(current) 59 | down = downsample(current, kernel) 60 | up = upsample(down, kernel) 61 | diff = current - up 62 | pyramid.append(diff) 63 | current = down 64 | return pyramid 65 | 66 | def gauss_kernel(device='cpu', dtype=torch.float32): 67 | kernel = torch.tensor([[1, 4, 6, 4, 1], 68 | [4, 16, 24, 16, 4], 69 | [6, 24, 36, 24, 6], 70 | [4, 16, 24, 16, 4], 71 | [1, 4, 6, 4, 1]], device=device, dtype=dtype) 72 | kernel /= 256 73 | kernel = kernel[None, None, :, :] 74 | return kernel 75 | 76 | def gauss_convolution(img, kernel): 77 | B, C, H, W = img.shape 78 | img = img.reshape(B * C, 1, H, W) 79 | img = F.pad(img, (2, 2, 2, 2), mode='reflect') 80 | img = F.conv2d(img, kernel) 81 | img = img.reshape(B, C, H, W) 82 | return img 83 | 84 | def downsample(img, kernel): 85 | img = gauss_convolution(img, kernel) 86 | img = img[:, :, ::2, ::2] 87 | return img 88 | 89 | def upsample(img, kernel): 90 | B, C, H, W = img.shape 91 | out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype) 92 | out[:, :, ::2, ::2] = img * 4 93 | out = gauss_convolution(out, kernel) 94 | return out 95 | 96 | def crop_to_even_size(img): 97 | H, W = img.shape[2:] 98 | H = H - H % 2 99 | W = W - W % 2 100 | return img[:, :, :H, :W] 101 | 102 | --------------------------------------------------------------------------------