├── 20250206workflow.json ├── LICENSE ├── README.md ├── __init__.py ├── diffueraser_node.py ├── example.png ├── examplea.png ├── exampleb.png ├── examples ├── __init__.py ├── mask.mp4 └── video.mp4 ├── libs ├── __init__.py ├── brushnet_CA.py ├── diffueraser.py ├── pipeline_diffueraser.py ├── transformer_temporal.py ├── unet_2d_blocks.py ├── unet_2d_condition.py ├── unet_3d_blocks.py └── unet_motion_model.py ├── node_utils.py ├── propainter ├── RAFT │ ├── __init__.py │ ├── corr.py │ ├── datasets.py │ ├── demo.py │ ├── extractor.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── flow_viz_pt.py │ │ ├── frame_utils.py │ │ └── utils.py ├── core │ ├── __init__.py │ ├── dataset.py │ ├── dist.py │ ├── loss.py │ ├── lr_scheduler.py │ ├── metrics.py │ ├── prefetch_dataloader.py │ ├── trainer.py │ ├── trainer_flow_w_edge.py │ └── utils.py ├── inference.py ├── model │ ├── __init__.py │ ├── canny │ │ ├── __init__.py │ │ ├── canny_filter.py │ │ ├── filter.py │ │ ├── gaussian.py │ │ ├── kernels.py │ │ └── sobel.py │ ├── misc.py │ ├── modules │ │ ├── __init__.py │ │ ├── base_module.py │ │ ├── deformconv.py │ │ ├── flow_comp_raft.py │ │ ├── flow_loss_utils.py │ │ ├── sparse_transformer.py │ │ └── spectral_norm.py │ ├── propainter.py │ ├── recurrent_flow_completion.py │ └── vgg_arch.py └── utils │ ├── __init__.py │ ├── download_util.py │ ├── file_client.py │ ├── flow_util.py │ └── img_util.py ├── pyproject.toml ├── requirements.txt ├── run_diffueraser.py └── sd15_repo ├── feature_extractor └── preprocessor_config.json ├── model_index.json ├── safety_checker └── config.json ├── scheduler └── scheduler_config.json ├── text_encoder └── config.json ├── tokenizer ├── merges.txt ├── special_tokens_map.json ├── tokenizer_config.json └── vocab.json ├── unet └── config.json └── vae └── config.json /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 smthemex 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI_DiffuEraser 2 | [DiffuEraser](https://github.com/lixiaowen-xw/DiffuEraser) is a diffusion model for video Inpainting, you can use it in ComfyUI 3 | 4 | # Update 5 | * 新增单mask模式,处理固定水印用,见示例图 6 | * Added single mask mode for handling fixed watermarks,new example; 7 | * 方法前置处理的视频反而更好,可能是代码的问题,所以增加输出前置视频 8 | * The video pre processed by the method is actually better, which may be a problem with the code, so adding output pre processed video is necessary 9 | 10 | 11 | # 1. Installation 12 | 13 | In the ./ComfyUI /custom_node directory, run the following: 14 | ``` 15 | git clone https://github.com/smthemex/ComfyUI_DiffuEraser.git 16 | ``` 17 | --- 18 | 19 | # 2. Requirements 20 | * no need, because it's base in sd1.5 ,Perhaps someone may be missing the library.没什么特殊的库,懒得删了 21 | ``` 22 | pip install -r requirements.txt 23 | ``` 24 | # 3. Models 25 | * sd1.5 [address](https://modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5/files) v1-5-pruned-emaonly.safetensors #example 26 | * pcm 1.5 lora [address](https://huggingface.co/wangfuyun/PCM_Weights/tree/main/sd15) pcm_sd15_smallcfg_2step_converted.safetensors #example 27 | * ProPainter [address](https://github.com/sczhou/ProPainter/releases/tag/v0.1.0) # below example 28 | * unet and brushnet [address](https://huggingface.co/lixiaowen/diffuEraser/tree/main) # below example 29 | 30 | ``` 31 | -- ComfyUI/models/checkpoints 32 | |-- any sd1.5 safetensors #任意sd1.5模型 33 | -- ComfyUI/models/DiffuEraser 34 | |--brushnet 35 | |-- config.json 36 | |-- diffusion_pytorch_model.safetensors 37 | |--unet_main 38 | |-- config.json 39 | |-- diffusion_pytorch_model.safetensors 40 | |--propainter 41 | |-- ProPainter.pth 42 | |-- raft-things.pth 43 | |-- recurrent_flow_completion.pth 44 | ``` 45 | * If use video to mask #可以用RMBG或者BiRefNet模型脱底 46 | ``` 47 | -- any/path/briaai/RMBG-2.0 # or auto download 48 | |--config.json 49 | |--model.safetensors 50 | |--birefnet.py 51 | |--BiRefNet_config.py 52 | Or 53 | -- any/path/ZhengPeng7/BiRefNet # or auto download 54 | |--config.json 55 | |--model.safetensors 56 | |--birefnet.py 57 | |--BiRefNet_config.py 58 | |--handler.py 59 | ``` 60 | 61 | # 4.Tips 62 | * video2mask : If only the input video is available, please enable this option (generate mask video). 如果只有输入视频,请开启此选项(生成遮罩视频) 63 | 64 | # 5 Example 65 | * Use one mask 66 | ![](https://github.com/smthemex/ComfyUI_DiffuEraser/blob/main/examplea.png) 67 | * Use RMBG or BiRefNet make video2mask 使用RMBG or BiRefNet将 输入视频转为mask,注意RMBG不能商用. 68 | ![](https://github.com/smthemex/ComfyUI_DiffuEraser/blob/main/example.png) 69 | * Use Mask video 使用遮罩视频,可以用其他方法,如sam2一类转化: 70 | ![](https://github.com/smthemex/ComfyUI_DiffuEraser/blob/main/exampleb.png) 71 | 72 | 73 | # 6.Citation 74 | ``` 75 | @misc{li2025diffueraserdiffusionmodelvideo, 76 | title={DiffuEraser: A Diffusion Model for Video Inpainting}, 77 | author={Xiaowen Li and Haolan Xue and Peiran Ren and Liefeng Bo}, 78 | year={2025}, 79 | eprint={2501.10018}, 80 | archivePrefix={arXiv}, 81 | primaryClass={cs.CV}, 82 | url={https://arxiv.org/abs/2501.10018}, 83 | } 84 | ``` 85 | ``` 86 | @inproceedings{zhou2023propainter, 87 | title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting}, 88 | author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change}, 89 | booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)}, 90 | year={2023} 91 | } 92 | ``` 93 | ``` 94 | @misc{ju2024brushnet, 95 | title={BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion}, 96 | author={Xuan Ju and Xian Liu and Xintao Wang and Yuxuan Bian and Ying Shan and Qiang Xu}, 97 | year={2024}, 98 | eprint={2403.06976}, 99 | archivePrefix={arXiv}, 100 | primaryClass={cs.CV} 101 | } 102 | ``` 103 | ``` 104 | @article{BiRefNet, 105 | title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation}, 106 | author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu}, 107 | journal={CAAI Artificial Intelligence Research}, 108 | year={2024} 109 | } 110 | 111 | ``` 112 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .diffueraser_node import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 3 | 4 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 5 | -------------------------------------------------------------------------------- /diffueraser_node.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import os 4 | import torch 5 | import gc 6 | import numpy as np 7 | from .node_utils import load_images,tensor2pil_list,file_exists,download_weights,image2masks 8 | import folder_paths 9 | from .run_diffueraser import load_diffueraser,diffueraser_inference 10 | 11 | 12 | MAX_SEED = np.iinfo(np.int32).max 13 | current_node_path = os.path.dirname(os.path.abspath(__file__)) 14 | device = torch.device( 15 | "cuda") if torch.cuda.is_available() else torch.device("cpu") 16 | 17 | # add checkpoints dir 18 | DiffuEraser_weigths_path = os.path.join(folder_paths.models_dir, "DiffuEraser") 19 | if not os.path.exists(DiffuEraser_weigths_path): 20 | os.makedirs(DiffuEraser_weigths_path) 21 | folder_paths.add_model_folder_path("DiffuEraser", DiffuEraser_weigths_path) 22 | 23 | 24 | 25 | class DiffuEraserLoader: 26 | def __init__(self): 27 | pass 28 | 29 | @classmethod 30 | def INPUT_TYPES(s): 31 | return { 32 | "required": { 33 | "checkpoint": (["none"] + folder_paths.get_filename_list("checkpoints"),), 34 | "lora": (["none"] + folder_paths.get_filename_list("loras"),), 35 | }, 36 | } 37 | 38 | RETURN_TYPES = ("MODEL_DiffuEraser",) 39 | RETURN_NAMES = ("model",) 40 | FUNCTION = "loader_main" 41 | CATEGORY = "DiffuEraser" 42 | 43 | def loader_main(self,checkpoint,lora): 44 | 45 | # check model is exits or not,if not auto downlaod 46 | 47 | brushnet_weigths_path = os.path.join(DiffuEraser_weigths_path, "brushnet") 48 | if not os.path.exists(brushnet_weigths_path): 49 | os.makedirs(brushnet_weigths_path) 50 | 51 | unet_weigths_path = os.path.join(DiffuEraser_weigths_path, "unet_main") 52 | if not os.path.exists(unet_weigths_path): 53 | os.makedirs(unet_weigths_path) 54 | 55 | 56 | if not file_exists(brushnet_weigths_path,"config.json") : 57 | download_weights(DiffuEraser_weigths_path,"lixiaowen/diffuEraser",subfolder="brushnet",pt_name="config.json") 58 | if not file_exists(brushnet_weigths_path,"diffusion_pytorch_model.safetensors"): 59 | download_weights(DiffuEraser_weigths_path,"lixiaowen/diffuEraser",subfolder="brushnet",pt_name="diffusion_pytorch_model.safetensors") 60 | 61 | if not file_exists(unet_weigths_path,"diffusion_pytorch_model.safetensors"): 62 | download_weights(DiffuEraser_weigths_path,"lixiaowen/diffuEraser",subfolder="unet_main",pt_name="diffusion_pytorch_model.safetensors") 63 | if not file_exists(unet_weigths_path,"config.json"): 64 | download_weights(DiffuEraser_weigths_path,"lixiaowen/diffuEraser",subfolder="unet_main",pt_name="config.json") 65 | 66 | # load model 67 | original_config_file=os.path.join(folder_paths.models_dir,"configs","v1-inference.yaml") 68 | sd_repo=os.path.join(current_node_path,"sd15_repo") 69 | if checkpoint!="none": 70 | ckpt_path=folder_paths.get_full_path("checkpoints",checkpoint) 71 | else: 72 | raise "no sd1.5 checkpoint" 73 | 74 | if lora!="none": 75 | pcm_lora_path=folder_paths.get_full_path("loras",lora) 76 | else: 77 | raise "no pcm lora checkpoint" 78 | # if vae!="none" : 79 | # vae_path=folder_paths.get_full_path("vae",vae) 80 | # else: 81 | # raise "no sd1.5 vae" 82 | 83 | 84 | model=load_diffueraser(DiffuEraser_weigths_path, pcm_lora_path,sd_repo,ckpt_path,original_config_file,device) 85 | 86 | gc.collect() 87 | torch.cuda.empty_cache() 88 | return (model,) 89 | 90 | class DiffuEraserSampler: 91 | def __init__(self): 92 | pass 93 | 94 | @classmethod 95 | def INPUT_TYPES(s): 96 | return { 97 | "required": { 98 | "model": ("MODEL_DiffuEraser",), 99 | "images": ("IMAGE",), #[b,h,w,c] 100 | "fps": ("FLOAT", {"forceInput": True,}), 101 | "seed": ("INT", {"default": -1, "min": -1, "max": MAX_SEED}), 102 | "num_inference_steps": ("INT", { 103 | "default": 2, 104 | "min": 1, # Minimum value 105 | "max": 120, # Maximum value 106 | "step": 1, # Slider's step 107 | "display": "number", # Cosmetic only: display as "number" or "slider" 108 | }), 109 | "guidance_scale": ("FLOAT", {"default": 0, "min": 0, "max": 10., "step": -0.1, "display": "number"}), 110 | "video_length": ("INT", { 111 | "default": 10, 112 | "min": 1, # Minimum value 113 | "max": 1024, # Maximum value 114 | "step": 1, # Slider's step 115 | "display": "number", # Cosmetic only: display as "number" or "slider" 116 | }), 117 | "mask_dilation_iter": ("INT", { 118 | "default": 8, 119 | "min": 1, # Minimum value 120 | "max": 1024, # Maximum value 121 | "step": 1, # Slider's step 122 | "display": "number", # Cosmetic only: display as "number" or "slider" 123 | }), 124 | "ref_stride": ("INT", { 125 | "default": 10, 126 | "min": 1, # Minimum value 127 | "max": 1024, # Maximum value 128 | "step": 1, # Slider's step 129 | "display": "number", # Cosmetic only: display as "number" or "slider" 130 | }), 131 | "neighbor_length": ("INT", { 132 | "default": 10, 133 | "min": 1, # Minimum value 134 | "max": 1024, # Maximum value 135 | "step": 1, # Slider's step 136 | "display": "number", # Cosmetic only: display as "number" or "slider" 137 | }), 138 | "subvideo_length": ("INT", { 139 | "default": 50, 140 | "min": 1, # Minimum value 141 | "max": 1024, # Maximum value 142 | "step": 1, # Slider's step 143 | "display": "number", # Cosmetic only: display as "number" or "slider" 144 | }), 145 | "video2mask":("BOOLEAN", {"default": False},), 146 | "seg_repo": ("STRING", {"default": "briaai/RMBG-2.0"},), 147 | "save_result_video":("BOOLEAN", {"default": False},),}, 148 | "optional": { 149 | "video_mask": ("IMAGE",), 150 | 151 | } 152 | 153 | } 154 | 155 | RETURN_TYPES = ("IMAGE","IMAGE","STRING", ) 156 | RETURN_NAMES = ("images","propainter_img","output_path", ) 157 | FUNCTION = "sampler_main" 158 | CATEGORY = "DiffuEraser" 159 | 160 | def sampler_main(self, model,images,fps,seed,num_inference_steps,guidance_scale,video_length,mask_dilation_iter,ref_stride,neighbor_length,subvideo_length,video2mask,seg_repo,save_result_video,**kwargs): 161 | 162 | video_inpainting_sd=model.get("video_inpainting_sd") 163 | propainter=model.get("propainter") 164 | 165 | max_img_size=1920 166 | _,height,width,_ = images.size() 167 | video_image=tensor2pil_list(images,width,height) 168 | if video2mask and seg_repo: 169 | print("***********Start video to masks infer **************") 170 | video_mask=image2masks(seg_repo,video_image)# use rmbg or BiRefNet to make video to masks 171 | else: 172 | if isinstance(kwargs.get("video_mask"),torch.Tensor): 173 | video_mask=tensor2pil_list(kwargs.get("video_mask"),width,height) 174 | else: 175 | raise "no video_mask,you can enable video2mask and fill a rmbg or BiRefNet repo to generate mask from video_image,or link video_mask from other node" 176 | 177 | seeds=None if seed==-1 else seed 178 | 179 | print("frame_length:",len(video_image),"mask_length:",len(video_mask),"fps:",fps) 180 | if len(video_mask)!=len(video_image) and len(video_mask)==1: 181 | video_mask=video_mask*len(video_image) # if use one mask to inpaint all frames 182 | assert len(video_image) == len(video_mask), "Length of video_image and video_mask must be equal" 183 | 184 | print("***********Start DiffuEraser Sampler**************") 185 | video_inpainting_sd.to(device) 186 | propainter.to(device) 187 | output_path,image_list,Propainter_list=diffueraser_inference(video_inpainting_sd,propainter,video_image,video_mask,video_length,width,height, 188 | mask_dilation_iter,max_img_size,ref_stride,neighbor_length,subvideo_length,guidance_scale,num_inference_steps,seeds,fps,save_result_video) 189 | video_inpainting_sd.to("cpu") 190 | #propainter.to("cpu") 191 | 192 | images=load_images(image_list) 193 | Propainter_img=load_images(Propainter_list) 194 | gc.collect() 195 | torch.cuda.empty_cache() 196 | return (images,Propainter_img,output_path,) 197 | 198 | 199 | 200 | NODE_CLASS_MAPPINGS = { 201 | "DiffuEraserLoader":DiffuEraserLoader, 202 | "DiffuEraserSampler":DiffuEraserSampler, 203 | } 204 | 205 | NODE_DISPLAY_NAME_MAPPINGS = { 206 | "DiffuEraserLoader":"DiffuEraserLoader", 207 | "DiffuEraserSampler":"DiffuEraserSampler", 208 | } 209 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_DiffuEraser/a59995e30f6b7a78f4606a7773c6031752dd738d/example.png -------------------------------------------------------------------------------- /examplea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_DiffuEraser/a59995e30f6b7a78f4606a7773c6031752dd738d/examplea.png -------------------------------------------------------------------------------- /exampleb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_DiffuEraser/a59995e30f6b7a78f4606a7773c6031752dd738d/exampleb.png -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/mask.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_DiffuEraser/a59995e30f6b7a78f4606a7773c6031752dd738d/examples/mask.mp4 -------------------------------------------------------------------------------- /examples/video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smthemex/ComfyUI_DiffuEraser/a59995e30f6b7a78f4606a7773c6031752dd738d/examples/video.mp4 -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /node_utils.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import os 4 | import torch 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | import time 9 | from comfy.utils import common_upscale,ProgressBar 10 | from huggingface_hub import hf_hub_download 11 | import torchvision.transforms as transforms 12 | from transformers import AutoModelForImageSegmentation 13 | import folder_paths 14 | import gc 15 | cur_path = os.path.dirname(os.path.abspath(__file__)) 16 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 17 | 18 | 19 | 20 | def image2masks(repo,video_image): 21 | start_time = time.time() 22 | model = AutoModelForImageSegmentation.from_pretrained(repo, trust_remote_code=True) 23 | torch.set_float32_matmul_precision(['high', 'highest'][0]) 24 | model.to('cuda') 25 | model.eval() 26 | # Data settings 27 | image_size = (1024, 1024) 28 | transform_image = transforms.Compose([ 29 | transforms.Resize(image_size), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 32 | ]) 33 | masks=[] 34 | for img in video_image: 35 | input_images = transform_image(img).unsqueeze(0).to('cuda') 36 | # Prediction 37 | with torch.no_grad(): 38 | preds = model(input_images)[-1].sigmoid().cpu() 39 | pred = preds[0].squeeze() 40 | pred_pil = transforms.ToPILImage()(pred) 41 | mask = pred_pil.resize(img.size) 42 | #img.putalpha(mask) 43 | masks.append(mask.convert('RGB')) 44 | end_time = time.time() 45 | load_time = end_time - start_time 46 | print(f"image2masks infer time: {load_time:.4f} s") 47 | model.to('cpu') 48 | gc.collect() 49 | torch.cuda.empty_cache() 50 | return masks 51 | 52 | 53 | def resize_and_center_paste(image_list, target_size=(1024, 1024)): 54 | # 定义转换为张量的变换 55 | to_tensor = transforms.ToTensor() 56 | 57 | # 处理每张图像 58 | images = [] 59 | for img in image_list: 60 | # 获取原始图像尺寸 61 | img_width, img_height = img.size 62 | 63 | # 计算缩放比例 64 | scale_factor = target_size[0] / max(img_width, img_height) 65 | 66 | # 计算新的尺寸 67 | new_width = int(img_width * scale_factor) 68 | new_height = int(img_height * scale_factor) 69 | 70 | # 缩放图像 71 | resized_img = img.resize((new_width, new_height), Image.BICUBIC) 72 | 73 | # 创建空白画布 74 | canvas = Image.new('RGB', target_size, (0, 0, 0)) 75 | 76 | # 计算粘贴位置 77 | paste_x = (target_size[0] - new_width) // 2 78 | paste_y = (target_size[1] - new_height) // 2 79 | 80 | # 粘贴图像到画布中心 81 | canvas.paste(resized_img, (paste_x, paste_y)) 82 | 83 | # 转换为张量 84 | tensor_img = to_tensor(canvas) 85 | images.append(tensor_img) 86 | 87 | # 堆叠所有张量 88 | images_tensor = torch.stack(images) 89 | return images_tensor 90 | 91 | 92 | 93 | def center_paste_and_resize(image_list, target_size=(1024, 1024)): 94 | # 定义转换为张量的变换 95 | to_tensor = transforms.ToTensor() 96 | 97 | # 处理每张图像 98 | images = [] 99 | for img in image_list: 100 | # 创建空白画布 101 | canvas = Image.new('RGB', target_size, (0, 0, 0)) 102 | 103 | # 计算粘贴位置 104 | img_width, img_height = img.size 105 | paste_x = (target_size[0] - img_width) // 2 106 | paste_y = (target_size[1] - img_height) // 2 107 | 108 | # 粘贴图像到画布中心 109 | canvas.paste(img, (paste_x, paste_y)) 110 | 111 | # 转换为张量 112 | tensor_img = to_tensor(canvas) 113 | images.append(tensor_img) 114 | 115 | # 堆叠所有张量 116 | images_tensor = torch.stack(images) 117 | return images_tensor 118 | 119 | 120 | def tensor_to_pil(tensor): 121 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 122 | image = Image.fromarray(image_np, mode='RGB') 123 | return image 124 | 125 | def tensor2pil_list(image,width,height): 126 | B,_,_,_=image.size() 127 | if B==1: 128 | ref_image_list=[tensor2pil_upscale(image,width,height)] 129 | else: 130 | img_list = list(torch.chunk(image, chunks=B)) 131 | ref_image_list = [tensor2pil_upscale(img,width,height) for img in img_list] 132 | return ref_image_list 133 | 134 | def tensor2pil_upscale(img_tensor, width, height): 135 | samples = img_tensor.movedim(-1, 1) 136 | img = common_upscale(samples, width, height, "nearest-exact", "center") 137 | samples = img.movedim(1, -1) 138 | img_pil = tensor_to_pil(samples) 139 | return img_pil 140 | 141 | 142 | def tensor2cv(tensor_image): 143 | if len(tensor_image.shape)==4:#bhwc to hwc 144 | tensor_image=tensor_image.squeeze(0) 145 | if tensor_image.is_cuda: 146 | tensor_image = tensor_image.cpu().detach() 147 | tensor_image=tensor_image.numpy() 148 | #反归一化 149 | maxValue=tensor_image.max() 150 | tensor_image=tensor_image*255/maxValue 151 | img_cv2=np.uint8(tensor_image)#32 to uint8 152 | img_cv2=cv2.cvtColor(img_cv2,cv2.COLOR_RGB2BGR) 153 | return img_cv2 154 | 155 | def cvargb2tensor(img): 156 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 157 | img = torch.from_numpy(img.transpose((2, 0, 1))) 158 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 159 | 160 | def cv2tensor(img): 161 | assert type(img) == np.ndarray, 'the img type is {}, but ndarry expected'.format(type(img)) 162 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 163 | img = torch.from_numpy(img.transpose((2, 0, 1))) 164 | return img.float().div(255).unsqueeze(0) # 255也可以改为256 165 | 166 | def images_generator(img_list: list,): 167 | #get img size 168 | sizes = {} 169 | for image_ in img_list: 170 | if isinstance(image_,Image.Image): 171 | count = sizes.get(image_.size, 0) 172 | sizes[image_.size] = count + 1 173 | elif isinstance(image_,np.ndarray): 174 | count = sizes.get(image_.shape[:2][::-1], 0) 175 | sizes[image_.shape[:2][::-1]] = count + 1 176 | else: 177 | raise "unsupport image list,must be pil or cv2!!!" 178 | size = max(sizes.items(), key=lambda x: x[1])[0] 179 | yield size[0], size[1] 180 | 181 | # any to tensor 182 | def load_image(img_in): 183 | if isinstance(img_in, Image.Image): 184 | img_in=img_in.convert("RGB") 185 | i = np.array(img_in, dtype=np.float32) 186 | i = torch.from_numpy(i).div_(255) 187 | if i.shape[0] != size[1] or i.shape[1] != size[0]: 188 | i = torch.from_numpy(i).movedim(-1, 0).unsqueeze(0) 189 | i = common_upscale(i, size[0], size[1], "lanczos", "center") 190 | i = i.squeeze(0).movedim(0, -1).numpy() 191 | return i 192 | elif isinstance(img_in,np.ndarray): 193 | i=cv2.cvtColor(img_in,cv2.COLOR_BGR2RGB).astype(np.float32) 194 | i = torch.from_numpy(i).div_(255) 195 | #print(i.shape) 196 | return i 197 | else: 198 | raise "unsupport image list,must be pil,cv2 or tensor!!!" 199 | 200 | total_images = len(img_list) 201 | processed_images = 0 202 | pbar = ProgressBar(total_images) 203 | images = map(load_image, img_list) 204 | try: 205 | prev_image = next(images) 206 | while True: 207 | next_image = next(images) 208 | yield prev_image 209 | processed_images += 1 210 | pbar.update_absolute(processed_images, total_images) 211 | prev_image = next_image 212 | except StopIteration: 213 | pass 214 | if prev_image is not None: 215 | yield prev_image 216 | 217 | def load_images(img_list: list,): 218 | gen = images_generator(img_list) 219 | (width, height) = next(gen) 220 | images = torch.from_numpy(np.fromiter(gen, np.dtype((np.float32, (height, width, 3))))) 221 | if len(images) == 0: 222 | raise FileNotFoundError(f"No images could be loaded .") 223 | return images 224 | 225 | def tensor2pil(tensor): 226 | image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy() 227 | image = Image.fromarray(image_np, mode='RGB') 228 | return image 229 | 230 | def pil2narry(img): 231 | narry = torch.from_numpy(np.array(img).astype(np.float32) / 255.0).unsqueeze(0) 232 | return narry 233 | 234 | def equalize_lists(list1, list2): 235 | """ 236 | 比较两个列表的长度,如果不一致,则将较短的列表复制以匹配较长列表的长度。 237 | 238 | 参数: 239 | list1 (list): 第一个列表 240 | list2 (list): 第二个列表 241 | 242 | 返回: 243 | tuple: 包含两个长度相等的列表的元组 244 | """ 245 | len1 = len(list1) 246 | len2 = len(list2) 247 | 248 | if len1 == len2: 249 | pass 250 | elif len1 < len2: 251 | print("list1 is shorter than list2, copying list1 to match list2's length.") 252 | list1.extend(list1 * ((len2 // len1) + 1)) # 复制list1以匹配list2的长度 253 | list1 = list1[:len2] # 确保长度一致 254 | else: 255 | print("list2 is shorter than list1, copying list2 to match list1's length.") 256 | list2.extend(list2 * ((len1 // len2) + 1)) # 复制list2以匹配list1的长度 257 | list2 = list2[:len1] # 确保长度一致 258 | 259 | return list1, list2 260 | 261 | def file_exists(directory, filename): 262 | # 构建文件的完整路径 263 | file_path = os.path.join(directory, filename) 264 | # 检查文件是否存在 265 | return os.path.isfile(file_path) 266 | 267 | def download_weights(file_dir,repo_id,subfolder="",pt_name=""): 268 | if subfolder: 269 | file_path = os.path.join(file_dir,subfolder, pt_name) 270 | sub_dir=os.path.join(file_dir,subfolder) 271 | if not os.path.exists(sub_dir): 272 | os.makedirs(sub_dir) 273 | if not os.path.exists(file_path): 274 | file_path = hf_hub_download( 275 | repo_id=repo_id, 276 | subfolder=subfolder, 277 | filename=pt_name, 278 | local_dir = file_dir, 279 | ) 280 | return file_path 281 | else: 282 | file_path = os.path.join(file_dir, pt_name) 283 | if not os.path.exists(file_dir): 284 | os.makedirs(file_dir) 285 | if not os.path.exists(file_path): 286 | file_path = hf_hub_download( 287 | repo_id=repo_id, 288 | filename=pt_name, 289 | local_dir=file_dir, 290 | ) 291 | return file_path -------------------------------------------------------------------------------- /propainter/RAFT/__init__.py: -------------------------------------------------------------------------------- 1 | # from .demo import RAFT_infer 2 | from .raft import RAFT 3 | 4 | -------------------------------------------------------------------------------- /propainter/RAFT/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class CorrLayer(torch.autograd.Function): 64 | @staticmethod 65 | def forward(ctx, fmap1, fmap2, coords, r): 66 | fmap1 = fmap1.contiguous() 67 | fmap2 = fmap2.contiguous() 68 | coords = coords.contiguous() 69 | ctx.save_for_backward(fmap1, fmap2, coords) 70 | ctx.r = r 71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) 72 | return corr 73 | 74 | @staticmethod 75 | def backward(ctx, grad_corr): 76 | fmap1, fmap2, coords = ctx.saved_tensors 77 | grad_corr = grad_corr.contiguous() 78 | fmap1_grad, fmap2_grad, coords_grad = \ 79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) 80 | return fmap1_grad, fmap2_grad, coords_grad, None 81 | 82 | 83 | class AlternateCorrBlock: 84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 85 | self.num_levels = num_levels 86 | self.radius = radius 87 | 88 | self.pyramid = [(fmap1, fmap2)] 89 | for i in range(self.num_levels): 90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 92 | self.pyramid.append((fmap1, fmap2)) 93 | 94 | def __call__(self, coords): 95 | 96 | coords = coords.permute(0, 2, 3, 1) 97 | B, H, W, _ = coords.shape 98 | 99 | corr_list = [] 100 | for i in range(self.num_levels): 101 | r = self.radius 102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) 103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) 104 | 105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) 107 | corr_list.append(corr.squeeze(1)) 108 | 109 | corr = torch.stack(corr_list, dim=1) 110 | corr = corr.reshape(B, -1, H, W) 111 | return corr / 16.0 112 | -------------------------------------------------------------------------------- /propainter/RAFT/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from .utils import frame_utils 15 | from .utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | for cam in ['left']: 142 | for direction in ['into_future', 'into_past']: 143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 145 | 146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 148 | 149 | for idir, fdir in zip(image_dirs, flow_dirs): 150 | images = sorted(glob(osp.join(idir, '*.png')) ) 151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 152 | for i in range(len(flows)-1): 153 | if direction == 'into_future': 154 | self.image_list += [ [images[i], images[i+1]] ] 155 | self.flow_list += [ flows[i] ] 156 | elif direction == 'into_past': 157 | self.image_list += [ [images[i+1], images[i]] ] 158 | self.flow_list += [ flows[i+1] ] 159 | 160 | 161 | class KITTI(FlowDataset): 162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 163 | super(KITTI, self).__init__(aug_params, sparse=True) 164 | if split == 'testing': 165 | self.is_test = True 166 | 167 | root = osp.join(root, split) 168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 170 | 171 | for img1, img2 in zip(images1, images2): 172 | frame_id = img1.split('/')[-1] 173 | self.extra_info += [ [frame_id] ] 174 | self.image_list += [ [img1, img2] ] 175 | 176 | if split == 'training': 177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 178 | 179 | 180 | class HD1K(FlowDataset): 181 | def __init__(self, aug_params=None, root='datasets/HD1k'): 182 | super(HD1K, self).__init__(aug_params, sparse=True) 183 | 184 | seq_ix = 0 185 | while 1: 186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 188 | 189 | if len(flows) == 0: 190 | break 191 | 192 | for i in range(len(flows)-1): 193 | self.flow_list += [flows[i]] 194 | self.image_list += [ [images[i], images[i+1]] ] 195 | 196 | seq_ix += 1 197 | 198 | 199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 200 | """ Create the data loader for the corresponding trainign set """ 201 | 202 | if args.stage == 'chairs': 203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 204 | train_dataset = FlyingChairs(aug_params, split='training') 205 | 206 | elif args.stage == 'things': 207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 210 | train_dataset = clean_dataset + final_dataset 211 | 212 | elif args.stage == 'sintel': 213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 217 | 218 | if TRAIN_DS == 'C+T+K+S+H': 219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 222 | 223 | elif TRAIN_DS == 'C+T+K/S': 224 | train_dataset = 100*sintel_clean + 100*sintel_final + things 225 | 226 | elif args.stage == 'kitti': 227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 228 | train_dataset = KITTI(aug_params, split='training') 229 | 230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 232 | 233 | print('Training with %d image pairs' % len(train_dataset)) 234 | return train_loader 235 | 236 | -------------------------------------------------------------------------------- /propainter/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import cv2 5 | import glob 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | from .raft import RAFT 11 | from .utils import flow_viz 12 | from .utils.utils import InputPadder 13 | 14 | 15 | 16 | DEVICE = 'cuda' 17 | 18 | def load_image(imfile): 19 | img = np.array(Image.open(imfile)).astype(np.uint8) 20 | img = torch.from_numpy(img).permute(2, 0, 1).float() 21 | return img 22 | 23 | 24 | def load_image_list(image_files): 25 | images = [] 26 | for imfile in sorted(image_files): 27 | images.append(load_image(imfile)) 28 | 29 | images = torch.stack(images, dim=0) 30 | images = images.to(DEVICE) 31 | 32 | padder = InputPadder(images.shape) 33 | return padder.pad(images)[0] 34 | 35 | 36 | def viz(img, flo): 37 | img = img[0].permute(1,2,0).cpu().numpy() 38 | flo = flo[0].permute(1,2,0).cpu().numpy() 39 | 40 | # map flow to rgb image 41 | flo = flow_viz.flow_to_image(flo) 42 | # img_flo = np.concatenate([img, flo], axis=0) 43 | img_flo = flo 44 | 45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) 46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 47 | # cv2.waitKey() 48 | 49 | 50 | def demo(args): 51 | model = torch.nn.DataParallel(RAFT(args)) 52 | model.load_state_dict(torch.load(args.model)) 53 | 54 | model = model.module 55 | model.to(DEVICE) 56 | model.eval() 57 | 58 | with torch.no_grad(): 59 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 60 | glob.glob(os.path.join(args.path, '*.jpg')) 61 | 62 | images = load_image_list(images) 63 | for i in range(images.shape[0]-1): 64 | image1 = images[i,None] 65 | image2 = images[i+1,None] 66 | 67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 68 | viz(image1, flow_up) 69 | 70 | 71 | def RAFT_infer(args): 72 | model = torch.nn.DataParallel(RAFT(args)) 73 | model.load_state_dict(torch.load(args.model)) 74 | 75 | model = model.module 76 | model.to(DEVICE) 77 | model.eval() 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /propainter/RAFT/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /propainter/RAFT/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock, SmallUpdateBlock 7 | from .extractor import BasicEncoder, SmallEncoder 8 | from .corr import CorrBlock, AlternateCorrBlock 9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in args._get_kwargs(): 42 | args.dropout = 0 43 | 44 | if 'alternate_corr' not in args._get_kwargs(): 45 | args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | 59 | def freeze_bn(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.BatchNorm2d): 62 | m.eval() 63 | 64 | def initialize_flow(self, img): 65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 66 | N, C, H, W = img.shape 67 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 68 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 69 | 70 | # optical flow computed as difference: flow = coords1 - coords0 71 | return coords0, coords1 72 | 73 | def upsample_flow(self, flow, mask): 74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 75 | N, _, H, W = flow.shape 76 | mask = mask.view(N, 1, 9, 8, 8, H, W) 77 | mask = torch.softmax(mask, dim=2) 78 | 79 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 81 | 82 | up_flow = torch.sum(mask * up_flow, dim=2) 83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 84 | return up_flow.reshape(N, 2, 8*H, 8*W) 85 | 86 | 87 | def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True): 88 | """ Estimate optical flow between pair of frames """ 89 | 90 | # image1 = 2 * (image1 / 255.0) - 1.0 91 | # image2 = 2 * (image2 / 255.0) - 1.0 92 | 93 | image1 = image1.contiguous() 94 | image2 = image2.contiguous() 95 | 96 | hdim = self.hidden_dim 97 | cdim = self.context_dim 98 | 99 | # run the feature network 100 | with autocast(enabled=self.args.mixed_precision): 101 | fmap1, fmap2 = self.fnet([image1, image2]) 102 | 103 | fmap1 = fmap1.float() 104 | fmap2 = fmap2.float() 105 | 106 | if self.args.alternate_corr: 107 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | else: 109 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 110 | 111 | # run the context network 112 | with autocast(enabled=self.args.mixed_precision): 113 | cnet = self.cnet(image1) 114 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 115 | net = torch.tanh(net) 116 | inp = torch.relu(inp) 117 | 118 | coords0, coords1 = self.initialize_flow(image1) 119 | 120 | if flow_init is not None: 121 | coords1 = coords1 + flow_init 122 | 123 | flow_predictions = [] 124 | for itr in range(iters): 125 | coords1 = coords1.detach() 126 | corr = corr_fn(coords1) # index correlation volume 127 | 128 | flow = coords1 - coords0 129 | with autocast(enabled=self.args.mixed_precision): 130 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 131 | 132 | # F(t+1) = F(t) + \Delta(t) 133 | coords1 = coords1 + delta_flow 134 | 135 | # upsample predictions 136 | if up_mask is None: 137 | flow_up = upflow8(coords1 - coords0) 138 | else: 139 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 140 | 141 | flow_predictions.append(flow_up) 142 | 143 | if test_mode: 144 | return coords1 - coords0, flow_up 145 | 146 | return flow_predictions 147 | -------------------------------------------------------------------------------- /propainter/RAFT/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /propainter/RAFT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_viz import flow_to_image 2 | from .frame_utils import writeFlow 3 | -------------------------------------------------------------------------------- /propainter/RAFT/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /propainter/RAFT/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /propainter/RAFT/utils/flow_viz_pt.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization 2 | import torch 3 | torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 4 | 5 | @torch.no_grad() 6 | def flow_to_image(flow: torch.Tensor) -> torch.Tensor: 7 | 8 | """ 9 | Converts a flow to an RGB image. 10 | 11 | Args: 12 | flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. 13 | 14 | Returns: 15 | img (Tensor): Image Tensor of dtype uint8 where each color corresponds 16 | to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. 17 | """ 18 | 19 | if flow.dtype != torch.float: 20 | raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") 21 | 22 | orig_shape = flow.shape 23 | if flow.ndim == 3: 24 | flow = flow[None] # Add batch dim 25 | 26 | if flow.ndim != 4 or flow.shape[1] != 2: 27 | raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") 28 | 29 | max_norm = torch.sum(flow**2, dim=1).sqrt().max() 30 | epsilon = torch.finfo((flow).dtype).eps 31 | normalized_flow = flow / (max_norm + epsilon) 32 | img = _normalized_flow_to_image(normalized_flow) 33 | 34 | if len(orig_shape) == 3: 35 | img = img[0] # Remove batch dim 36 | return img 37 | 38 | @torch.no_grad() 39 | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: 40 | 41 | """ 42 | Converts a batch of normalized flow to an RGB image. 43 | 44 | Args: 45 | normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) 46 | Returns: 47 | img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. 48 | """ 49 | 50 | N, _, H, W = normalized_flow.shape 51 | device = normalized_flow.device 52 | flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) 53 | colorwheel = _make_colorwheel().to(device) # shape [55x3] 54 | num_cols = colorwheel.shape[0] 55 | norm = torch.sum(normalized_flow**2, dim=1).sqrt() 56 | a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi 57 | fk = (a + 1) / 2 * (num_cols - 1) 58 | k0 = torch.floor(fk).to(torch.long) 59 | k1 = k0 + 1 60 | k1[k1 == num_cols] = 0 61 | f = fk - k0 62 | 63 | for c in range(colorwheel.shape[1]): 64 | tmp = colorwheel[:, c] 65 | col0 = tmp[k0] / 255.0 66 | col1 = tmp[k1] / 255.0 67 | col = (1 - f) * col0 + f * col1 68 | col = 1 - norm * (1 - col) 69 | flow_image[:, c, :, :] = torch.floor(255. * col) 70 | return flow_image 71 | 72 | 73 | @torch.no_grad() 74 | def _make_colorwheel() -> torch.Tensor: 75 | """ 76 | Generates a color wheel for optical flow visualization as presented in: 77 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 78 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. 79 | 80 | Returns: 81 | colorwheel (Tensor[55, 3]): Colorwheel Tensor. 82 | """ 83 | 84 | RY = 15 85 | YG = 6 86 | GC = 4 87 | CB = 11 88 | BM = 13 89 | MR = 6 90 | 91 | ncols = RY + YG + GC + CB + BM + MR 92 | colorwheel = torch.zeros((ncols, 3)) 93 | col = 0 94 | 95 | # RY 96 | colorwheel[0:RY, 0] = 255 97 | colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY) 98 | col = col + RY 99 | # YG 100 | colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG) 101 | colorwheel[col : col + YG, 1] = 255 102 | col = col + YG 103 | # GC 104 | colorwheel[col : col + GC, 1] = 255 105 | colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC) 106 | col = col + GC 107 | # CB 108 | colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB) 109 | colorwheel[col : col + CB, 2] = 255 110 | col = col + CB 111 | # BM 112 | colorwheel[col : col + BM, 2] = 255 113 | colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM) 114 | col = col + BM 115 | # MR 116 | colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR) 117 | colorwheel[col : col + MR, 0] = 255 118 | return colorwheel 119 | -------------------------------------------------------------------------------- /propainter/RAFT/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /propainter/RAFT/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /propainter/core/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /propainter/core/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | 5 | import cv2 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | import torchvision.transforms as transforms 11 | 12 | from ..utils.file_client import FileClient 13 | from ..utils.img_util import imfrombytes 14 | from ..utils.flow_util import resize_flow, flowread 15 | from ..core.utils import (create_random_shape_with_random_motion, Stack, 16 | ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip) 17 | 18 | 19 | class TrainDataset(torch.utils.data.Dataset): 20 | def __init__(self, args: dict): 21 | self.args = args 22 | self.video_root = args['video_root'] 23 | self.flow_root = args['flow_root'] 24 | self.num_local_frames = args['num_local_frames'] 25 | self.num_ref_frames = args['num_ref_frames'] 26 | self.size = self.w, self.h = (args['w'], args['h']) 27 | 28 | self.load_flow = args['load_flow'] 29 | if self.load_flow: 30 | assert os.path.exists(self.flow_root) 31 | 32 | json_path = os.path.join('./datasets', args['name'], 'train.json') 33 | 34 | with open(json_path, 'r') as f: 35 | self.video_train_dict = json.load(f) 36 | self.video_names = sorted(list(self.video_train_dict.keys())) 37 | 38 | # self.video_names = sorted(os.listdir(self.video_root)) 39 | self.video_dict = {} 40 | self.frame_dict = {} 41 | 42 | for v in self.video_names: 43 | frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) 44 | v_len = len(frame_list) 45 | if v_len > self.num_local_frames + self.num_ref_frames: 46 | self.video_dict[v] = v_len 47 | self.frame_dict[v] = frame_list 48 | 49 | 50 | self.video_names = list(self.video_dict.keys()) # update names 51 | 52 | self._to_tensors = transforms.Compose([ 53 | Stack(), 54 | ToTorchFormatTensor(), 55 | ]) 56 | self.file_client = FileClient('disk') 57 | 58 | def __len__(self): 59 | return len(self.video_names) 60 | 61 | def _sample_index(self, length, sample_length, num_ref_frame=3): 62 | complete_idx_set = list(range(length)) 63 | pivot = random.randint(0, length - sample_length) 64 | local_idx = complete_idx_set[pivot:pivot + sample_length] 65 | remain_idx = list(set(complete_idx_set) - set(local_idx)) 66 | ref_index = sorted(random.sample(remain_idx, num_ref_frame)) 67 | 68 | return local_idx + ref_index 69 | 70 | def __getitem__(self, index): 71 | video_name = self.video_names[index] 72 | # create masks 73 | all_masks = create_random_shape_with_random_motion( 74 | self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w) 75 | 76 | # create sample index 77 | selected_index = self._sample_index(self.video_dict[video_name], 78 | self.num_local_frames, 79 | self.num_ref_frames) 80 | 81 | # read video frames 82 | frames = [] 83 | masks = [] 84 | flows_f, flows_b = [], [] 85 | for idx in selected_index: 86 | frame_list = self.frame_dict[video_name] 87 | img_path = os.path.join(self.video_root, video_name, frame_list[idx]) 88 | img_bytes = self.file_client.get(img_path, 'img') 89 | img = imfrombytes(img_bytes, float32=False) 90 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 91 | img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) 92 | img = Image.fromarray(img) 93 | 94 | frames.append(img) 95 | masks.append(all_masks[idx]) 96 | 97 | if len(frames) <= self.num_local_frames-1 and self.load_flow: 98 | current_n = frame_list[idx][:-4] 99 | next_n = frame_list[idx+1][:-4] 100 | flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') 101 | flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') 102 | flow_f = flowread(flow_f_path, quantize=False) 103 | flow_b = flowread(flow_b_path, quantize=False) 104 | flow_f = resize_flow(flow_f, self.h, self.w) 105 | flow_b = resize_flow(flow_b, self.h, self.w) 106 | flows_f.append(flow_f) 107 | flows_b.append(flow_b) 108 | 109 | if len(frames) == self.num_local_frames: # random reverse 110 | if random.random() < 0.5: 111 | frames.reverse() 112 | masks.reverse() 113 | if self.load_flow: 114 | flows_f.reverse() 115 | flows_b.reverse() 116 | flows_ = flows_f 117 | flows_f = flows_b 118 | flows_b = flows_ 119 | 120 | if self.load_flow: 121 | frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b) 122 | else: 123 | frames = GroupRandomHorizontalFlip()(frames) 124 | 125 | # normalizate, to tensors 126 | frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 127 | mask_tensors = self._to_tensors(masks) 128 | if self.load_flow: 129 | flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 130 | flows_b = np.stack(flows_b, axis=-1) 131 | flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() 132 | flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() 133 | 134 | # img [-1,1] mask [0,1] 135 | if self.load_flow: 136 | return frame_tensors, mask_tensors, flows_f, flows_b, video_name 137 | else: 138 | return frame_tensors, mask_tensors, 'None', 'None', video_name 139 | 140 | 141 | class TestDataset(torch.utils.data.Dataset): 142 | def __init__(self, args): 143 | self.args = args 144 | self.size = self.w, self.h = args['size'] 145 | 146 | self.video_root = args['video_root'] 147 | self.mask_root = args['mask_root'] 148 | self.flow_root = args['flow_root'] 149 | 150 | self.load_flow = args['load_flow'] 151 | if self.load_flow: 152 | assert os.path.exists(self.flow_root) 153 | self.video_names = sorted(os.listdir(self.mask_root)) 154 | 155 | self.video_dict = {} 156 | self.frame_dict = {} 157 | 158 | for v in self.video_names: 159 | frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) 160 | v_len = len(frame_list) 161 | self.video_dict[v] = v_len 162 | self.frame_dict[v] = frame_list 163 | 164 | self._to_tensors = transforms.Compose([ 165 | Stack(), 166 | ToTorchFormatTensor(), 167 | ]) 168 | self.file_client = FileClient('disk') 169 | 170 | def __len__(self): 171 | return len(self.video_names) 172 | 173 | def __getitem__(self, index): 174 | video_name = self.video_names[index] 175 | selected_index = list(range(self.video_dict[video_name])) 176 | 177 | # read video frames 178 | frames = [] 179 | masks = [] 180 | flows_f, flows_b = [], [] 181 | for idx in selected_index: 182 | frame_list = self.frame_dict[video_name] 183 | frame_path = os.path.join(self.video_root, video_name, frame_list[idx]) 184 | 185 | img_bytes = self.file_client.get(frame_path, 'input') 186 | img = imfrombytes(img_bytes, float32=False) 187 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 188 | img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) 189 | img = Image.fromarray(img) 190 | 191 | frames.append(img) 192 | 193 | mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png') 194 | mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L') 195 | 196 | # origin: 0 indicates missing. now: 1 indicates missing 197 | mask = np.asarray(mask) 198 | m = np.array(mask > 0).astype(np.uint8) 199 | 200 | m = cv2.dilate(m, 201 | cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)), 202 | iterations=4) 203 | mask = Image.fromarray(m * 255) 204 | masks.append(mask) 205 | 206 | if len(frames) <= len(selected_index)-1 and self.load_flow: 207 | current_n = frame_list[idx][:-4] 208 | next_n = frame_list[idx+1][:-4] 209 | flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') 210 | flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') 211 | flow_f = flowread(flow_f_path, quantize=False) 212 | flow_b = flowread(flow_b_path, quantize=False) 213 | flow_f = resize_flow(flow_f, self.h, self.w) 214 | flow_b = resize_flow(flow_b, self.h, self.w) 215 | flows_f.append(flow_f) 216 | flows_b.append(flow_b) 217 | 218 | # normalizate, to tensors 219 | frames_PIL = [np.array(f).astype(np.uint8) for f in frames] 220 | frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 221 | mask_tensors = self._to_tensors(masks) 222 | 223 | if self.load_flow: 224 | flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 225 | flows_b = np.stack(flows_b, axis=-1) 226 | flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() 227 | flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() 228 | 229 | if self.load_flow: 230 | return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL 231 | else: 232 | return frame_tensors, mask_tensors, 'None', 'None', video_name -------------------------------------------------------------------------------- /propainter/core/dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def get_world_size(): 6 | """Find OMPI world size without calling mpi functions 7 | :rtype: int 8 | """ 9 | if os.environ.get('PMI_SIZE') is not None: 10 | return int(os.environ.get('PMI_SIZE') or 1) 11 | elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: 12 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) 13 | else: 14 | return torch.cuda.device_count() 15 | 16 | 17 | def get_global_rank(): 18 | """Find OMPI world rank without calling mpi functions 19 | :rtype: int 20 | """ 21 | if os.environ.get('PMI_RANK') is not None: 22 | return int(os.environ.get('PMI_RANK') or 0) 23 | elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: 24 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) 25 | else: 26 | return 0 27 | 28 | 29 | def get_local_rank(): 30 | """Find OMPI local rank without calling mpi functions 31 | :rtype: int 32 | """ 33 | if os.environ.get('MPI_LOCALRANKID') is not None: 34 | return int(os.environ.get('MPI_LOCALRANKID') or 0) 35 | elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: 36 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) 37 | else: 38 | return 0 39 | 40 | 41 | def get_master_ip(): 42 | if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: 43 | return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] 44 | elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: 45 | return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') 46 | else: 47 | return "127.0.0.1" 48 | -------------------------------------------------------------------------------- /propainter/core/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import lpips 4 | from ..model.vgg_arch import VGGFeatureExtractor 5 | 6 | class PerceptualLoss(nn.Module): 7 | """Perceptual loss with commonly used style loss. 8 | 9 | Args: 10 | layer_weights (dict): The weight for each layer of vgg feature. 11 | Here is an example: {'conv5_4': 1.}, which means the conv5_4 12 | feature layer (before relu5_4) will be extracted with weight 13 | 1.0 in calculting losses. 14 | vgg_type (str): The type of vgg network used as feature extractor. 15 | Default: 'vgg19'. 16 | use_input_norm (bool): If True, normalize the input image in vgg. 17 | Default: True. 18 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 19 | Default: False. 20 | perceptual_weight (float): If `perceptual_weight > 0`, the perceptual 21 | loss will be calculated and the loss will multiplied by the 22 | weight. Default: 1.0. 23 | style_weight (float): If `style_weight > 0`, the style loss will be 24 | calculated and the loss will multiplied by the weight. 25 | Default: 0. 26 | criterion (str): Criterion used for perceptual loss. Default: 'l1'. 27 | """ 28 | 29 | def __init__(self, 30 | layer_weights, 31 | vgg_type='vgg19', 32 | use_input_norm=True, 33 | range_norm=False, 34 | perceptual_weight=1.0, 35 | style_weight=0., 36 | criterion='l1'): 37 | super(PerceptualLoss, self).__init__() 38 | self.perceptual_weight = perceptual_weight 39 | self.style_weight = style_weight 40 | self.layer_weights = layer_weights 41 | self.vgg = VGGFeatureExtractor( 42 | layer_name_list=list(layer_weights.keys()), 43 | vgg_type=vgg_type, 44 | use_input_norm=use_input_norm, 45 | range_norm=range_norm) 46 | 47 | self.criterion_type = criterion 48 | if self.criterion_type == 'l1': 49 | self.criterion = torch.nn.L1Loss() 50 | elif self.criterion_type == 'l2': 51 | self.criterion = torch.nn.L2loss() 52 | elif self.criterion_type == 'mse': 53 | self.criterion = torch.nn.MSELoss(reduction='mean') 54 | elif self.criterion_type == 'fro': 55 | self.criterion = None 56 | else: 57 | raise NotImplementedError(f'{criterion} criterion has not been supported.') 58 | 59 | def forward(self, x, gt): 60 | """Forward function. 61 | 62 | Args: 63 | x (Tensor): Input tensor with shape (n, c, h, w). 64 | gt (Tensor): Ground-truth tensor with shape (n, c, h, w). 65 | 66 | Returns: 67 | Tensor: Forward results. 68 | """ 69 | # extract vgg features 70 | x_features = self.vgg(x) 71 | gt_features = self.vgg(gt.detach()) 72 | 73 | # calculate perceptual loss 74 | if self.perceptual_weight > 0: 75 | percep_loss = 0 76 | for k in x_features.keys(): 77 | if self.criterion_type == 'fro': 78 | percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] 79 | else: 80 | percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] 81 | percep_loss *= self.perceptual_weight 82 | else: 83 | percep_loss = None 84 | 85 | # calculate style loss 86 | if self.style_weight > 0: 87 | style_loss = 0 88 | for k in x_features.keys(): 89 | if self.criterion_type == 'fro': 90 | style_loss += torch.norm( 91 | self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] 92 | else: 93 | style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( 94 | gt_features[k])) * self.layer_weights[k] 95 | style_loss *= self.style_weight 96 | else: 97 | style_loss = None 98 | 99 | return percep_loss, style_loss 100 | 101 | def _gram_mat(self, x): 102 | """Calculate Gram matrix. 103 | 104 | Args: 105 | x (torch.Tensor): Tensor with shape of (n, c, h, w). 106 | 107 | Returns: 108 | torch.Tensor: Gram matrix. 109 | """ 110 | n, c, h, w = x.size() 111 | features = x.view(n, c, w * h) 112 | features_t = features.transpose(1, 2) 113 | gram = features.bmm(features_t) / (c * h * w) 114 | return gram 115 | 116 | class LPIPSLoss(nn.Module): 117 | def __init__(self, 118 | loss_weight=1.0, 119 | use_input_norm=True, 120 | range_norm=False,): 121 | super(LPIPSLoss, self).__init__() 122 | self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() 123 | self.loss_weight = loss_weight 124 | self.use_input_norm = use_input_norm 125 | self.range_norm = range_norm 126 | 127 | if self.use_input_norm: 128 | # the mean is for image with range [0, 1] 129 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 130 | # the std is for image with range [0, 1] 131 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 132 | 133 | def forward(self, pred, target): 134 | if self.range_norm: 135 | pred = (pred + 1) / 2 136 | target = (target + 1) / 2 137 | if self.use_input_norm: 138 | pred = (pred - self.mean) / self.std 139 | target = (target - self.mean) / self.std 140 | lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) 141 | return self.loss_weight * lpips_loss.mean(), None 142 | 143 | 144 | class AdversarialLoss(nn.Module): 145 | r""" 146 | Adversarial loss 147 | https://arxiv.org/abs/1711.10337 148 | """ 149 | def __init__(self, 150 | type='nsgan', 151 | target_real_label=1.0, 152 | target_fake_label=0.0): 153 | r""" 154 | type = nsgan | lsgan | hinge 155 | """ 156 | super(AdversarialLoss, self).__init__() 157 | self.type = type 158 | self.register_buffer('real_label', torch.tensor(target_real_label)) 159 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 160 | 161 | if type == 'nsgan': 162 | self.criterion = nn.BCELoss() 163 | elif type == 'lsgan': 164 | self.criterion = nn.MSELoss() 165 | elif type == 'hinge': 166 | self.criterion = nn.ReLU() 167 | 168 | def __call__(self, outputs, is_real, is_disc=None): 169 | if self.type == 'hinge': 170 | if is_disc: 171 | if is_real: 172 | outputs = -outputs 173 | return self.criterion(1 + outputs).mean() 174 | else: 175 | return (-outputs).mean() 176 | else: 177 | labels = (self.real_label 178 | if is_real else self.fake_label).expand_as(outputs) 179 | loss = self.criterion(outputs, labels) 180 | return loss 181 | -------------------------------------------------------------------------------- /propainter/core/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | LR scheduler from BasicSR https://github.com/xinntao/BasicSR 3 | """ 4 | import math 5 | from collections import Counter 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | 9 | class MultiStepRestartLR(_LRScheduler): 10 | """ MultiStep with restarts learning rate scheme. 11 | Args: 12 | optimizer (torch.nn.optimizer): Torch optimizer. 13 | milestones (list): Iterations that will decrease learning rate. 14 | gamma (float): Decrease ratio. Default: 0.1. 15 | restarts (list): Restart iterations. Default: [0]. 16 | restart_weights (list): Restart weights at each restart iteration. 17 | Default: [1]. 18 | last_epoch (int): Used in _LRScheduler. Default: -1. 19 | """ 20 | def __init__(self, 21 | optimizer, 22 | milestones, 23 | gamma=0.1, 24 | restarts=(0, ), 25 | restart_weights=(1, ), 26 | last_epoch=-1): 27 | self.milestones = Counter(milestones) 28 | self.gamma = gamma 29 | self.restarts = restarts 30 | self.restart_weights = restart_weights 31 | assert len(self.restarts) == len( 32 | self.restart_weights), 'restarts and their weights do not match.' 33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self): 36 | if self.last_epoch in self.restarts: 37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 38 | return [ 39 | group['initial_lr'] * weight 40 | for group in self.optimizer.param_groups 41 | ] 42 | if self.last_epoch not in self.milestones: 43 | return [group['lr'] for group in self.optimizer.param_groups] 44 | return [ 45 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 46 | for group in self.optimizer.param_groups 47 | ] 48 | 49 | 50 | def get_position_from_periods(iteration, cumulative_period): 51 | """Get the position from a period list. 52 | It will return the index of the right-closest number in the period list. 53 | For example, the cumulative_period = [100, 200, 300, 400], 54 | if iteration == 50, return 0; 55 | if iteration == 210, return 2; 56 | if iteration == 300, return 2. 57 | Args: 58 | iteration (int): Current iteration. 59 | cumulative_period (list[int]): Cumulative period list. 60 | Returns: 61 | int: The position of the right-closest number in the period list. 62 | """ 63 | for i, period in enumerate(cumulative_period): 64 | if iteration <= period: 65 | return i 66 | 67 | 68 | class CosineAnnealingRestartLR(_LRScheduler): 69 | """ Cosine annealing with restarts learning rate scheme. 70 | An example of config: 71 | periods = [10, 10, 10, 10] 72 | restart_weights = [1, 0.5, 0.5, 0.5] 73 | eta_min=1e-7 74 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 75 | scheduler will restart with the weights in restart_weights. 76 | Args: 77 | optimizer (torch.nn.optimizer): Torch optimizer. 78 | periods (list): Period for each cosine anneling cycle. 79 | restart_weights (list): Restart weights at each restart iteration. 80 | Default: [1]. 81 | eta_min (float): The mimimum lr. Default: 0. 82 | last_epoch (int): Used in _LRScheduler. Default: -1. 83 | """ 84 | def __init__(self, 85 | optimizer, 86 | periods, 87 | restart_weights=(1, ), 88 | eta_min=1e-7, 89 | last_epoch=-1): 90 | self.periods = periods 91 | self.restart_weights = restart_weights 92 | self.eta_min = eta_min 93 | assert (len(self.periods) == len(self.restart_weights) 94 | ), 'periods and restart_weights should have the same length.' 95 | self.cumulative_period = [ 96 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 97 | ] 98 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 99 | 100 | def get_lr(self): 101 | idx = get_position_from_periods(self.last_epoch, 102 | self.cumulative_period) 103 | current_weight = self.restart_weights[idx] 104 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 105 | current_period = self.periods[idx] 106 | 107 | return [ 108 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 109 | (1 + math.cos(math.pi * ( 110 | (self.last_epoch - nearest_restart) / current_period))) 111 | for base_lr in self.base_lrs 112 | ] 113 | -------------------------------------------------------------------------------- /propainter/core/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /propainter/model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /propainter/model/canny/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /propainter/model/canny/canny_filter.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .gaussian import gaussian_blur2d 9 | from .kernels import get_canny_nms_kernel, get_hysteresis_kernel 10 | from .sobel import spatial_gradient 11 | 12 | def rgb_to_grayscale(image, rgb_weights = None): 13 | if len(image.shape) < 3 or image.shape[-3] != 3: 14 | raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") 15 | 16 | if rgb_weights is None: 17 | # 8 bit images 18 | if image.dtype == torch.uint8: 19 | rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8) 20 | # floating point images 21 | elif image.dtype in (torch.float16, torch.float32, torch.float64): 22 | rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype) 23 | else: 24 | raise TypeError(f"Unknown data type: {image.dtype}") 25 | else: 26 | # is tensor that we make sure is in the same device/dtype 27 | rgb_weights = rgb_weights.to(image) 28 | 29 | # unpack the color image channels with RGB order 30 | r = image[..., 0:1, :, :] 31 | g = image[..., 1:2, :, :] 32 | b = image[..., 2:3, :, :] 33 | 34 | w_r, w_g, w_b = rgb_weights.unbind() 35 | return w_r * r + w_g * g + w_b * b 36 | 37 | 38 | def canny( 39 | input: torch.Tensor, 40 | low_threshold: float = 0.1, 41 | high_threshold: float = 0.2, 42 | kernel_size: Tuple[int, int] = (5, 5), 43 | sigma: Tuple[float, float] = (1, 1), 44 | hysteresis: bool = True, 45 | eps: float = 1e-6, 46 | ) -> Tuple[torch.Tensor, torch.Tensor]: 47 | r"""Find edges of the input image and filters them using the Canny algorithm. 48 | 49 | .. image:: _static/img/canny.png 50 | 51 | Args: 52 | input: input image tensor with shape :math:`(B,C,H,W)`. 53 | low_threshold: lower threshold for the hysteresis procedure. 54 | high_threshold: upper threshold for the hysteresis procedure. 55 | kernel_size: the size of the kernel for the gaussian blur. 56 | sigma: the standard deviation of the kernel for the gaussian blur. 57 | hysteresis: if True, applies the hysteresis edge tracking. 58 | Otherwise, the edges are divided between weak (0.5) and strong (1) edges. 59 | eps: regularization number to avoid NaN during backprop. 60 | 61 | Returns: 62 | - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. 63 | - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. 64 | 65 | .. note:: 66 | See a working example `here `__. 68 | 69 | Example: 70 | >>> input = torch.rand(5, 3, 4, 4) 71 | >>> magnitude, edges = canny(input) # 5x3x4x4 72 | >>> magnitude.shape 73 | torch.Size([5, 1, 4, 4]) 74 | >>> edges.shape 75 | torch.Size([5, 1, 4, 4]) 76 | """ 77 | if not isinstance(input, torch.Tensor): 78 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") 79 | 80 | if not len(input.shape) == 4: 81 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") 82 | 83 | if low_threshold > high_threshold: 84 | raise ValueError( 85 | "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format( 86 | low_threshold, high_threshold 87 | ) 88 | ) 89 | 90 | if low_threshold < 0 and low_threshold > 1: 91 | raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") 92 | 93 | if high_threshold < 0 and high_threshold > 1: 94 | raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") 95 | 96 | device: torch.device = input.device 97 | dtype: torch.dtype = input.dtype 98 | 99 | # To Grayscale 100 | if input.shape[1] == 3: 101 | input = rgb_to_grayscale(input) 102 | 103 | # Gaussian filter 104 | blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma) 105 | 106 | # Compute the gradients 107 | gradients: torch.Tensor = spatial_gradient(blurred, normalized=False) 108 | 109 | # Unpack the edges 110 | gx: torch.Tensor = gradients[:, :, 0] 111 | gy: torch.Tensor = gradients[:, :, 1] 112 | 113 | # Compute gradient magnitude and angle 114 | magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) 115 | angle: torch.Tensor = torch.atan2(gy, gx) 116 | 117 | # Radians to Degrees 118 | angle = 180.0 * angle / math.pi 119 | 120 | # Round angle to the nearest 45 degree 121 | angle = torch.round(angle / 45) * 45 122 | 123 | # Non-maximal suppression 124 | nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype) 125 | nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2) 126 | 127 | # Get the indices for both directions 128 | positive_idx: torch.Tensor = (angle / 45) % 8 129 | positive_idx = positive_idx.long() 130 | 131 | negative_idx: torch.Tensor = ((angle / 45) + 4) % 8 132 | negative_idx = negative_idx.long() 133 | 134 | # Apply the non-maximum suppression to the different directions 135 | channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx) 136 | channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx) 137 | 138 | channel_select_filtered: torch.Tensor = torch.stack( 139 | [channel_select_filtered_positive, channel_select_filtered_negative], 1 140 | ) 141 | 142 | is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 143 | 144 | magnitude = magnitude * is_max 145 | 146 | # Threshold 147 | edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0) 148 | 149 | low: torch.Tensor = magnitude > low_threshold 150 | high: torch.Tensor = magnitude > high_threshold 151 | 152 | edges = low * 0.5 + high * 0.5 153 | edges = edges.to(dtype) 154 | 155 | # Hysteresis 156 | if hysteresis: 157 | edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype) 158 | hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype) 159 | 160 | while ((edges_old - edges).abs() != 0).any(): 161 | weak: torch.Tensor = (edges == 0.5).float() 162 | strong: torch.Tensor = (edges == 1).float() 163 | 164 | hysteresis_magnitude: torch.Tensor = F.conv2d( 165 | edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 166 | ) 167 | hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) 168 | hysteresis_magnitude = hysteresis_magnitude * weak + strong 169 | 170 | edges_old = edges.clone() 171 | edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 172 | 173 | edges = hysteresis_magnitude 174 | 175 | return magnitude, edges 176 | 177 | 178 | class Canny(nn.Module): 179 | r"""Module that finds edges of the input image and filters them using the Canny algorithm. 180 | 181 | Args: 182 | input: input image tensor with shape :math:`(B,C,H,W)`. 183 | low_threshold: lower threshold for the hysteresis procedure. 184 | high_threshold: upper threshold for the hysteresis procedure. 185 | kernel_size: the size of the kernel for the gaussian blur. 186 | sigma: the standard deviation of the kernel for the gaussian blur. 187 | hysteresis: if True, applies the hysteresis edge tracking. 188 | Otherwise, the edges are divided between weak (0.5) and strong (1) edges. 189 | eps: regularization number to avoid NaN during backprop. 190 | 191 | Returns: 192 | - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. 193 | - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. 194 | 195 | Example: 196 | >>> input = torch.rand(5, 3, 4, 4) 197 | >>> magnitude, edges = Canny()(input) # 5x3x4x4 198 | >>> magnitude.shape 199 | torch.Size([5, 1, 4, 4]) 200 | >>> edges.shape 201 | torch.Size([5, 1, 4, 4]) 202 | """ 203 | 204 | def __init__( 205 | self, 206 | low_threshold: float = 0.1, 207 | high_threshold: float = 0.2, 208 | kernel_size: Tuple[int, int] = (5, 5), 209 | sigma: Tuple[float, float] = (1, 1), 210 | hysteresis: bool = True, 211 | eps: float = 1e-6, 212 | ) -> None: 213 | super().__init__() 214 | 215 | if low_threshold > high_threshold: 216 | raise ValueError( 217 | "Invalid input thresholds. low_threshold should be\ 218 | smaller than the high_threshold. Got: {}>{}".format( 219 | low_threshold, high_threshold 220 | ) 221 | ) 222 | 223 | if low_threshold < 0 or low_threshold > 1: 224 | raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") 225 | 226 | if high_threshold < 0 or high_threshold > 1: 227 | raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") 228 | 229 | # Gaussian blur parameters 230 | self.kernel_size = kernel_size 231 | self.sigma = sigma 232 | 233 | # Double threshold 234 | self.low_threshold = low_threshold 235 | self.high_threshold = high_threshold 236 | 237 | # Hysteresis 238 | self.hysteresis = hysteresis 239 | 240 | self.eps: float = eps 241 | 242 | def __repr__(self) -> str: 243 | return ''.join( 244 | ( 245 | f'{type(self).__name__}(', 246 | ', '.join( 247 | f'{name}={getattr(self, name)}' for name in sorted(self.__dict__) if not name.startswith('_') 248 | ), 249 | ')', 250 | ) 251 | ) 252 | 253 | def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 254 | return canny( 255 | input, self.low_threshold, self.high_threshold, self.kernel_size, self.sigma, self.hysteresis, self.eps 256 | ) -------------------------------------------------------------------------------- /propainter/model/canny/gaussian.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .filter import filter2d, filter2d_separable 7 | from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d 8 | 9 | 10 | def gaussian_blur2d( 11 | input: torch.Tensor, 12 | kernel_size: Tuple[int, int], 13 | sigma: Tuple[float, float], 14 | border_type: str = 'reflect', 15 | separable: bool = True, 16 | ) -> torch.Tensor: 17 | r"""Create an operator that blurs a tensor using a Gaussian filter. 18 | 19 | .. image:: _static/img/gaussian_blur2d.png 20 | 21 | The operator smooths the given tensor with a gaussian kernel by convolving 22 | it to each channel. It supports batched operation. 23 | 24 | Arguments: 25 | input: the input tensor with shape :math:`(B,C,H,W)`. 26 | kernel_size: the size of the kernel. 27 | sigma: the standard deviation of the kernel. 28 | border_type: the padding mode to be applied before convolving. 29 | The expected modes are: ``'constant'``, ``'reflect'``, 30 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 31 | separable: run as composition of two 1d-convolutions. 32 | 33 | Returns: 34 | the blurred tensor with shape :math:`(B, C, H, W)`. 35 | 36 | .. note:: 37 | See a working example `here `__. 39 | 40 | Examples: 41 | >>> input = torch.rand(2, 4, 5, 5) 42 | >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5)) 43 | >>> output.shape 44 | torch.Size([2, 4, 5, 5]) 45 | """ 46 | if separable: 47 | kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1]) 48 | kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0]) 49 | out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type) 50 | else: 51 | kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma) 52 | out = filter2d(input, kernel[None], border_type) 53 | return out 54 | 55 | 56 | class GaussianBlur2d(nn.Module): 57 | r"""Create an operator that blurs a tensor using a Gaussian filter. 58 | 59 | The operator smooths the given tensor with a gaussian kernel by convolving 60 | it to each channel. It supports batched operation. 61 | 62 | Arguments: 63 | kernel_size: the size of the kernel. 64 | sigma: the standard deviation of the kernel. 65 | border_type: the padding mode to be applied before convolving. 66 | The expected modes are: ``'constant'``, ``'reflect'``, 67 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 68 | separable: run as composition of two 1d-convolutions. 69 | 70 | Returns: 71 | the blurred tensor. 72 | 73 | Shape: 74 | - Input: :math:`(B, C, H, W)` 75 | - Output: :math:`(B, C, H, W)` 76 | 77 | Examples:: 78 | 79 | >>> input = torch.rand(2, 4, 5, 5) 80 | >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5)) 81 | >>> output = gauss(input) # 2x4x5x5 82 | >>> output.shape 83 | torch.Size([2, 4, 5, 5]) 84 | """ 85 | 86 | def __init__( 87 | self, 88 | kernel_size: Tuple[int, int], 89 | sigma: Tuple[float, float], 90 | border_type: str = 'reflect', 91 | separable: bool = True, 92 | ) -> None: 93 | super().__init__() 94 | self.kernel_size: Tuple[int, int] = kernel_size 95 | self.sigma: Tuple[float, float] = sigma 96 | self.border_type = border_type 97 | self.separable = separable 98 | 99 | def __repr__(self) -> str: 100 | return ( 101 | self.__class__.__name__ 102 | + '(kernel_size=' 103 | + str(self.kernel_size) 104 | + ', ' 105 | + 'sigma=' 106 | + str(self.sigma) 107 | + ', ' 108 | + 'border_type=' 109 | + self.border_type 110 | + 'separable=' 111 | + str(self.separable) 112 | + ')' 113 | ) 114 | 115 | def forward(self, input: torch.Tensor) -> torch.Tensor: 116 | return gaussian_blur2d(input, self.kernel_size, self.sigma, self.border_type, self.separable) -------------------------------------------------------------------------------- /propainter/model/canny/sobel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .kernels import get_spatial_gradient_kernel2d, get_spatial_gradient_kernel3d, normalize_kernel2d 6 | 7 | 8 | def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> torch.Tensor: 9 | r"""Compute the first order image derivative in both x and y using a Sobel operator. 10 | 11 | .. image:: _static/img/spatial_gradient.png 12 | 13 | Args: 14 | input: input image tensor with shape :math:`(B, C, H, W)`. 15 | mode: derivatives modality, can be: `sobel` or `diff`. 16 | order: the order of the derivatives. 17 | normalized: whether the output is normalized. 18 | 19 | Return: 20 | the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. 21 | 22 | .. note:: 23 | See a working example `here `__. 25 | 26 | Examples: 27 | >>> input = torch.rand(1, 3, 4, 4) 28 | >>> output = spatial_gradient(input) # 1x3x2x4x4 29 | >>> output.shape 30 | torch.Size([1, 3, 2, 4, 4]) 31 | """ 32 | if not isinstance(input, torch.Tensor): 33 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") 34 | 35 | if not len(input.shape) == 4: 36 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") 37 | # allocate kernel 38 | kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order) 39 | if normalized: 40 | kernel = normalize_kernel2d(kernel) 41 | 42 | # prepare kernel 43 | b, c, h, w = input.shape 44 | tmp_kernel: torch.Tensor = kernel.to(input).detach() 45 | tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) 46 | 47 | # convolve input tensor with sobel kernel 48 | kernel_flip: torch.Tensor = tmp_kernel.flip(-3) 49 | 50 | # Pad with "replicate for spatial dims, but with zeros for channel 51 | spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] 52 | out_channels: int = 3 if order == 2 else 2 53 | padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')[:, :, None] 54 | 55 | return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w) 56 | 57 | 58 | def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: int = 1) -> torch.Tensor: 59 | r"""Compute the first and second order volume derivative in x, y and d using a diff operator. 60 | 61 | Args: 62 | input: input features tensor with shape :math:`(B, C, D, H, W)`. 63 | mode: derivatives modality, can be: `sobel` or `diff`. 64 | order: the order of the derivatives. 65 | 66 | Return: 67 | the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)` 68 | or :math:`(B, C, 6, D, H, W)`. 69 | 70 | Examples: 71 | >>> input = torch.rand(1, 4, 2, 4, 4) 72 | >>> output = spatial_gradient3d(input) 73 | >>> output.shape 74 | torch.Size([1, 4, 3, 2, 4, 4]) 75 | """ 76 | if not isinstance(input, torch.Tensor): 77 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") 78 | 79 | if not len(input.shape) == 5: 80 | raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") 81 | b, c, d, h, w = input.shape 82 | dev = input.device 83 | dtype = input.dtype 84 | if (mode == 'diff') and (order == 1): 85 | # we go for the special case implementation due to conv3d bad speed 86 | x: torch.Tensor = F.pad(input, 6 * [1], 'replicate') 87 | center = slice(1, -1) 88 | left = slice(0, -2) 89 | right = slice(2, None) 90 | out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype) 91 | out[..., 0, :, :, :] = x[..., center, center, right] - x[..., center, center, left] 92 | out[..., 1, :, :, :] = x[..., center, right, center] - x[..., center, left, center] 93 | out[..., 2, :, :, :] = x[..., right, center, center] - x[..., left, center, center] 94 | out = 0.5 * out 95 | else: 96 | # prepare kernel 97 | # allocate kernel 98 | kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order) 99 | 100 | tmp_kernel: torch.Tensor = kernel.to(input).detach() 101 | tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1) 102 | 103 | # convolve input tensor with grad kernel 104 | kernel_flip: torch.Tensor = tmp_kernel.flip(-3) 105 | 106 | # Pad with "replicate for spatial dims, but with zeros for channel 107 | spatial_pad = [ 108 | kernel.size(2) // 2, 109 | kernel.size(2) // 2, 110 | kernel.size(3) // 2, 111 | kernel.size(3) // 2, 112 | kernel.size(4) // 2, 113 | kernel.size(4) // 2, 114 | ] 115 | out_ch: int = 6 if order == 2 else 3 116 | out = F.conv3d(F.pad(input, spatial_pad, 'replicate'), kernel_flip, padding=0, groups=c).view( 117 | b, c, out_ch, d, h, w 118 | ) 119 | return out 120 | 121 | 122 | def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-6) -> torch.Tensor: 123 | r"""Compute the Sobel operator and returns the magnitude per channel. 124 | 125 | .. image:: _static/img/sobel.png 126 | 127 | Args: 128 | input: the input image with shape :math:`(B,C,H,W)`. 129 | normalized: if True, L1 norm of the kernel is set to 1. 130 | eps: regularization number to avoid NaN during backprop. 131 | 132 | Return: 133 | the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`. 134 | 135 | .. note:: 136 | See a working example `here `__. 138 | 139 | Example: 140 | >>> input = torch.rand(1, 3, 4, 4) 141 | >>> output = sobel(input) # 1x3x4x4 142 | >>> output.shape 143 | torch.Size([1, 3, 4, 4]) 144 | """ 145 | if not isinstance(input, torch.Tensor): 146 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") 147 | 148 | if not len(input.shape) == 4: 149 | raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") 150 | 151 | # comput the x/y gradients 152 | edges: torch.Tensor = spatial_gradient(input, normalized=normalized) 153 | 154 | # unpack the edges 155 | gx: torch.Tensor = edges[:, :, 0] 156 | gy: torch.Tensor = edges[:, :, 1] 157 | 158 | # compute gradient maginitude 159 | magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) 160 | 161 | return magnitude 162 | 163 | 164 | class SpatialGradient(nn.Module): 165 | r"""Compute the first order image derivative in both x and y using a Sobel operator. 166 | 167 | Args: 168 | mode: derivatives modality, can be: `sobel` or `diff`. 169 | order: the order of the derivatives. 170 | normalized: whether the output is normalized. 171 | 172 | Return: 173 | the sobel edges of the input feature map. 174 | 175 | Shape: 176 | - Input: :math:`(B, C, H, W)` 177 | - Output: :math:`(B, C, 2, H, W)` 178 | 179 | Examples: 180 | >>> input = torch.rand(1, 3, 4, 4) 181 | >>> output = SpatialGradient()(input) # 1x3x2x4x4 182 | """ 183 | 184 | def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> None: 185 | super().__init__() 186 | self.normalized: bool = normalized 187 | self.order: int = order 188 | self.mode: str = mode 189 | 190 | def __repr__(self) -> str: 191 | return ( 192 | self.__class__.__name__ + '(' 193 | 'order=' + str(self.order) + ', ' + 'normalized=' + str(self.normalized) + ', ' + 'mode=' + self.mode + ')' 194 | ) 195 | 196 | def forward(self, input: torch.Tensor) -> torch.Tensor: 197 | return spatial_gradient(input, self.mode, self.order, self.normalized) 198 | 199 | 200 | class SpatialGradient3d(nn.Module): 201 | r"""Compute the first and second order volume derivative in x, y and d using a diff operator. 202 | 203 | Args: 204 | mode: derivatives modality, can be: `sobel` or `diff`. 205 | order: the order of the derivatives. 206 | 207 | Return: 208 | the spatial gradients of the input feature map. 209 | 210 | Shape: 211 | - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them. 212 | - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)` 213 | 214 | Examples: 215 | >>> input = torch.rand(1, 4, 2, 4, 4) 216 | >>> output = SpatialGradient3d()(input) 217 | >>> output.shape 218 | torch.Size([1, 4, 3, 2, 4, 4]) 219 | """ 220 | 221 | def __init__(self, mode: str = 'diff', order: int = 1) -> None: 222 | super().__init__() 223 | self.order: int = order 224 | self.mode: str = mode 225 | self.kernel = get_spatial_gradient_kernel3d(mode, order) 226 | return 227 | 228 | def __repr__(self) -> str: 229 | return self.__class__.__name__ + '(' 'order=' + str(self.order) + ', ' + 'mode=' + self.mode + ')' 230 | 231 | def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore 232 | return spatial_gradient3d(input, self.mode, self.order) 233 | 234 | 235 | class Sobel(nn.Module): 236 | r"""Compute the Sobel operator and returns the magnitude per channel. 237 | 238 | Args: 239 | normalized: if True, L1 norm of the kernel is set to 1. 240 | eps: regularization number to avoid NaN during backprop. 241 | 242 | Return: 243 | the sobel edge gradient magnitudes map. 244 | 245 | Shape: 246 | - Input: :math:`(B, C, H, W)` 247 | - Output: :math:`(B, C, H, W)` 248 | 249 | Examples: 250 | >>> input = torch.rand(1, 3, 4, 4) 251 | >>> output = Sobel()(input) # 1x3x4x4 252 | """ 253 | 254 | def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None: 255 | super().__init__() 256 | self.normalized: bool = normalized 257 | self.eps: float = eps 258 | 259 | def __repr__(self) -> str: 260 | return self.__class__.__name__ + '(' 'normalized=' + str(self.normalized) + ')' 261 | 262 | def forward(self, input: torch.Tensor) -> torch.Tensor: 263 | return sobel(input, self.normalized, self.eps) -------------------------------------------------------------------------------- /propainter/model/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import logging 8 | import numpy as np 9 | from os import path as osp 10 | from packaging import version 11 | 12 | def constant_init(module, val, bias=0): 13 | if hasattr(module, 'weight') and module.weight is not None: 14 | nn.init.constant_(module.weight, val) 15 | if hasattr(module, 'bias') and module.bias is not None: 16 | nn.init.constant_(module.bias, bias) 17 | 18 | initialized_logger = {} 19 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 20 | """Get the root logger. 21 | The logger will be initialized if it has not been initialized. By default a 22 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 23 | also be added. 24 | Args: 25 | logger_name (str): root logger name. Default: 'basicsr'. 26 | log_file (str | None): The log filename. If specified, a FileHandler 27 | will be added to the root logger. 28 | log_level (int): The root logger level. Note that only the process of 29 | rank 0 is affected, while other processes will set the level to 30 | "Error" and be silent most of the time. 31 | Returns: 32 | logging.Logger: The root logger. 33 | """ 34 | logger = logging.getLogger(logger_name) 35 | # if the logger has been initialized, just return it 36 | if logger_name in initialized_logger: 37 | return logger 38 | 39 | format_str = '%(asctime)s %(levelname)s: %(message)s' 40 | stream_handler = logging.StreamHandler() 41 | stream_handler.setFormatter(logging.Formatter(format_str)) 42 | logger.addHandler(stream_handler) 43 | logger.propagate = False 44 | 45 | if log_file is not None: 46 | logger.setLevel(log_level) 47 | # add file handler 48 | # file_handler = logging.FileHandler(log_file, 'w') 49 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log 50 | file_handler.setFormatter(logging.Formatter(format_str)) 51 | file_handler.setLevel(log_level) 52 | logger.addHandler(file_handler) 53 | initialized_logger[logger_name] = True 54 | return logger 55 | 56 | required_version = version.parse("1.12.0") 57 | current_version = version.parse(torch.__version__) 58 | 59 | IS_HIGH_VERSION = current_version >= required_version 60 | 61 | def gpu_is_available(): 62 | if IS_HIGH_VERSION: 63 | if torch.backends.mps.is_available(): 64 | return True 65 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False 66 | 67 | def get_device(gpu_id=None): 68 | if gpu_id is None: 69 | gpu_str = '' 70 | elif isinstance(gpu_id, int): 71 | gpu_str = f':{gpu_id}' 72 | else: 73 | raise TypeError('Input should be int value.') 74 | 75 | if IS_HIGH_VERSION: 76 | if torch.backends.mps.is_available(): 77 | return torch.device('mps'+gpu_str) 78 | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') 79 | 80 | 81 | def set_random_seed(seed): 82 | """Set random seeds.""" 83 | random.seed(seed) 84 | np.random.seed(seed) 85 | torch.manual_seed(seed) 86 | torch.cuda.manual_seed(seed) 87 | torch.cuda.manual_seed_all(seed) 88 | 89 | 90 | def get_time_str(): 91 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 92 | 93 | 94 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 95 | """Scan a directory to find the interested files. 96 | 97 | Args: 98 | dir_path (str): Path of the directory. 99 | suffix (str | tuple(str), optional): File suffix that we are 100 | interested in. Default: None. 101 | recursive (bool, optional): If set to True, recursively scan the 102 | directory. Default: False. 103 | full_path (bool, optional): If set to True, include the dir_path. 104 | Default: False. 105 | 106 | Returns: 107 | A generator for all the interested files with relative pathes. 108 | """ 109 | 110 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 111 | raise TypeError('"suffix" must be a string or tuple of strings') 112 | 113 | root = dir_path 114 | 115 | def _scandir(dir_path, suffix, recursive): 116 | for entry in os.scandir(dir_path): 117 | if not entry.name.startswith('.') and entry.is_file(): 118 | if full_path: 119 | return_path = entry.path 120 | else: 121 | return_path = osp.relpath(entry.path, root) 122 | 123 | if suffix is None: 124 | yield return_path 125 | elif return_path.endswith(suffix): 126 | yield return_path 127 | else: 128 | if recursive: 129 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 130 | else: 131 | continue 132 | 133 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 134 | -------------------------------------------------------------------------------- /propainter/model/modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /propainter/model/modules/base_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from functools import reduce 6 | 7 | class BaseNetwork(nn.Module): 8 | def __init__(self): 9 | super(BaseNetwork, self).__init__() 10 | 11 | def print_network(self): 12 | if isinstance(self, list): 13 | self = self[0] 14 | num_params = 0 15 | for param in self.parameters(): 16 | num_params += param.numel() 17 | print( 18 | 'Network [%s] was created. Total number of parameters: %.1f million. ' 19 | 'To see the architecture, do print(network).' % 20 | (type(self).__name__, num_params / 1000000)) 21 | 22 | def init_weights(self, init_type='normal', gain=0.02): 23 | ''' 24 | initialize network's weights 25 | init_type: normal | xavier | kaiming | orthogonal 26 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 27 | ''' 28 | def init_func(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('InstanceNorm2d') != -1: 31 | if hasattr(m, 'weight') and m.weight is not None: 32 | nn.init.constant_(m.weight.data, 1.0) 33 | if hasattr(m, 'bias') and m.bias is not None: 34 | nn.init.constant_(m.bias.data, 0.0) 35 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 36 | or classname.find('Linear') != -1): 37 | if init_type == 'normal': 38 | nn.init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | nn.init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'xavier_uniform': 42 | nn.init.xavier_uniform_(m.weight.data, gain=1.0) 43 | elif init_type == 'kaiming': 44 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif init_type == 'orthogonal': 46 | nn.init.orthogonal_(m.weight.data, gain=gain) 47 | elif init_type == 'none': # uses pytorch's default init method 48 | m.reset_parameters() 49 | else: 50 | raise NotImplementedError( 51 | 'initialization method [%s] is not implemented' % 52 | init_type) 53 | if hasattr(m, 'bias') and m.bias is not None: 54 | nn.init.constant_(m.bias.data, 0.0) 55 | 56 | self.apply(init_func) 57 | 58 | # propagate to children 59 | for m in self.children(): 60 | if hasattr(m, 'init_weights'): 61 | m.init_weights(init_type, gain) 62 | 63 | 64 | class Vec2Feat(nn.Module): 65 | def __init__(self, channel, hidden, kernel_size, stride, padding): 66 | super(Vec2Feat, self).__init__() 67 | self.relu = nn.LeakyReLU(0.2, inplace=True) 68 | c_out = reduce((lambda x, y: x * y), kernel_size) * channel 69 | self.embedding = nn.Linear(hidden, c_out) 70 | self.kernel_size = kernel_size 71 | self.stride = stride 72 | self.padding = padding 73 | self.bias_conv = nn.Conv2d(channel, 74 | channel, 75 | kernel_size=3, 76 | stride=1, 77 | padding=1) 78 | 79 | def forward(self, x, t, output_size): 80 | b_, _, _, _, c_ = x.shape 81 | x = x.view(b_, -1, c_) 82 | feat = self.embedding(x) 83 | b, _, c = feat.size() 84 | feat = feat.view(b * t, -1, c).permute(0, 2, 1) 85 | feat = F.fold(feat, 86 | output_size=output_size, 87 | kernel_size=self.kernel_size, 88 | stride=self.stride, 89 | padding=self.padding) 90 | feat = self.bias_conv(feat) 91 | return feat 92 | 93 | 94 | class FusionFeedForward(nn.Module): 95 | def __init__(self, dim, hidden_dim=1960, t2t_params=None): 96 | super(FusionFeedForward, self).__init__() 97 | # We set hidden_dim as a default to 1960 98 | self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim)) 99 | self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim)) 100 | assert t2t_params is not None 101 | self.t2t_params = t2t_params 102 | self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49 103 | 104 | def forward(self, x, output_size): 105 | n_vecs = 1 106 | for i, d in enumerate(self.t2t_params['kernel_size']): 107 | n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] - 108 | (d - 1) - 1) / self.t2t_params['stride'][i] + 1) 109 | 110 | x = self.fc1(x) 111 | b, n, c = x.size() 112 | normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1) 113 | normalizer = F.fold(normalizer, 114 | output_size=output_size, 115 | kernel_size=self.t2t_params['kernel_size'], 116 | padding=self.t2t_params['padding'], 117 | stride=self.t2t_params['stride']) 118 | 119 | x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), 120 | output_size=output_size, 121 | kernel_size=self.t2t_params['kernel_size'], 122 | padding=self.t2t_params['padding'], 123 | stride=self.t2t_params['stride']) 124 | 125 | x = F.unfold(x / normalizer, 126 | kernel_size=self.t2t_params['kernel_size'], 127 | padding=self.t2t_params['padding'], 128 | stride=self.t2t_params['stride']).permute( 129 | 0, 2, 1).contiguous().view(b, n, c) 130 | x = self.fc2(x) 131 | return x 132 | -------------------------------------------------------------------------------- /propainter/model/modules/deformconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init as init 4 | from torch.nn.modules.utils import _pair, _single 5 | import math 6 | 7 | class ModulatedDeformConv2d(nn.Module): 8 | def __init__(self, 9 | in_channels, 10 | out_channels, 11 | kernel_size, 12 | stride=1, 13 | padding=0, 14 | dilation=1, 15 | groups=1, 16 | deform_groups=1, 17 | bias=True): 18 | super(ModulatedDeformConv2d, self).__init__() 19 | 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _pair(kernel_size) 23 | self.stride = stride 24 | self.padding = padding 25 | self.dilation = dilation 26 | self.groups = groups 27 | self.deform_groups = deform_groups 28 | self.with_bias = bias 29 | # enable compatibility with nn.Conv2d 30 | self.transposed = False 31 | self.output_padding = _single(0) 32 | 33 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 34 | if bias: 35 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 36 | else: 37 | self.register_parameter('bias', None) 38 | self.init_weights() 39 | 40 | def init_weights(self): 41 | n = self.in_channels 42 | for k in self.kernel_size: 43 | n *= k 44 | stdv = 1. / math.sqrt(n) 45 | self.weight.data.uniform_(-stdv, stdv) 46 | if self.bias is not None: 47 | self.bias.data.zero_() 48 | 49 | if hasattr(self, 'conv_offset'): 50 | self.conv_offset.weight.data.zero_() 51 | self.conv_offset.bias.data.zero_() 52 | 53 | def forward(self, x, offset, mask): 54 | pass -------------------------------------------------------------------------------- /propainter/model/modules/flow_comp_raft.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | from ...RAFT import RAFT 8 | from .flow_loss_utils import flow_warp, ternary_loss2 9 | # except: 10 | # from propainter.RAFT import RAFT 11 | # from propainter.model.modules.flow_loss_utils import flow_warp, ternary_loss2 12 | 13 | 14 | 15 | def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'): 16 | """Initializes the RAFT model. 17 | """ 18 | args = argparse.ArgumentParser() 19 | args.raft_model = model_path 20 | args.small = False 21 | args.mixed_precision = False 22 | args.alternate_corr = False 23 | model = torch.nn.DataParallel(RAFT(args)) 24 | model.load_state_dict(torch.load(args.raft_model, map_location='cpu')) 25 | model = model.module 26 | 27 | model.to(device) 28 | 29 | return model 30 | 31 | 32 | class RAFT_bi(nn.Module): 33 | """Flow completion loss""" 34 | def __init__(self, model_path='weights/raft-things.pth', device='cuda'): 35 | super().__init__() 36 | self.fix_raft = initialize_RAFT(model_path, device=device) 37 | 38 | for p in self.fix_raft.parameters(): 39 | p.requires_grad = False 40 | 41 | self.l1_criterion = nn.L1Loss() 42 | self.eval() 43 | 44 | def forward(self, gt_local_frames, iters=20): 45 | b, l_t, c, h, w = gt_local_frames.size() 46 | # print(gt_local_frames.shape) 47 | 48 | with torch.no_grad(): 49 | gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(-1, c, h, w) 50 | gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(-1, c, h, w) 51 | # print(gtlf_1.shape) 52 | 53 | _, gt_flows_forward = self.fix_raft(gtlf_1, gtlf_2, iters=iters, test_mode=True) 54 | _, gt_flows_backward = self.fix_raft(gtlf_2, gtlf_1, iters=iters, test_mode=True) 55 | 56 | 57 | gt_flows_forward = gt_flows_forward.view(b, l_t-1, 2, h, w) 58 | gt_flows_backward = gt_flows_backward.view(b, l_t-1, 2, h, w) 59 | 60 | return gt_flows_forward, gt_flows_backward 61 | 62 | 63 | ################################################################################## 64 | def smoothness_loss(flow, cmask): 65 | delta_u, delta_v, mask = smoothness_deltas(flow) 66 | loss_u = charbonnier_loss(delta_u, cmask) 67 | loss_v = charbonnier_loss(delta_v, cmask) 68 | return loss_u + loss_v 69 | 70 | 71 | def smoothness_deltas(flow): 72 | """ 73 | flow: [b, c, h, w] 74 | """ 75 | mask_x = create_mask(flow, [[0, 0], [0, 1]]) 76 | mask_y = create_mask(flow, [[0, 1], [0, 0]]) 77 | mask = torch.cat((mask_x, mask_y), dim=1) 78 | mask = mask.to(flow.device) 79 | filter_x = torch.tensor([[0, 0, 0.], [0, 1, -1], [0, 0, 0]]) 80 | filter_y = torch.tensor([[0, 0, 0.], [0, 1, 0], [0, -1, 0]]) 81 | weights = torch.ones([2, 1, 3, 3]) 82 | weights[0, 0] = filter_x 83 | weights[1, 0] = filter_y 84 | weights = weights.to(flow.device) 85 | 86 | flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) 87 | delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) 88 | delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) 89 | return delta_u, delta_v, mask 90 | 91 | 92 | def second_order_loss(flow, cmask): 93 | delta_u, delta_v, mask = second_order_deltas(flow) 94 | loss_u = charbonnier_loss(delta_u, cmask) 95 | loss_v = charbonnier_loss(delta_v, cmask) 96 | return loss_u + loss_v 97 | 98 | 99 | def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001): 100 | """ 101 | Compute the generalized charbonnier loss of the difference tensor x 102 | All positions where mask == 0 are not taken into account 103 | x: a tensor of shape [b, c, h, w] 104 | mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as 105 | the number of channels of x. Entries should be 0 or 1 106 | return: loss 107 | """ 108 | b, c, h, w = x.shape 109 | norm = b * c * h * w 110 | error = torch.pow(torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha) 111 | if mask is not None: 112 | error = mask * error 113 | if truncate is not None: 114 | error = torch.min(error, truncate) 115 | return torch.sum(error) / norm 116 | 117 | 118 | def second_order_deltas(flow): 119 | """ 120 | consider the single flow first 121 | flow shape: [b, c, h, w] 122 | """ 123 | # create mask 124 | mask_x = create_mask(flow, [[0, 0], [1, 1]]) 125 | mask_y = create_mask(flow, [[1, 1], [0, 0]]) 126 | mask_diag = create_mask(flow, [[1, 1], [1, 1]]) 127 | mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1) 128 | mask = mask.to(flow.device) 129 | 130 | filter_x = torch.tensor([[0, 0, 0.], [1, -2, 1], [0, 0, 0]]) 131 | filter_y = torch.tensor([[0, 1, 0.], [0, -2, 0], [0, 1, 0]]) 132 | filter_diag1 = torch.tensor([[1, 0, 0.], [0, -2, 0], [0, 0, 1]]) 133 | filter_diag2 = torch.tensor([[0, 0, 1.], [0, -2, 0], [1, 0, 0]]) 134 | weights = torch.ones([4, 1, 3, 3]) 135 | weights[0] = filter_x 136 | weights[1] = filter_y 137 | weights[2] = filter_diag1 138 | weights[3] = filter_diag2 139 | weights = weights.to(flow.device) 140 | 141 | # split the flow into flow_u and flow_v, conv them with the weights 142 | flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) 143 | delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) 144 | delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) 145 | return delta_u, delta_v, mask 146 | 147 | def create_mask(tensor, paddings): 148 | """ 149 | tensor shape: [b, c, h, w] 150 | paddings: [2 x 2] shape list, the first row indicates up and down paddings 151 | the second row indicates left and right paddings 152 | | | 153 | | x | 154 | | x * x | 155 | | x | 156 | | | 157 | """ 158 | shape = tensor.shape 159 | inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) 160 | inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) 161 | inner = torch.ones([inner_height, inner_width]) 162 | torch_paddings = [paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] # left, right, up and down 163 | mask2d = F.pad(inner, pad=torch_paddings) 164 | mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1) 165 | mask4d = mask3d.unsqueeze(1) 166 | return mask4d.detach() 167 | 168 | def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, scale_factor=1): 169 | if scale_factor != 1: 170 | current_frame = F.interpolate(current_frame, scale_factor=1 / scale_factor, mode='bilinear') 171 | shift_frame = F.interpolate(shift_frame, scale_factor=1 / scale_factor, mode='bilinear') 172 | warped_sc = flow_warp(shift_frame, flow_gt.permute(0, 2, 3, 1)) 173 | noc_mask = torch.exp(-50. * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)).unsqueeze(1) 174 | warped_comp_sc = flow_warp(shift_frame, flow_comp.permute(0, 2, 3, 1)) 175 | loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask) 176 | return loss 177 | 178 | class FlowLoss(nn.Module): 179 | def __init__(self): 180 | super().__init__() 181 | self.l1_criterion = nn.L1Loss() 182 | 183 | def forward(self, pred_flows, gt_flows, masks, frames): 184 | # pred_flows: b t-1 2 h w 185 | loss = 0 186 | warp_loss = 0 187 | h, w = pred_flows[0].shape[-2:] 188 | masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()] 189 | frames0 = frames[:,:-1,...] 190 | frames1 = frames[:,1:,...] 191 | current_frames = [frames0, frames1] 192 | next_frames = [frames1, frames0] 193 | for i in range(len(pred_flows)): 194 | # print(pred_flows[i].shape) 195 | combined_flow = pred_flows[i] * masks[i] + gt_flows[i] * (1-masks[i]) 196 | l1_loss = self.l1_criterion(pred_flows[i] * masks[i], gt_flows[i] * masks[i]) / torch.mean(masks[i]) 197 | l1_loss += self.l1_criterion(pred_flows[i] * (1-masks[i]), gt_flows[i] * (1-masks[i])) / torch.mean((1-masks[i])) 198 | 199 | smooth_loss = smoothness_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w)) 200 | smooth_loss2 = second_order_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w)) 201 | 202 | warp_loss_i = ternary_loss(combined_flow.reshape(-1,2,h,w), gt_flows[i].reshape(-1,2,h,w), 203 | masks[i].reshape(-1,1,h,w), current_frames[i].reshape(-1,3,h,w), next_frames[i].reshape(-1,3,h,w)) 204 | 205 | loss += l1_loss + smooth_loss + smooth_loss2 206 | 207 | warp_loss += warp_loss_i 208 | 209 | return loss, warp_loss 210 | 211 | 212 | def edgeLoss(preds_edges, edges): 213 | """ 214 | 215 | Args: 216 | preds_edges: with shape [b, c, h , w] 217 | edges: with shape [b, c, h, w] 218 | 219 | Returns: Edge losses 220 | 221 | """ 222 | mask = (edges > 0.5).float() 223 | b, c, h, w = mask.shape 224 | num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,]. 225 | num_neg = c * h * w - num_pos # Shape: [b,]. 226 | neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) 227 | pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) 228 | weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug 229 | losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none') 230 | loss = torch.mean(losses) 231 | return loss 232 | 233 | class EdgeLoss(nn.Module): 234 | def __init__(self): 235 | super().__init__() 236 | 237 | def forward(self, pred_edges, gt_edges, masks): 238 | # pred_flows: b t-1 1 h w 239 | loss = 0 240 | h, w = pred_edges[0].shape[-2:] 241 | masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()] 242 | for i in range(len(pred_edges)): 243 | # print(f'edges_{i}', torch.sum(gt_edges[i])) # debug 244 | combined_edge = pred_edges[i] * masks[i] + gt_edges[i] * (1-masks[i]) 245 | edge_loss = (edgeLoss(pred_edges[i].reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)) \ 246 | + 5 * edgeLoss(combined_edge.reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w))) 247 | loss += edge_loss 248 | 249 | return loss 250 | 251 | 252 | class FlowSimpleLoss(nn.Module): 253 | def __init__(self): 254 | super().__init__() 255 | self.l1_criterion = nn.L1Loss() 256 | 257 | def forward(self, pred_flows, gt_flows): 258 | # pred_flows: b t-1 2 h w 259 | loss = 0 260 | h, w = pred_flows[0].shape[-2:] 261 | h_orig, w_orig = gt_flows[0].shape[-2:] 262 | pred_flows = [f.view(-1, 2, h, w) for f in pred_flows] 263 | gt_flows = [f.view(-1, 2, h_orig, w_orig) for f in gt_flows] 264 | 265 | ds_factor = 1.0*h/h_orig 266 | gt_flows = [F.interpolate(f, scale_factor=ds_factor, mode='area') * ds_factor for f in gt_flows] 267 | for i in range(len(pred_flows)): 268 | loss += self.l1_criterion(pred_flows[i], gt_flows[i]) 269 | 270 | return loss -------------------------------------------------------------------------------- /propainter/model/modules/flow_loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def flow_warp(x, 7 | flow, 8 | interpolation='bilinear', 9 | padding_mode='zeros', 10 | align_corners=True): 11 | """Warp an image or a feature map with optical flow. 12 | Args: 13 | x (Tensor): Tensor with size (n, c, h, w). 14 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is 15 | a two-channel, denoting the width and height relative offsets. 16 | Note that the values are not normalized to [-1, 1]. 17 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. 18 | Default: 'bilinear'. 19 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. 20 | Default: 'zeros'. 21 | align_corners (bool): Whether align corners. Default: True. 22 | Returns: 23 | Tensor: Warped image or feature map. 24 | """ 25 | if x.size()[-2:] != flow.size()[1:3]: 26 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' 27 | f'flow ({flow.size()[1:3]}) are not the same.') 28 | _, _, h, w = x.size() 29 | # create mesh grid 30 | device = flow.device 31 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h, device=device), torch.arange(0, w, device=device)) 32 | grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2) 33 | grid.requires_grad = False 34 | 35 | grid_flow = grid + flow 36 | # scale grid_flow to [-1,1] 37 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 38 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 39 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) 40 | output = F.grid_sample(x, 41 | grid_flow, 42 | mode=interpolation, 43 | padding_mode=padding_mode, 44 | align_corners=align_corners) 45 | return output 46 | 47 | 48 | # def image_warp(image, flow): 49 | # b, c, h, w = image.size() 50 | # device = image.device 51 | # flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right 52 | # flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension 53 | # x = np.linspace(-1, 1, w) 54 | # y = np.linspace(-1, 1, h) 55 | # X, Y = np.meshgrid(x, y) 56 | # grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3), 57 | # torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device) 58 | # output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros') 59 | # return output 60 | 61 | 62 | def length_sq(x): 63 | return torch.sum(torch.square(x), dim=1, keepdim=True) 64 | 65 | 66 | def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): 67 | flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x)) 68 | flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1)) # wf(wb(x)) 69 | flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x)) 70 | flow_diff_bw = flow_bw + flow_fw_warped # wb + wf(wb(x)) 71 | 72 | mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))| 73 | mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) # |wb| + |wf(wb(x))| 74 | occ_thresh_fw = alpha1 * mag_sq_fw + alpha2 75 | occ_thresh_bw = alpha1 * mag_sq_bw + alpha2 76 | 77 | fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float() 78 | fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float() 79 | 80 | return fb_occ_fw, fb_occ_bw # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2 81 | 82 | 83 | def rgb2gray(image): 84 | gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2] 85 | gray_image = gray_image.unsqueeze(1) 86 | return gray_image 87 | 88 | 89 | def ternary_transform(image, max_distance=1): 90 | device = image.device 91 | patch_size = 2 * max_distance + 1 92 | intensities = rgb2gray(image) * 255 93 | out_channels = patch_size * patch_size 94 | w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size) 95 | weights = torch.from_numpy(w).float().to(device) 96 | patches = F.conv2d(intensities, weights, stride=1, padding=1) 97 | transf = patches - intensities 98 | transf_norm = transf / torch.sqrt(0.81 + torch.square(transf)) 99 | return transf_norm 100 | 101 | 102 | def hamming_distance(t1, t2): 103 | dist = torch.square(t1 - t2) 104 | dist_norm = dist / (0.1 + dist) 105 | dist_sum = torch.sum(dist_norm, dim=1, keepdim=True) 106 | return dist_sum 107 | 108 | 109 | def create_mask(mask, paddings): 110 | """ 111 | padding: [[top, bottom], [left, right]] 112 | """ 113 | shape = mask.shape 114 | inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) 115 | inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) 116 | inner = torch.ones([inner_height, inner_width]) 117 | 118 | mask2d = F.pad(inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]]) 119 | mask3d = mask2d.unsqueeze(0) 120 | mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1) 121 | return mask4d.detach() 122 | 123 | 124 | def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1): 125 | """ 126 | 127 | Args: 128 | frame1: torch tensor, with shape [b * t, c, h, w] 129 | warp_frame21: torch tensor, with shape [b * t, c, h, w] 130 | confMask: confidence mask, with shape [b * t, c, h, w] 131 | masks: torch tensor, with shape [b * t, c, h, w] 132 | max_distance: maximum distance. 133 | 134 | Returns: ternary loss 135 | 136 | """ 137 | t1 = ternary_transform(frame1) 138 | t21 = ternary_transform(warp_frame21) 139 | dist = hamming_distance(t1, t21) 140 | loss = torch.mean(dist * confMask * masks) / torch.mean(masks) 141 | return loss 142 | 143 | -------------------------------------------------------------------------------- /propainter/model/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn as nn 5 | from torchvision.models import vgg as vgg 6 | 7 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' 8 | NAMES = { 9 | 'vgg11': [ 10 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 11 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 12 | 'pool5' 13 | ], 14 | 'vgg13': [ 15 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 16 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 17 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 18 | ], 19 | 'vgg16': [ 20 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 21 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 22 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 23 | 'pool5' 24 | ], 25 | 'vgg19': [ 26 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 27 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 28 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 29 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' 30 | ] 31 | } 32 | 33 | 34 | def insert_bn(names): 35 | """Insert bn layer after each conv. 36 | 37 | Args: 38 | names (list): The list of layer names. 39 | 40 | Returns: 41 | list: The list of layer names with bn layers. 42 | """ 43 | names_bn = [] 44 | for name in names: 45 | names_bn.append(name) 46 | if 'conv' in name: 47 | position = name.replace('conv', '') 48 | names_bn.append('bn' + position) 49 | return names_bn 50 | 51 | class VGGFeatureExtractor(nn.Module): 52 | """VGG network for feature extraction. 53 | 54 | In this implementation, we allow users to choose whether use normalization 55 | in the input feature and the type of vgg network. Note that the pretrained 56 | path must fit the vgg type. 57 | 58 | Args: 59 | layer_name_list (list[str]): Forward function returns the corresponding 60 | features according to the layer_name_list. 61 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 62 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 63 | use_input_norm (bool): If True, normalize the input image. Importantly, 64 | the input feature must in the range [0, 1]. Default: True. 65 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 66 | Default: False. 67 | requires_grad (bool): If true, the parameters of VGG network will be 68 | optimized. Default: False. 69 | remove_pooling (bool): If true, the max pooling operations in VGG net 70 | will be removed. Default: False. 71 | pooling_stride (int): The stride of max pooling operation. Default: 2. 72 | """ 73 | 74 | def __init__(self, 75 | layer_name_list, 76 | vgg_type='vgg19', 77 | use_input_norm=True, 78 | range_norm=False, 79 | requires_grad=False, 80 | remove_pooling=False, 81 | pooling_stride=2): 82 | super(VGGFeatureExtractor, self).__init__() 83 | 84 | self.layer_name_list = layer_name_list 85 | self.use_input_norm = use_input_norm 86 | self.range_norm = range_norm 87 | 88 | self.names = NAMES[vgg_type.replace('_bn', '')] 89 | if 'bn' in vgg_type: 90 | self.names = insert_bn(self.names) 91 | 92 | # only borrow layers that will be used to avoid unused params 93 | max_idx = 0 94 | for v in layer_name_list: 95 | idx = self.names.index(v) 96 | if idx > max_idx: 97 | max_idx = idx 98 | 99 | if os.path.exists(VGG_PRETRAIN_PATH): 100 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 101 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 102 | vgg_net.load_state_dict(state_dict) 103 | else: 104 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 105 | 106 | features = vgg_net.features[:max_idx + 1] 107 | 108 | modified_net = OrderedDict() 109 | for k, v in zip(self.names, features): 110 | if 'pool' in k: 111 | # if remove_pooling is true, pooling operation will be removed 112 | if remove_pooling: 113 | continue 114 | else: 115 | # in some cases, we may want to change the default stride 116 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 117 | else: 118 | modified_net[k] = v 119 | 120 | self.vgg_net = nn.Sequential(modified_net) 121 | 122 | if not requires_grad: 123 | self.vgg_net.eval() 124 | for param in self.parameters(): 125 | param.requires_grad = False 126 | else: 127 | self.vgg_net.train() 128 | for param in self.parameters(): 129 | param.requires_grad = True 130 | 131 | if self.use_input_norm: 132 | # the mean is for image with range [0, 1] 133 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 134 | # the std is for image with range [0, 1] 135 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 136 | 137 | def forward(self, x): 138 | """Forward function. 139 | 140 | Args: 141 | x (Tensor): Input tensor with shape (n, c, h, w). 142 | 143 | Returns: 144 | Tensor: Forward results. 145 | """ 146 | if self.range_norm: 147 | x = (x + 1) / 2 148 | if self.use_input_norm: 149 | x = (x - self.mean) / self.std 150 | output = {} 151 | 152 | for key, layer in self.vgg_net._modules.items(): 153 | x = layer(x) 154 | if key in self.layer_name_list: 155 | output[key] = x.clone() 156 | 157 | return output 158 | -------------------------------------------------------------------------------- /propainter/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /propainter/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | def sizeof_fmt(size, suffix='B'): 9 | """Get human readable file size. 10 | 11 | Args: 12 | size (int): File size. 13 | suffix (str): Suffix. Default: 'B'. 14 | 15 | Return: 16 | str: Formated file siz. 17 | """ 18 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 19 | if abs(size) < 1024.0: 20 | return f'{size:3.1f} {unit}{suffix}' 21 | size /= 1024.0 22 | return f'{size:3.1f} Y{suffix}' 23 | 24 | 25 | def download_file_from_google_drive(file_id, save_path): 26 | """Download files from google drive. 27 | Ref: 28 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 29 | Args: 30 | file_id (str): File id. 31 | save_path (str): Save path. 32 | """ 33 | 34 | session = requests.Session() 35 | URL = 'https://docs.google.com/uc?export=download' 36 | params = {'id': file_id} 37 | 38 | response = session.get(URL, params=params, stream=True) 39 | token = get_confirm_token(response) 40 | if token: 41 | params['confirm'] = token 42 | response = session.get(URL, params=params, stream=True) 43 | 44 | # get file size 45 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 46 | print(response_file_size) 47 | if 'Content-Range' in response_file_size.headers: 48 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 49 | else: 50 | file_size = None 51 | 52 | save_response_content(response, save_path, file_size) 53 | 54 | 55 | def get_confirm_token(response): 56 | for key, value in response.cookies.items(): 57 | if key.startswith('download_warning'): 58 | return value 59 | return None 60 | 61 | 62 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 63 | if file_size is not None: 64 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 65 | 66 | readable_file_size = sizeof_fmt(file_size) 67 | else: 68 | pbar = None 69 | 70 | with open(destination, 'wb') as f: 71 | downloaded_size = 0 72 | for chunk in response.iter_content(chunk_size): 73 | downloaded_size += chunk_size 74 | if pbar is not None: 75 | pbar.update(1) 76 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 77 | if chunk: # filter out keep-alive new chunks 78 | f.write(chunk) 79 | if pbar is not None: 80 | pbar.close() 81 | 82 | 83 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 84 | """Load file form http url, will download models if necessary. 85 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 86 | Args: 87 | url (str): URL to be downloaded. 88 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 89 | Default: None. 90 | progress (bool): Whether to show the download progress. Default: True. 91 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 92 | Returns: 93 | str: The path to the downloaded file. 94 | """ 95 | if model_dir is None: # use the pytorch hub_dir 96 | hub_dir = get_dir() 97 | model_dir = os.path.join(hub_dir, 'checkpoints') 98 | 99 | os.makedirs(model_dir, exist_ok=True) 100 | 101 | parts = urlparse(url) 102 | filename = os.path.basename(parts.path) 103 | if file_name is not None: 104 | filename = file_name 105 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 106 | if not os.path.exists(cached_file): 107 | print(f'Downloading: "{url}" to {cached_file}\n') 108 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 109 | return cached_file -------------------------------------------------------------------------------- /propainter/utils/file_client.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class BaseStorageBackend(metaclass=ABCMeta): 5 | """Abstract class of storage backends. 6 | 7 | All backends need to implement two apis: ``get()`` and ``get_text()``. 8 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 9 | as texts. 10 | """ 11 | 12 | @abstractmethod 13 | def get(self, filepath): 14 | pass 15 | 16 | @abstractmethod 17 | def get_text(self, filepath): 18 | pass 19 | 20 | 21 | class MemcachedBackend(BaseStorageBackend): 22 | """Memcached storage backend. 23 | 24 | Attributes: 25 | server_list_cfg (str): Config file for memcached server list. 26 | client_cfg (str): Config file for memcached client. 27 | sys_path (str | None): Additional path to be appended to `sys.path`. 28 | Default: None. 29 | """ 30 | 31 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 32 | if sys_path is not None: 33 | import sys 34 | sys.path.append(sys_path) 35 | try: 36 | import mc 37 | except ImportError: 38 | raise ImportError('Please install memcached to enable MemcachedBackend.') 39 | 40 | self.server_list_cfg = server_list_cfg 41 | self.client_cfg = client_cfg 42 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 43 | # mc.pyvector servers as a point which points to a memory cache 44 | self._mc_buffer = mc.pyvector() 45 | 46 | def get(self, filepath): 47 | filepath = str(filepath) 48 | import mc 49 | self._client.Get(filepath, self._mc_buffer) 50 | value_buf = mc.ConvertBuffer(self._mc_buffer) 51 | return value_buf 52 | 53 | def get_text(self, filepath): 54 | raise NotImplementedError 55 | 56 | 57 | class HardDiskBackend(BaseStorageBackend): 58 | """Raw hard disks storage backend.""" 59 | 60 | def get(self, filepath): 61 | filepath = str(filepath) 62 | with open(filepath, 'rb') as f: 63 | value_buf = f.read() 64 | return value_buf 65 | 66 | def get_text(self, filepath): 67 | filepath = str(filepath) 68 | with open(filepath, 'r') as f: 69 | value_buf = f.read() 70 | return value_buf 71 | 72 | 73 | class LmdbBackend(BaseStorageBackend): 74 | """Lmdb storage backend. 75 | 76 | Args: 77 | db_paths (str | list[str]): Lmdb database paths. 78 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 79 | readonly (bool, optional): Lmdb environment parameter. If True, 80 | disallow any write operations. Default: True. 81 | lock (bool, optional): Lmdb environment parameter. If False, when 82 | concurrent access occurs, do not lock the database. Default: False. 83 | readahead (bool, optional): Lmdb environment parameter. If False, 84 | disable the OS filesystem readahead mechanism, which may improve 85 | random read performance when a database is larger than RAM. 86 | Default: False. 87 | 88 | Attributes: 89 | db_paths (list): Lmdb database path. 90 | _client (list): A list of several lmdb envs. 91 | """ 92 | 93 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 94 | try: 95 | import lmdb 96 | except ImportError: 97 | raise ImportError('Please install lmdb to enable LmdbBackend.') 98 | 99 | if isinstance(client_keys, str): 100 | client_keys = [client_keys] 101 | 102 | if isinstance(db_paths, list): 103 | self.db_paths = [str(v) for v in db_paths] 104 | elif isinstance(db_paths, str): 105 | self.db_paths = [str(db_paths)] 106 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 107 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 108 | 109 | self._client = {} 110 | for client, path in zip(client_keys, self.db_paths): 111 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 112 | 113 | def get(self, filepath, client_key): 114 | """Get values according to the filepath from one lmdb named client_key. 115 | 116 | Args: 117 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 118 | client_key (str): Used for distinguishing differnet lmdb envs. 119 | """ 120 | filepath = str(filepath) 121 | assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') 122 | client = self._client[client_key] 123 | with client.begin(write=False) as txn: 124 | value_buf = txn.get(filepath.encode('ascii')) 125 | return value_buf 126 | 127 | def get_text(self, filepath): 128 | raise NotImplementedError 129 | 130 | 131 | class FileClient(object): 132 | """A general file client to access files in different backend. 133 | 134 | The client loads a file or text in a specified backend from its path 135 | and return it as a binary file. it can also register other backend 136 | accessor with a given name and backend class. 137 | 138 | Attributes: 139 | backend (str): The storage backend type. Options are "disk", 140 | "memcached" and "lmdb". 141 | client (:obj:`BaseStorageBackend`): The backend object. 142 | """ 143 | 144 | _backends = { 145 | 'disk': HardDiskBackend, 146 | 'memcached': MemcachedBackend, 147 | 'lmdb': LmdbBackend, 148 | } 149 | 150 | def __init__(self, backend='disk', **kwargs): 151 | if backend not in self._backends: 152 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 153 | f' are {list(self._backends.keys())}') 154 | self.backend = backend 155 | self.client = self._backends[backend](**kwargs) 156 | 157 | def get(self, filepath, client_key='default'): 158 | # client_key is used only for lmdb, where different fileclients have 159 | # different lmdb environments. 160 | if self.backend == 'lmdb': 161 | return self.client.get(filepath, client_key) 162 | else: 163 | return self.client.get(filepath) 164 | 165 | def get_text(self, filepath): 166 | return self.client.get_text(filepath) 167 | -------------------------------------------------------------------------------- /propainter/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import torch.nn.functional as F 5 | 6 | def resize_flow(flow, newh, neww): 7 | oldh, oldw = flow.shape[0:2] 8 | flow = cv2.resize(flow, (neww, newh), interpolation=cv2.INTER_LINEAR) 9 | flow[:, :, 0] *= neww / oldw 10 | flow[:, :, 1] *= newh / oldh 11 | return flow 12 | 13 | def resize_flow_pytorch(flow, newh, neww): 14 | oldh, oldw = flow.shape[-2:] 15 | flow = F.interpolate(flow, (newh, neww), mode='bilinear') 16 | flow[:, :, 0] *= neww / oldw 17 | flow[:, :, 1] *= newh / oldh 18 | return flow 19 | 20 | 21 | def imwrite(img, file_path, params=None, auto_mkdir=True): 22 | if auto_mkdir: 23 | dir_name = os.path.abspath(os.path.dirname(file_path)) 24 | os.makedirs(dir_name, exist_ok=True) 25 | return cv2.imwrite(file_path, img, params) 26 | 27 | 28 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 29 | """Read an optical flow map. 30 | 31 | Args: 32 | flow_path (ndarray or str): Flow path. 33 | quantize (bool): whether to read quantized pair, if set to True, 34 | remaining args will be passed to :func:`dequantize_flow`. 35 | concat_axis (int): The axis that dx and dy are concatenated, 36 | can be either 0 or 1. Ignored if quantize is False. 37 | 38 | Returns: 39 | ndarray: Optical flow represented as a (h, w, 2) numpy array 40 | """ 41 | if quantize: 42 | assert concat_axis in [0, 1] 43 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 44 | if cat_flow.ndim != 2: 45 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') 46 | assert cat_flow.shape[concat_axis] % 2 == 0 47 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 48 | flow = dequantize_flow(dx, dy, *args, **kwargs) 49 | else: 50 | with open(flow_path, 'rb') as f: 51 | try: 52 | header = f.read(4).decode('utf-8') 53 | except Exception: 54 | raise IOError(f'Invalid flow file: {flow_path}') 55 | else: 56 | if header != 'PIEH': 57 | raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') 58 | 59 | w = np.fromfile(f, np.int32, 1).squeeze() 60 | h = np.fromfile(f, np.int32, 1).squeeze() 61 | # flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 62 | flow = np.fromfile(f, np.float16, w * h * 2).reshape((h, w, 2)) 63 | 64 | return flow.astype(np.float32) 65 | 66 | 67 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 68 | """Write optical flow to file. 69 | 70 | If the flow is not quantized, it will be saved as a .flo file losslessly, 71 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 72 | will be concatenated horizontally into a single image if quantize is True.) 73 | 74 | Args: 75 | flow (ndarray): (h, w, 2) array of optical flow. 76 | filename (str): Output filepath. 77 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 78 | images. If set to True, remaining args will be passed to 79 | :func:`quantize_flow`. 80 | concat_axis (int): The axis that dx and dy are concatenated, 81 | can be either 0 or 1. Ignored if quantize is False. 82 | """ 83 | dir_name = os.path.abspath(os.path.dirname(filename)) 84 | os.makedirs(dir_name, exist_ok=True) 85 | if not quantize: 86 | with open(filename, 'wb') as f: 87 | f.write('PIEH'.encode('utf-8')) 88 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 89 | # flow = flow.astype(np.float32) 90 | flow = flow.astype(np.float16) 91 | flow.tofile(f) 92 | f.flush() 93 | else: 94 | assert concat_axis in [0, 1] 95 | dx, dy = quantize_flow(flow, *args, **kwargs) 96 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 97 | # os.makedirs(os.path.dirname(filename), exist_ok=True) 98 | cv2.imwrite(filename, dxdy) 99 | # imwrite(dxdy, filename) 100 | 101 | 102 | def quantize_flow(flow, max_val=0.02, norm=True): 103 | """Quantize flow to [0, 255]. 104 | 105 | After this step, the size of flow will be much smaller, and can be 106 | dumped as jpeg images. 107 | 108 | Args: 109 | flow (ndarray): (h, w, 2) array of optical flow. 110 | max_val (float): Maximum value of flow, values beyond 111 | [-max_val, max_val] will be truncated. 112 | norm (bool): Whether to divide flow values by image width/height. 113 | 114 | Returns: 115 | tuple[ndarray]: Quantized dx and dy. 116 | """ 117 | h, w, _ = flow.shape 118 | dx = flow[..., 0] 119 | dy = flow[..., 1] 120 | if norm: 121 | dx = dx / w # avoid inplace operations 122 | dy = dy / h 123 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 124 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] 125 | return tuple(flow_comps) 126 | 127 | 128 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 129 | """Recover from quantized flow. 130 | 131 | Args: 132 | dx (ndarray): Quantized dx. 133 | dy (ndarray): Quantized dy. 134 | max_val (float): Maximum value used when quantizing. 135 | denorm (bool): Whether to multiply flow values with width/height. 136 | 137 | Returns: 138 | ndarray: Dequantized flow. 139 | """ 140 | assert dx.shape == dy.shape 141 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 142 | 143 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 144 | 145 | if denorm: 146 | dx *= dx.shape[1] 147 | dy *= dx.shape[0] 148 | flow = np.dstack((dx, dy)) 149 | return flow 150 | 151 | 152 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 153 | """Quantize an array of (-inf, inf) to [0, levels-1]. 154 | 155 | Args: 156 | arr (ndarray): Input array. 157 | min_val (scalar): Minimum value to be clipped. 158 | max_val (scalar): Maximum value to be clipped. 159 | levels (int): Quantization levels. 160 | dtype (np.type): The type of the quantized array. 161 | 162 | Returns: 163 | tuple: Quantized array. 164 | """ 165 | if not (isinstance(levels, int) and levels > 1): 166 | raise ValueError(f'levels must be a positive integer, but got {levels}') 167 | if min_val >= max_val: 168 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 169 | 170 | arr = np.clip(arr, min_val, max_val) - min_val 171 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 172 | 173 | return quantized_arr 174 | 175 | 176 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 177 | """Dequantize an array. 178 | 179 | Args: 180 | arr (ndarray): Input array. 181 | min_val (scalar): Minimum value to be clipped. 182 | max_val (scalar): Maximum value to be clipped. 183 | levels (int): Quantization levels. 184 | dtype (np.type): The type of the dequantized array. 185 | 186 | Returns: 187 | tuple: Dequantized array. 188 | """ 189 | if not (isinstance(levels, int) and levels > 1): 190 | raise ValueError(f'levels must be a positive integer, but got {levels}') 191 | if min_val >= max_val: 192 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 193 | 194 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val 195 | 196 | return dequantized_arr -------------------------------------------------------------------------------- /propainter/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | return cv2.imwrite(file_path, img, params) 152 | 153 | 154 | def crop_border(imgs, crop_border): 155 | """Crop borders of images. 156 | 157 | Args: 158 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 159 | crop_border (int): Crop border for each end of height and weight. 160 | 161 | Returns: 162 | list[ndarray]: Cropped images. 163 | """ 164 | if crop_border == 0: 165 | return imgs 166 | else: 167 | if isinstance(imgs, list): 168 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 169 | else: 170 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 171 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_diffueraser" 3 | description = "DiffuEraser is a diffusion model for video Inpainting, you can use it in ComfyUI" 4 | version = "1.0.0" 5 | license = { file = "LICENSE" } 6 | dependencies = ["torch", "torchvision", "torchaudio", "diffusers", "accelerate", "opencv-python", "imageio", "#matplotlib", "transformers", "einops", "#datasets", "#numpy==1.26.4", "#pillow==10.4.0", "#tqdm==4.66.4", "#urllib3==2.2.2", "#zipp==3.19.2", "peft", "#scipy==1.13.1", "#av==14.0.1"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/smthemex/ComfyUI_DiffuEraser" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "smthemex" 14 | DisplayName = "ComfyUI_DiffuEraser" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | diffusers 5 | accelerate 6 | opencv-python 7 | imageio 8 | #matplotlib 9 | transformers 10 | einops 11 | #datasets 12 | #numpy==1.26.4 13 | #pillow==10.4.0 14 | #tqdm==4.66.4 15 | #urllib3==2.2.2 16 | #zipp==3.19.2 17 | peft 18 | #scipy==1.13.1 19 | #av==14.0.1 -------------------------------------------------------------------------------- /run_diffueraser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | import random 5 | from .libs.diffueraser import DiffuEraser 6 | from .propainter.inference import Propainter, get_device 7 | import folder_paths 8 | import gc 9 | 10 | 11 | def load_diffueraser(pre_model_path, pcm_lora_path,sd_repo,ckpt_path,original_config_file,device): 12 | 13 | start_time = time.time() 14 | device = get_device() 15 | #ckpt = "2-Step" 16 | propainter_model_dir=os.path.join(pre_model_path, "propainter") 17 | if not os.path.exists(propainter_model_dir): 18 | os.makedirs(propainter_model_dir) 19 | video_inpainting_sd = DiffuEraser(device, sd_repo, pre_model_path,ckpt_path,original_config_file, ckpt=pcm_lora_path) 20 | propainter = Propainter(propainter_model_dir, device=device) 21 | 22 | end_time = time.time() 23 | load_time = end_time - start_time 24 | print(f"DiffuEraser load time: {load_time:.4f} s") 25 | return {"video_inpainting_sd":video_inpainting_sd,"propainter":propainter} 26 | 27 | 28 | def diffueraser_inference(video_inpainting_sd,propainter,input_video,input_mask,video_length,width,height,mask_dilation_iter, 29 | max_img_size,ref_stride,neighbor_length,subvideo_length,guidance_scale,num_inference_steps,seed,fps,save_result_video,): 30 | 31 | prefix = ''.join(random.choice("0123456789") for _ in range(6)) 32 | priori_path = os.path.join(folder_paths.get_output_directory(), f"priori_{prefix}.mp4") 33 | if not os.path.exists(os.path.dirname(priori_path)): 34 | os.makedirs(os.path.dirname(priori_path)) 35 | output_path = os.path.join(folder_paths.get_output_directory(), f"diffueraser_result_{prefix}.mp4") 36 | start_time = time.time() 37 | load_videobypath=False 38 | # if load_videobypath: 39 | # input_mask="F:/test/ComfyUI/input/mask.mp4" 40 | # input_video="F:/test/ComfyUI/input/video.mp4" 41 | ## priori 42 | res=propainter.forward(input_video, input_mask, priori_path,load_videobypath=load_videobypath,video_length=video_length, height=height,width=width, 43 | ref_stride=ref_stride, neighbor_length=neighbor_length, subvideo_length = subvideo_length, 44 | mask_dilation = mask_dilation_iter,save_fps=fps) 45 | 46 | propainter.to("cpu") 47 | gc.collect() 48 | torch.cuda.empty_cache() 49 | ## diffueraser 50 | # The default value is 0. 51 | video_path,image_list,Propainter_img=video_inpainting_sd.forward(input_video, input_mask, priori_path,output_path,load_videobypath=load_videobypath, 52 | max_img_size = max_img_size, video_length=video_length, mask_dilation_iter=mask_dilation_iter,seed=seed, 53 | guidance_scale=guidance_scale,num_inference_steps=num_inference_steps,fps=fps,img_size=(width,height),if_save_video=save_result_video) 54 | 55 | end_time = time.time() 56 | inference_time = end_time - start_time 57 | print(f"DiffuEraser inference time: {inference_time:.4f} s") 58 | 59 | torch.cuda.empty_cache() 60 | return video_path,image_list,Propainter_img 61 | 62 | -------------------------------------------------------------------------------- /sd15_repo/feature_extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": { 3 | "height": 224, 4 | "width": 224 5 | }, 6 | "do_center_crop": true, 7 | "do_convert_rgb": true, 8 | "do_normalize": true, 9 | "do_rescale": true, 10 | "do_resize": true, 11 | "feature_extractor_type": "CLIPFeatureExtractor", 12 | "image_mean": [ 13 | 0.48145466, 14 | 0.4578275, 15 | 0.40821073 16 | ], 17 | "image_processor_type": "CLIPFeatureExtractor", 18 | "image_std": [ 19 | 0.26862954, 20 | 0.26130258, 21 | 0.27577711 22 | ], 23 | "resample": 3, 24 | "rescale_factor": 0.00392156862745098, 25 | "size": { 26 | "shortest_edge": 224 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /sd15_repo/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableDiffusionPipeline", 3 | "_diffusers_version": "0.21.0.dev0", 4 | "_name_or_path": "lykon-models/dreamshaper-8", 5 | "feature_extractor": [ 6 | "transformers", 7 | "CLIPFeatureExtractor" 8 | ], 9 | "requires_safety_checker": true, 10 | "safety_checker": [ 11 | "stable_diffusion", 12 | "StableDiffusionSafetyChecker" 13 | ], 14 | "scheduler": [ 15 | "diffusers", 16 | "DEISMultistepScheduler" 17 | ], 18 | "text_encoder": [ 19 | "transformers", 20 | "CLIPTextModel" 21 | ], 22 | "tokenizer": [ 23 | "transformers", 24 | "CLIPTokenizer" 25 | ], 26 | "unet": [ 27 | "diffusers", 28 | "UNet2DConditionModel" 29 | ], 30 | "vae": [ 31 | "diffusers", 32 | "AutoencoderKL" 33 | ] 34 | } 35 | -------------------------------------------------------------------------------- /sd15_repo/safety_checker/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8/snapshots/7e855e3f481832419503d1fa18d4a4379597f04b/safety_checker", 3 | "architectures": [ 4 | "StableDiffusionSafetyChecker" 5 | ], 6 | "initializer_factor": 1.0, 7 | "logit_scale_init_value": 2.6592, 8 | "model_type": "clip", 9 | "projection_dim": 768, 10 | "text_config": { 11 | "dropout": 0.0, 12 | "hidden_size": 768, 13 | "intermediate_size": 3072, 14 | "model_type": "clip_text_model", 15 | "num_attention_heads": 12 16 | }, 17 | "torch_dtype": "float16", 18 | "transformers_version": "4.33.0.dev0", 19 | "vision_config": { 20 | "dropout": 0.0, 21 | "hidden_size": 1024, 22 | "intermediate_size": 4096, 23 | "model_type": "clip_vision_model", 24 | "num_attention_heads": 16, 25 | "num_hidden_layers": 24, 26 | "patch_size": 14 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /sd15_repo/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "DEISMultistepScheduler", 3 | "_diffusers_version": "0.21.0.dev0", 4 | "algorithm_type": "deis", 5 | "beta_end": 0.012, 6 | "beta_schedule": "scaled_linear", 7 | "beta_start": 0.00085, 8 | "clip_sample": false, 9 | "dynamic_thresholding_ratio": 0.995, 10 | "lower_order_final": true, 11 | "num_train_timesteps": 1000, 12 | "prediction_type": "epsilon", 13 | "sample_max_value": 1.0, 14 | "set_alpha_to_one": false, 15 | "skip_prk_steps": true, 16 | "solver_order": 2, 17 | "solver_type": "logrho", 18 | "steps_offset": 1, 19 | "thresholding": false, 20 | "timestep_spacing": "leading", 21 | "trained_betas": null, 22 | "use_karras_sigmas": false 23 | } 24 | -------------------------------------------------------------------------------- /sd15_repo/text_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8/snapshots/7e855e3f481832419503d1fa18d4a4379597f04b/text_encoder", 3 | "architectures": [ 4 | "CLIPTextModel" 5 | ], 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 0, 8 | "dropout": 0.0, 9 | "eos_token_id": 2, 10 | "hidden_act": "quick_gelu", 11 | "hidden_size": 768, 12 | "initializer_factor": 1.0, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-05, 16 | "max_position_embeddings": 77, 17 | "model_type": "clip_text_model", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 1, 21 | "projection_dim": 768, 22 | "torch_dtype": "float16", 23 | "transformers_version": "4.33.0.dev0", 24 | "vocab_size": 49408 25 | } 26 | -------------------------------------------------------------------------------- /sd15_repo/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /sd15_repo/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "clean_up_tokenization_spaces": true, 12 | "do_lower_case": true, 13 | "eos_token": { 14 | "__type": "AddedToken", 15 | "content": "<|endoftext|>", 16 | "lstrip": false, 17 | "normalized": true, 18 | "rstrip": false, 19 | "single_word": false 20 | }, 21 | "errors": "replace", 22 | "model_max_length": 77, 23 | "pad_token": "<|endoftext|>", 24 | "tokenizer_class": "CLIPTokenizer", 25 | "unk_token": { 26 | "__type": "AddedToken", 27 | "content": "<|endoftext|>", 28 | "lstrip": false, 29 | "normalized": true, 30 | "rstrip": false, 31 | "single_word": false 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /sd15_repo/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "_diffusers_version": "0.21.0.dev0", 4 | "_name_or_path": "/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8/snapshots/7e855e3f481832419503d1fa18d4a4379597f04b/unet", 5 | "act_fn": "silu", 6 | "addition_embed_type": null, 7 | "addition_embed_type_num_heads": 64, 8 | "addition_time_embed_dim": null, 9 | "attention_head_dim": 8, 10 | "attention_type": "default", 11 | "block_out_channels": [ 12 | 320, 13 | 640, 14 | 1280, 15 | 1280 16 | ], 17 | "center_input_sample": false, 18 | "class_embed_type": null, 19 | "class_embeddings_concat": false, 20 | "conv_in_kernel": 3, 21 | "conv_out_kernel": 3, 22 | "cross_attention_dim": 768, 23 | "cross_attention_norm": null, 24 | "down_block_types": [ 25 | "CrossAttnDownBlock2D", 26 | "CrossAttnDownBlock2D", 27 | "CrossAttnDownBlock2D", 28 | "DownBlock2D" 29 | ], 30 | "downsample_padding": 1, 31 | "dual_cross_attention": false, 32 | "encoder_hid_dim": null, 33 | "encoder_hid_dim_type": null, 34 | "flip_sin_to_cos": true, 35 | "freq_shift": 0, 36 | "in_channels": 4, 37 | "layers_per_block": 2, 38 | "mid_block_only_cross_attention": null, 39 | "mid_block_scale_factor": 1, 40 | "mid_block_type": "UNetMidBlock2DCrossAttn", 41 | "norm_eps": 1e-05, 42 | "norm_num_groups": 32, 43 | "num_attention_heads": null, 44 | "num_class_embeds": null, 45 | "only_cross_attention": false, 46 | "out_channels": 4, 47 | "projection_class_embeddings_input_dim": null, 48 | "resnet_out_scale_factor": 1.0, 49 | "resnet_skip_time_act": false, 50 | "resnet_time_scale_shift": "default", 51 | "sample_size": 64, 52 | "time_cond_proj_dim": null, 53 | "time_embedding_act_fn": null, 54 | "time_embedding_dim": null, 55 | "time_embedding_type": "positional", 56 | "timestep_post_act": null, 57 | "transformer_layers_per_block": 1, 58 | "up_block_types": [ 59 | "UpBlock2D", 60 | "CrossAttnUpBlock2D", 61 | "CrossAttnUpBlock2D", 62 | "CrossAttnUpBlock2D" 63 | ], 64 | "upcast_attention": null, 65 | "use_linear_projection": false 66 | } 67 | -------------------------------------------------------------------------------- /sd15_repo/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.21.0.dev0", 4 | "_name_or_path": "/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8/snapshots/7e855e3f481832419503d1fa18d4a4379597f04b/vae", 5 | "act_fn": "silu", 6 | "block_out_channels": [ 7 | 128, 8 | 256, 9 | 512, 10 | 512 11 | ], 12 | "down_block_types": [ 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D", 16 | "DownEncoderBlock2D" 17 | ], 18 | "force_upcast": true, 19 | "in_channels": 3, 20 | "latent_channels": 4, 21 | "layers_per_block": 2, 22 | "norm_num_groups": 32, 23 | "out_channels": 3, 24 | "sample_size": 512, 25 | "scaling_factor": 0.18215, 26 | "up_block_types": [ 27 | "UpDecoderBlock2D", 28 | "UpDecoderBlock2D", 29 | "UpDecoderBlock2D", 30 | "UpDecoderBlock2D" 31 | ] 32 | } 33 | --------------------------------------------------------------------------------