├── .gitignore ├── LICENSE.txt ├── ProPainter ├── .gitignore ├── LICENSE ├── RAFT │ ├── __init__.py │ ├── corr.py │ ├── datasets.py │ ├── demo.py │ ├── extractor.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── flow_viz_pt.py │ │ ├── frame_utils.py │ │ └── utils.py ├── README.md ├── assets │ ├── ProPainter_pipeline.png │ ├── object_removal1.gif │ ├── object_removal2.gif │ ├── propainter_logo1.png │ ├── propainter_logo1_glow.png │ ├── video_completion1.gif │ ├── video_completion2.gif │ ├── video_completion3.gif │ └── video_completion4.gif ├── configs │ ├── train_flowcomp.json │ └── train_propainter.json ├── core │ ├── dataset.py │ ├── dist.py │ ├── loss.py │ ├── lr_scheduler.py │ ├── metrics.py │ ├── prefetch_dataloader.py │ ├── trainer.py │ ├── trainer_flow_w_edge.py │ └── utils.py ├── inference_propainter.py ├── model │ ├── __init__.py │ ├── canny │ │ ├── canny_filter.py │ │ ├── filter.py │ │ ├── gaussian.py │ │ ├── kernels.py │ │ └── sobel.py │ ├── misc.py │ ├── modules │ │ ├── base_module.py │ │ ├── deformconv.py │ │ ├── flow_comp_raft.py │ │ ├── flow_loss_utils.py │ │ ├── sparse_transformer.py │ │ └── spectral_norm.py │ ├── propainter.py │ ├── recurrent_flow_completion.py │ └── vgg_arch.py ├── requirements.txt ├── scripts │ ├── compute_flow.py │ ├── evaluate_flow_completion.py │ └── evaluate_propainter.py ├── train.py └── utils │ ├── download_util.py │ ├── file_client.py │ ├── flow_util.py │ └── img_util.py ├── SegTracker.py ├── aot ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── __init__.py ├── configs │ ├── default.py │ ├── models │ │ ├── aotb.py │ │ ├── aotl.py │ │ ├── aots.py │ │ ├── aott.py │ │ ├── deaotb.py │ │ ├── deaotl.py │ │ ├── deaots.py │ │ ├── deaott.py │ │ ├── default.py │ │ ├── default_deaot.py │ │ ├── r101_aotl.py │ │ ├── r50_aotl.py │ │ ├── r50_deaotl.py │ │ ├── rs101_aotl.py │ │ ├── swinb_aotl.py │ │ └── swinb_deaotl.py │ ├── pre.py │ ├── pre_dav.py │ ├── pre_ytb.py │ ├── pre_ytb_dav.py │ └── ytb.py ├── dataloaders │ ├── __init__.py │ ├── eval_datasets.py │ ├── image_transforms.py │ ├── train_datasets.py │ └── video_transforms.py ├── networks │ ├── .DS_Store │ ├── __init__.py │ ├── decoders │ │ ├── __init__.py │ │ └── fpn.py │ ├── encoders │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── mobilenetv2.py │ │ ├── mobilenetv3.py │ │ ├── resnest │ │ │ ├── __init__.py │ │ │ ├── resnest.py │ │ │ ├── resnet.py │ │ │ └── splat.py │ │ ├── resnet.py │ │ └── swin │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── swin_transformer.py │ ├── engines │ │ ├── __init__.py │ │ ├── aot_engine.py │ │ └── deaot_engine.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── basic.py │ │ ├── loss.py │ │ ├── normalization.py │ │ ├── position.py │ │ └── transformer.py │ ├── managers │ │ ├── evaluator.py │ │ └── trainer.py │ └── models │ │ ├── __init__.py │ │ ├── aot.py │ │ └── deaot.py ├── source │ ├── .DS_Store │ ├── overview.png │ └── overview_deaot.png ├── tools │ ├── demo.py │ ├── eval.py │ └── train.py ├── train_eval.sh └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── cp_ckpt.py │ ├── ema.py │ ├── eval.py │ ├── image.py │ ├── learning.py │ ├── math.py │ ├── meters.py │ └── metric.py ├── aot_tracker.py ├── app.py ├── demo.gif ├── groundingdino ├── _C.cp310-win_amd64.pyd ├── __init__.py ├── config │ ├── GroundingDINO_SwinB_cfg.py │ ├── GroundingDINO_SwinT_OGC.py │ └── __init__.py ├── datasets │ ├── __init__.py │ ├── cocogrounding_eval.py │ └── transforms.py ├── models │ ├── GroundingDINO │ │ ├── __init__.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── backbone.py │ │ │ ├── position_encoding.py │ │ │ └── swin_transformer.py │ │ ├── bertwarper.py │ │ ├── csrc │ │ │ ├── MsDeformAttn │ │ │ │ ├── ms_deform_attn.h │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ ├── ms_deform_attn_cpu.h │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── cuda_version.cu │ │ │ └── vision.cpp │ │ ├── fuse_modules.py │ │ ├── groundingdino.py │ │ ├── ms_deform_attn.py │ │ ├── transformer.py │ │ ├── transformer_vanilla.py │ │ └── utils.py │ ├── __init__.py │ └── registry.py ├── util │ ├── __init__.py │ ├── box_ops.py │ ├── get_tokenlizer.py │ ├── inference.py │ ├── logger.py │ ├── misc.py │ ├── slconfig.py │ ├── slio.py │ ├── time_counter.py │ ├── utils.py │ ├── visualizer.py │ └── vl_utils.py └── version.py ├── img2vid.py ├── licenses.md ├── model_args.py ├── readme.md ├── readme_en.md ├── readme_zh.md ├── requirements.txt ├── sam ├── .flake8 ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── assets │ ├── masks1.png │ ├── masks2.jpg │ ├── model_diagram.png │ ├── notebook1.png │ └── notebook2.png ├── linter.sh ├── notebooks │ ├── automatic_mask_generator_example.ipynb │ ├── images │ │ ├── dog.jpg │ │ ├── groceries.jpg │ │ └── truck.jpg │ ├── onnx_model_example.ipynb │ └── predictor_example.ipynb ├── scripts │ ├── amg.py │ └── export_onnx_model.py ├── segment_anything │ ├── .DS_Store │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py ├── setup.cfg └── setup.py ├── seg_track_anything.py ├── tool ├── detector.py ├── segmentor.py └── transfer_tools.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | ckpt/* 3 | assets/*masks 4 | assets/*mp4 5 | # assets/*zip 6 | assets/*gif 7 | *.pyc 8 | debug 9 | cym_utils 10 | /src 11 | /tracking_results 12 | /aot/results 13 | /aot/pretrain_models 14 | /aot/datasets 15 | /Propainter/weights 16 | env 17 | demo 18 | ckpt 19 | output 20 | results 21 | result 22 | img_logs 23 | __pycache__ -------------------------------------------------------------------------------- /ProPainter/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | # ignored files 4 | version.py 5 | 6 | # ignored files with suffix 7 | *.html 8 | # *.png 9 | # *.jpeg 10 | # *.jpg 11 | # *.gif 12 | *.pt 13 | *.pth 14 | *.dat 15 | *.zip 16 | 17 | # template 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | 124 | # project 125 | experiments_model/ 126 | unreleased/ 127 | results_eval/ 128 | results/ 129 | *debug* 130 | *old* 131 | *.sh -------------------------------------------------------------------------------- /ProPainter/LICENSE: -------------------------------------------------------------------------------- 1 | # S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\ 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. 11 | 12 | 13 | --- 14 | For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg) -------------------------------------------------------------------------------- /ProPainter/RAFT/__init__.py: -------------------------------------------------------------------------------- 1 | # from .demo import RAFT_infer 2 | from .raft import RAFT 3 | -------------------------------------------------------------------------------- /ProPainter/RAFT/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class CorrLayer(torch.autograd.Function): 64 | @staticmethod 65 | def forward(ctx, fmap1, fmap2, coords, r): 66 | fmap1 = fmap1.contiguous() 67 | fmap2 = fmap2.contiguous() 68 | coords = coords.contiguous() 69 | ctx.save_for_backward(fmap1, fmap2, coords) 70 | ctx.r = r 71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) 72 | return corr 73 | 74 | @staticmethod 75 | def backward(ctx, grad_corr): 76 | fmap1, fmap2, coords = ctx.saved_tensors 77 | grad_corr = grad_corr.contiguous() 78 | fmap1_grad, fmap2_grad, coords_grad = \ 79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) 80 | return fmap1_grad, fmap2_grad, coords_grad, None 81 | 82 | 83 | class AlternateCorrBlock: 84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 85 | self.num_levels = num_levels 86 | self.radius = radius 87 | 88 | self.pyramid = [(fmap1, fmap2)] 89 | for i in range(self.num_levels): 90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 92 | self.pyramid.append((fmap1, fmap2)) 93 | 94 | def __call__(self, coords): 95 | 96 | coords = coords.permute(0, 2, 3, 1) 97 | B, H, W, _ = coords.shape 98 | 99 | corr_list = [] 100 | for i in range(self.num_levels): 101 | r = self.radius 102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) 103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) 104 | 105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) 107 | corr_list.append(corr.squeeze(1)) 108 | 109 | corr = torch.stack(corr_list, dim=1) 110 | corr = corr.reshape(B, -1, H, W) 111 | return corr / 16.0 112 | -------------------------------------------------------------------------------- /ProPainter/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import cv2 5 | import glob 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | from .raft import RAFT 11 | from .utils import flow_viz 12 | from .utils.utils import InputPadder 13 | 14 | 15 | 16 | DEVICE = 'cuda' 17 | 18 | def load_image(imfile): 19 | img = np.array(Image.open(imfile)).astype(np.uint8) 20 | img = torch.from_numpy(img).permute(2, 0, 1).float() 21 | return img 22 | 23 | 24 | def load_image_list(image_files): 25 | images = [] 26 | for imfile in sorted(image_files): 27 | images.append(load_image(imfile)) 28 | 29 | images = torch.stack(images, dim=0) 30 | images = images.to(DEVICE) 31 | 32 | padder = InputPadder(images.shape) 33 | return padder.pad(images)[0] 34 | 35 | 36 | def viz(img, flo): 37 | img = img[0].permute(1,2,0).cpu().numpy() 38 | flo = flo[0].permute(1,2,0).cpu().numpy() 39 | 40 | # map flow to rgb image 41 | flo = flow_viz.flow_to_image(flo) 42 | # img_flo = np.concatenate([img, flo], axis=0) 43 | img_flo = flo 44 | 45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) 46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 47 | # cv2.waitKey() 48 | 49 | 50 | def demo(args): 51 | model = torch.nn.DataParallel(RAFT(args)) 52 | model.load_state_dict(torch.load(args.model)) 53 | 54 | model = model.module 55 | model.to(DEVICE) 56 | model.eval() 57 | 58 | with torch.no_grad(): 59 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 60 | glob.glob(os.path.join(args.path, '*.jpg')) 61 | 62 | images = load_image_list(images) 63 | for i in range(images.shape[0]-1): 64 | image1 = images[i,None] 65 | image2 = images[i+1,None] 66 | 67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 68 | viz(image1, flow_up) 69 | 70 | 71 | def RAFT_infer(args): 72 | model = torch.nn.DataParallel(RAFT(args)) 73 | model.load_state_dict(torch.load(args.model)) 74 | 75 | model = model.module 76 | model.to(DEVICE) 77 | model.eval() 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /ProPainter/RAFT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_viz import flow_to_image 2 | from .frame_utils import writeFlow 3 | -------------------------------------------------------------------------------- /ProPainter/RAFT/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /ProPainter/RAFT/utils/flow_viz_pt.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization 2 | import torch 3 | torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 4 | 5 | @torch.no_grad() 6 | def flow_to_image(flow: torch.Tensor) -> torch.Tensor: 7 | 8 | """ 9 | Converts a flow to an RGB image. 10 | 11 | Args: 12 | flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. 13 | 14 | Returns: 15 | img (Tensor): Image Tensor of dtype uint8 where each color corresponds 16 | to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. 17 | """ 18 | 19 | if flow.dtype != torch.float: 20 | raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") 21 | 22 | orig_shape = flow.shape 23 | if flow.ndim == 3: 24 | flow = flow[None] # Add batch dim 25 | 26 | if flow.ndim != 4 or flow.shape[1] != 2: 27 | raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") 28 | 29 | max_norm = torch.sum(flow**2, dim=1).sqrt().max() 30 | epsilon = torch.finfo((flow).dtype).eps 31 | normalized_flow = flow / (max_norm + epsilon) 32 | img = _normalized_flow_to_image(normalized_flow) 33 | 34 | if len(orig_shape) == 3: 35 | img = img[0] # Remove batch dim 36 | return img 37 | 38 | @torch.no_grad() 39 | def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: 40 | 41 | """ 42 | Converts a batch of normalized flow to an RGB image. 43 | 44 | Args: 45 | normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) 46 | Returns: 47 | img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. 48 | """ 49 | 50 | N, _, H, W = normalized_flow.shape 51 | device = normalized_flow.device 52 | flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) 53 | colorwheel = _make_colorwheel().to(device) # shape [55x3] 54 | num_cols = colorwheel.shape[0] 55 | norm = torch.sum(normalized_flow**2, dim=1).sqrt() 56 | a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi 57 | fk = (a + 1) / 2 * (num_cols - 1) 58 | k0 = torch.floor(fk).to(torch.long) 59 | k1 = k0 + 1 60 | k1[k1 == num_cols] = 0 61 | f = fk - k0 62 | 63 | for c in range(colorwheel.shape[1]): 64 | tmp = colorwheel[:, c] 65 | col0 = tmp[k0] / 255.0 66 | col1 = tmp[k1] / 255.0 67 | col = (1 - f) * col0 + f * col1 68 | col = 1 - norm * (1 - col) 69 | flow_image[:, c, :, :] = torch.floor(255. * col) 70 | return flow_image 71 | 72 | 73 | @torch.no_grad() 74 | def _make_colorwheel() -> torch.Tensor: 75 | """ 76 | Generates a color wheel for optical flow visualization as presented in: 77 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 78 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. 79 | 80 | Returns: 81 | colorwheel (Tensor[55, 3]): Colorwheel Tensor. 82 | """ 83 | 84 | RY = 15 85 | YG = 6 86 | GC = 4 87 | CB = 11 88 | BM = 13 89 | MR = 6 90 | 91 | ncols = RY + YG + GC + CB + BM + MR 92 | colorwheel = torch.zeros((ncols, 3)) 93 | col = 0 94 | 95 | # RY 96 | colorwheel[0:RY, 0] = 255 97 | colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY) 98 | col = col + RY 99 | # YG 100 | colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG) 101 | colorwheel[col : col + YG, 1] = 255 102 | col = col + YG 103 | # GC 104 | colorwheel[col : col + GC, 1] = 255 105 | colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC) 106 | col = col + GC 107 | # CB 108 | colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB) 109 | colorwheel[col : col + CB, 2] = 255 110 | col = col + CB 111 | # BM 112 | colorwheel[col : col + BM, 2] = 255 113 | colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM) 114 | col = col + BM 115 | # MR 116 | colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR) 117 | colorwheel[col : col + MR, 0] = 255 118 | return colorwheel 119 | -------------------------------------------------------------------------------- /ProPainter/RAFT/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /ProPainter/RAFT/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /ProPainter/assets/ProPainter_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/ProPainter/assets/ProPainter_pipeline.png -------------------------------------------------------------------------------- /ProPainter/assets/object_removal1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/ProPainter/assets/object_removal1.gif -------------------------------------------------------------------------------- /ProPainter/assets/object_removal2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/ProPainter/assets/object_removal2.gif -------------------------------------------------------------------------------- /ProPainter/assets/propainter_logo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/ProPainter/assets/propainter_logo1.png -------------------------------------------------------------------------------- /ProPainter/assets/propainter_logo1_glow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/ProPainter/assets/propainter_logo1_glow.png -------------------------------------------------------------------------------- /ProPainter/assets/video_completion1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/ProPainter/assets/video_completion1.gif -------------------------------------------------------------------------------- /ProPainter/assets/video_completion2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/ProPainter/assets/video_completion2.gif -------------------------------------------------------------------------------- /ProPainter/assets/video_completion3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/ProPainter/assets/video_completion3.gif -------------------------------------------------------------------------------- /ProPainter/assets/video_completion4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/ProPainter/assets/video_completion4.gif -------------------------------------------------------------------------------- /ProPainter/configs/train_flowcomp.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "experiments_model/", 4 | "train_data_loader": { 5 | "name": "youtube-vos", 6 | "video_root": "your_video_root", 7 | "flow_root": "your_flow_root", 8 | "w": 432, 9 | "h": 240, 10 | "num_local_frames": 10, 11 | "num_ref_frames": 1, 12 | "load_flow": 0 13 | }, 14 | "losses": { 15 | "flow_weight": 0.25 16 | }, 17 | "model": { 18 | "net": "recurrent_flow_completion" 19 | }, 20 | "trainer": { 21 | "version": "trainer_flow_w_edge", 22 | "type": "Adam", 23 | "beta1": 0, 24 | "beta2": 0.99, 25 | "lr": 5e-5, 26 | "batch_size": 8, 27 | "num_workers": 4, 28 | "num_prefetch_queue": 4, 29 | "log_freq": 100, 30 | "save_freq": 5e3, 31 | "iterations": 700e3, 32 | "scheduler": { 33 | "type": "MultiStepLR", 34 | "milestones": [ 35 | 300e3, 400e3, 500e3, 600e3 36 | ], 37 | "gamma": 0.2 38 | } 39 | } 40 | } -------------------------------------------------------------------------------- /ProPainter/configs/train_propainter.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 2023, 3 | "save_dir": "experiments_model/", 4 | "train_data_loader": { 5 | "name": "youtube-vos", 6 | "video_root": "your_video_root", 7 | "flow_root": "your_flow_root", 8 | "w": 432, 9 | "h": 240, 10 | "num_local_frames": 10, 11 | "num_ref_frames": 6, 12 | "load_flow": 0 13 | }, 14 | "losses": { 15 | "hole_weight": 1, 16 | "valid_weight": 1, 17 | "flow_weight": 1, 18 | "adversarial_weight": 0.01, 19 | "GAN_LOSS": "hinge", 20 | "perceptual_weight": 0 21 | }, 22 | "model": { 23 | "net": "propainter", 24 | "no_dis": 0, 25 | "load_d": 1, 26 | "interp_mode": "nearest" 27 | }, 28 | "trainer": { 29 | "version": "trainer", 30 | "type": "Adam", 31 | "beta1": 0, 32 | "beta2": 0.99, 33 | "lr": 1e-4, 34 | "batch_size": 8, 35 | "num_workers": 8, 36 | "num_prefetch_queue": 8, 37 | "log_freq": 100, 38 | "save_freq": 1e4, 39 | "iterations": 700e3, 40 | "scheduler": { 41 | "type": "MultiStepLR", 42 | "milestones": [ 43 | 400e3 44 | ], 45 | "gamma": 0.1 46 | } 47 | } 48 | } -------------------------------------------------------------------------------- /ProPainter/core/dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def get_world_size(): 6 | """Find OMPI world size without calling mpi functions 7 | :rtype: int 8 | """ 9 | if os.environ.get('PMI_SIZE') is not None: 10 | return int(os.environ.get('PMI_SIZE') or 1) 11 | elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: 12 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) 13 | else: 14 | return torch.cuda.device_count() 15 | 16 | 17 | def get_global_rank(): 18 | """Find OMPI world rank without calling mpi functions 19 | :rtype: int 20 | """ 21 | if os.environ.get('PMI_RANK') is not None: 22 | return int(os.environ.get('PMI_RANK') or 0) 23 | elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: 24 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) 25 | else: 26 | return 0 27 | 28 | 29 | def get_local_rank(): 30 | """Find OMPI local rank without calling mpi functions 31 | :rtype: int 32 | """ 33 | if os.environ.get('MPI_LOCALRANKID') is not None: 34 | return int(os.environ.get('MPI_LOCALRANKID') or 0) 35 | elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: 36 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) 37 | else: 38 | return 0 39 | 40 | 41 | def get_master_ip(): 42 | if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: 43 | return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] 44 | elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: 45 | return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') 46 | else: 47 | return "127.0.0.1" 48 | -------------------------------------------------------------------------------- /ProPainter/core/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | LR scheduler from BasicSR https://github.com/xinntao/BasicSR 3 | """ 4 | import math 5 | from collections import Counter 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | 9 | class MultiStepRestartLR(_LRScheduler): 10 | """ MultiStep with restarts learning rate scheme. 11 | Args: 12 | optimizer (torch.nn.optimizer): Torch optimizer. 13 | milestones (list): Iterations that will decrease learning rate. 14 | gamma (float): Decrease ratio. Default: 0.1. 15 | restarts (list): Restart iterations. Default: [0]. 16 | restart_weights (list): Restart weights at each restart iteration. 17 | Default: [1]. 18 | last_epoch (int): Used in _LRScheduler. Default: -1. 19 | """ 20 | def __init__(self, 21 | optimizer, 22 | milestones, 23 | gamma=0.1, 24 | restarts=(0, ), 25 | restart_weights=(1, ), 26 | last_epoch=-1): 27 | self.milestones = Counter(milestones) 28 | self.gamma = gamma 29 | self.restarts = restarts 30 | self.restart_weights = restart_weights 31 | assert len(self.restarts) == len( 32 | self.restart_weights), 'restarts and their weights do not match.' 33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self): 36 | if self.last_epoch in self.restarts: 37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 38 | return [ 39 | group['initial_lr'] * weight 40 | for group in self.optimizer.param_groups 41 | ] 42 | if self.last_epoch not in self.milestones: 43 | return [group['lr'] for group in self.optimizer.param_groups] 44 | return [ 45 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 46 | for group in self.optimizer.param_groups 47 | ] 48 | 49 | 50 | def get_position_from_periods(iteration, cumulative_period): 51 | """Get the position from a period list. 52 | It will return the index of the right-closest number in the period list. 53 | For example, the cumulative_period = [100, 200, 300, 400], 54 | if iteration == 50, return 0; 55 | if iteration == 210, return 2; 56 | if iteration == 300, return 2. 57 | Args: 58 | iteration (int): Current iteration. 59 | cumulative_period (list[int]): Cumulative period list. 60 | Returns: 61 | int: The position of the right-closest number in the period list. 62 | """ 63 | for i, period in enumerate(cumulative_period): 64 | if iteration <= period: 65 | return i 66 | 67 | 68 | class CosineAnnealingRestartLR(_LRScheduler): 69 | """ Cosine annealing with restarts learning rate scheme. 70 | An example of config: 71 | periods = [10, 10, 10, 10] 72 | restart_weights = [1, 0.5, 0.5, 0.5] 73 | eta_min=1e-7 74 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 75 | scheduler will restart with the weights in restart_weights. 76 | Args: 77 | optimizer (torch.nn.optimizer): Torch optimizer. 78 | periods (list): Period for each cosine anneling cycle. 79 | restart_weights (list): Restart weights at each restart iteration. 80 | Default: [1]. 81 | eta_min (float): The mimimum lr. Default: 0. 82 | last_epoch (int): Used in _LRScheduler. Default: -1. 83 | """ 84 | def __init__(self, 85 | optimizer, 86 | periods, 87 | restart_weights=(1, ), 88 | eta_min=1e-7, 89 | last_epoch=-1): 90 | self.periods = periods 91 | self.restart_weights = restart_weights 92 | self.eta_min = eta_min 93 | assert (len(self.periods) == len(self.restart_weights) 94 | ), 'periods and restart_weights should have the same length.' 95 | self.cumulative_period = [ 96 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 97 | ] 98 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 99 | 100 | def get_lr(self): 101 | idx = get_position_from_periods(self.last_epoch, 102 | self.cumulative_period) 103 | current_weight = self.restart_weights[idx] 104 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 105 | current_period = self.periods[idx] 106 | 107 | return [ 108 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 109 | (1 + math.cos(math.pi * ( 110 | (self.last_epoch - nearest_restart) / current_period))) 111 | for base_lr in self.base_lrs 112 | ] 113 | -------------------------------------------------------------------------------- /ProPainter/core/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /ProPainter/model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ProPainter/model/canny/gaussian.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .filter import filter2d, filter2d_separable 7 | from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d 8 | 9 | 10 | def gaussian_blur2d( 11 | input: torch.Tensor, 12 | kernel_size: Tuple[int, int], 13 | sigma: Tuple[float, float], 14 | border_type: str = 'reflect', 15 | separable: bool = True, 16 | ) -> torch.Tensor: 17 | r"""Create an operator that blurs a tensor using a Gaussian filter. 18 | 19 | .. image:: _static/img/gaussian_blur2d.png 20 | 21 | The operator smooths the given tensor with a gaussian kernel by convolving 22 | it to each channel. It supports batched operation. 23 | 24 | Arguments: 25 | input: the input tensor with shape :math:`(B,C,H,W)`. 26 | kernel_size: the size of the kernel. 27 | sigma: the standard deviation of the kernel. 28 | border_type: the padding mode to be applied before convolving. 29 | The expected modes are: ``'constant'``, ``'reflect'``, 30 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 31 | separable: run as composition of two 1d-convolutions. 32 | 33 | Returns: 34 | the blurred tensor with shape :math:`(B, C, H, W)`. 35 | 36 | .. note:: 37 | See a working example `here `__. 39 | 40 | Examples: 41 | >>> input = torch.rand(2, 4, 5, 5) 42 | >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5)) 43 | >>> output.shape 44 | torch.Size([2, 4, 5, 5]) 45 | """ 46 | if separable: 47 | kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1]) 48 | kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0]) 49 | out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type) 50 | else: 51 | kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma) 52 | out = filter2d(input, kernel[None], border_type) 53 | return out 54 | 55 | 56 | class GaussianBlur2d(nn.Module): 57 | r"""Create an operator that blurs a tensor using a Gaussian filter. 58 | 59 | The operator smooths the given tensor with a gaussian kernel by convolving 60 | it to each channel. It supports batched operation. 61 | 62 | Arguments: 63 | kernel_size: the size of the kernel. 64 | sigma: the standard deviation of the kernel. 65 | border_type: the padding mode to be applied before convolving. 66 | The expected modes are: ``'constant'``, ``'reflect'``, 67 | ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. 68 | separable: run as composition of two 1d-convolutions. 69 | 70 | Returns: 71 | the blurred tensor. 72 | 73 | Shape: 74 | - Input: :math:`(B, C, H, W)` 75 | - Output: :math:`(B, C, H, W)` 76 | 77 | Examples:: 78 | 79 | >>> input = torch.rand(2, 4, 5, 5) 80 | >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5)) 81 | >>> output = gauss(input) # 2x4x5x5 82 | >>> output.shape 83 | torch.Size([2, 4, 5, 5]) 84 | """ 85 | 86 | def __init__( 87 | self, 88 | kernel_size: Tuple[int, int], 89 | sigma: Tuple[float, float], 90 | border_type: str = 'reflect', 91 | separable: bool = True, 92 | ) -> None: 93 | super().__init__() 94 | self.kernel_size: Tuple[int, int] = kernel_size 95 | self.sigma: Tuple[float, float] = sigma 96 | self.border_type = border_type 97 | self.separable = separable 98 | 99 | def __repr__(self) -> str: 100 | return ( 101 | self.__class__.__name__ 102 | + '(kernel_size=' 103 | + str(self.kernel_size) 104 | + ', ' 105 | + 'sigma=' 106 | + str(self.sigma) 107 | + ', ' 108 | + 'border_type=' 109 | + self.border_type 110 | + 'separable=' 111 | + str(self.separable) 112 | + ')' 113 | ) 114 | 115 | def forward(self, input: torch.Tensor) -> torch.Tensor: 116 | return gaussian_blur2d(input, self.kernel_size, self.sigma, self.border_type, self.separable) -------------------------------------------------------------------------------- /ProPainter/model/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import logging 8 | import numpy as np 9 | from os import path as osp 10 | 11 | def constant_init(module, val, bias=0): 12 | if hasattr(module, 'weight') and module.weight is not None: 13 | nn.init.constant_(module.weight, val) 14 | if hasattr(module, 'bias') and module.bias is not None: 15 | nn.init.constant_(module.bias, bias) 16 | 17 | initialized_logger = {} 18 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 19 | """Get the root logger. 20 | The logger will be initialized if it has not been initialized. By default a 21 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 22 | also be added. 23 | Args: 24 | logger_name (str): root logger name. Default: 'basicsr'. 25 | log_file (str | None): The log filename. If specified, a FileHandler 26 | will be added to the root logger. 27 | log_level (int): The root logger level. Note that only the process of 28 | rank 0 is affected, while other processes will set the level to 29 | "Error" and be silent most of the time. 30 | Returns: 31 | logging.Logger: The root logger. 32 | """ 33 | logger = logging.getLogger(logger_name) 34 | # if the logger has been initialized, just return it 35 | if logger_name in initialized_logger: 36 | return logger 37 | 38 | format_str = '%(asctime)s %(levelname)s: %(message)s' 39 | stream_handler = logging.StreamHandler() 40 | stream_handler.setFormatter(logging.Formatter(format_str)) 41 | logger.addHandler(stream_handler) 42 | logger.propagate = False 43 | 44 | if log_file is not None: 45 | logger.setLevel(log_level) 46 | # add file handler 47 | # file_handler = logging.FileHandler(log_file, 'w') 48 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log 49 | file_handler.setFormatter(logging.Formatter(format_str)) 50 | file_handler.setLevel(log_level) 51 | logger.addHandler(file_handler) 52 | initialized_logger[logger_name] = True 53 | return logger 54 | 55 | 56 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ 57 | torch.__version__)[0][:3])] >= [1, 12, 0] 58 | 59 | def gpu_is_available(): 60 | if IS_HIGH_VERSION: 61 | if torch.backends.mps.is_available(): 62 | return True 63 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False 64 | 65 | def get_device(gpu_id=None): 66 | if gpu_id is None: 67 | gpu_str = '' 68 | elif isinstance(gpu_id, int): 69 | gpu_str = f':{gpu_id}' 70 | else: 71 | raise TypeError('Input should be int value.') 72 | 73 | if IS_HIGH_VERSION: 74 | if torch.backends.mps.is_available(): 75 | return torch.device('mps'+gpu_str) 76 | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') 77 | 78 | 79 | def set_random_seed(seed): 80 | """Set random seeds.""" 81 | random.seed(seed) 82 | np.random.seed(seed) 83 | torch.manual_seed(seed) 84 | torch.cuda.manual_seed(seed) 85 | torch.cuda.manual_seed_all(seed) 86 | 87 | 88 | def get_time_str(): 89 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 90 | 91 | 92 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 93 | """Scan a directory to find the interested files. 94 | 95 | Args: 96 | dir_path (str): Path of the directory. 97 | suffix (str | tuple(str), optional): File suffix that we are 98 | interested in. Default: None. 99 | recursive (bool, optional): If set to True, recursively scan the 100 | directory. Default: False. 101 | full_path (bool, optional): If set to True, include the dir_path. 102 | Default: False. 103 | 104 | Returns: 105 | A generator for all the interested files with relative pathes. 106 | """ 107 | 108 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 109 | raise TypeError('"suffix" must be a string or tuple of strings') 110 | 111 | root = dir_path 112 | 113 | def _scandir(dir_path, suffix, recursive): 114 | for entry in os.scandir(dir_path): 115 | if not entry.name.startswith('.') and entry.is_file(): 116 | if full_path: 117 | return_path = entry.path 118 | else: 119 | return_path = osp.relpath(entry.path, root) 120 | 121 | if suffix is None: 122 | yield return_path 123 | elif return_path.endswith(suffix): 124 | yield return_path 125 | else: 126 | if recursive: 127 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 128 | else: 129 | continue 130 | 131 | return _scandir(dir_path, suffix=suffix, recursive=recursive) -------------------------------------------------------------------------------- /ProPainter/model/modules/deformconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init as init 4 | from torch.nn.modules.utils import _pair, _single 5 | import math 6 | 7 | class ModulatedDeformConv2d(nn.Module): 8 | def __init__(self, 9 | in_channels, 10 | out_channels, 11 | kernel_size, 12 | stride=1, 13 | padding=0, 14 | dilation=1, 15 | groups=1, 16 | deform_groups=1, 17 | bias=True): 18 | super(ModulatedDeformConv2d, self).__init__() 19 | 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _pair(kernel_size) 23 | self.stride = stride 24 | self.padding = padding 25 | self.dilation = dilation 26 | self.groups = groups 27 | self.deform_groups = deform_groups 28 | self.with_bias = bias 29 | # enable compatibility with nn.Conv2d 30 | self.transposed = False 31 | self.output_padding = _single(0) 32 | 33 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 34 | if bias: 35 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 36 | else: 37 | self.register_parameter('bias', None) 38 | self.init_weights() 39 | 40 | def init_weights(self): 41 | n = self.in_channels 42 | for k in self.kernel_size: 43 | n *= k 44 | stdv = 1. / math.sqrt(n) 45 | self.weight.data.uniform_(-stdv, stdv) 46 | if self.bias is not None: 47 | self.bias.data.zero_() 48 | 49 | if hasattr(self, 'conv_offset'): 50 | self.conv_offset.weight.data.zero_() 51 | self.conv_offset.bias.data.zero_() 52 | 53 | def forward(self, x, offset, mask): 54 | pass -------------------------------------------------------------------------------- /ProPainter/requirements.txt: -------------------------------------------------------------------------------- 1 | av 2 | addict 3 | einops 4 | future 5 | numpy 6 | scipy 7 | opencv-python 8 | matplotlib 9 | scikit-image 10 | torch>=1.7.1 11 | torchvision>=0.8.2 12 | imageio-ffmpeg 13 | pyyaml 14 | requests 15 | timm 16 | yapf -------------------------------------------------------------------------------- /ProPainter/scripts/compute_flow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | sys.path.append(".") 4 | 5 | import os 6 | import cv2 7 | import argparse 8 | from PIL import Image 9 | import torch 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | 13 | from RAFT import RAFT 14 | from utils.flow_util import * 15 | 16 | def imwrite(img, file_path, params=None, auto_mkdir=True): 17 | if auto_mkdir: 18 | dir_name = os.path.abspath(os.path.dirname(file_path)) 19 | os.makedirs(dir_name, exist_ok=True) 20 | return cv2.imwrite(file_path, img, params) 21 | 22 | def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'): 23 | """Initializes the RAFT model. 24 | """ 25 | args = argparse.ArgumentParser() 26 | args.raft_model = model_path 27 | args.small = False 28 | args.mixed_precision = False 29 | args.alternate_corr = False 30 | 31 | model = torch.nn.DataParallel(RAFT(args)) 32 | model.load_state_dict(torch.load(args.raft_model)) 33 | 34 | model = model.module 35 | model.to(device) 36 | model.eval() 37 | 38 | return model 39 | 40 | 41 | if __name__ == '__main__': 42 | device = 'cuda' 43 | 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('-i', '--root_path', type=str, default='your_dataset_root/youtube-vos/JPEGImages') 46 | parser.add_argument('-o', '--save_path', type=str, default='your_dataset_root/youtube-vos/Flows_flo') 47 | parser.add_argument('--height', type=int, default=240) 48 | parser.add_argument('--width', type=int, default=432) 49 | 50 | args = parser.parse_args() 51 | 52 | # Flow model 53 | RAFT_model = initialize_RAFT(device=device) 54 | 55 | root_path = args.root_path 56 | save_path = args.save_path 57 | h_new, w_new = (args.height, args.width) 58 | 59 | file_list = sorted(os.listdir(root_path)) 60 | for f in file_list: 61 | print(f'Processing: {f} ...') 62 | m_list = sorted(os.listdir(os.path.join(root_path, f))) 63 | len_m = len(m_list) 64 | for i in range(len_m-1): 65 | img1_path = os.path.join(root_path, f, m_list[i]) 66 | img2_path = os.path.join(root_path, f, m_list[i+1]) 67 | img1 = Image.fromarray(cv2.imread(img1_path)) 68 | img2 = Image.fromarray(cv2.imread(img2_path)) 69 | 70 | transform = transforms.Compose([transforms.ToTensor()]) 71 | 72 | img1 = transform(img1).unsqueeze(0).to(device)[:,[2,1,0],:,:] 73 | img2 = transform(img2).unsqueeze(0).to(device)[:,[2,1,0],:,:] 74 | 75 | # upsize to a multiple of 16 76 | # h, w = img1.shape[2:4] 77 | # w_new = w if (w % 16) == 0 else 16 * (w // 16 + 1) 78 | # h_new = h if (h % 16) == 0 else 16 * (h // 16 + 1) 79 | 80 | 81 | img1 = F.interpolate(input=img1, 82 | size=(h_new, w_new), 83 | mode='bilinear', 84 | align_corners=False) 85 | img2 = F.interpolate(input=img2, 86 | size=(h_new, w_new), 87 | mode='bilinear', 88 | align_corners=False) 89 | 90 | with torch.no_grad(): 91 | img1 = img1*2 - 1 92 | img2 = img2*2 - 1 93 | 94 | _, flow_f = RAFT_model(img1, img2, iters=20, test_mode=True) 95 | _, flow_b = RAFT_model(img2, img1, iters=20, test_mode=True) 96 | 97 | 98 | flow_f = flow_f[0].permute(1,2,0).cpu().numpy() 99 | flow_b = flow_b[0].permute(1,2,0).cpu().numpy() 100 | 101 | # flow_f = resize_flow(flow_f, w_new, h_new) 102 | # flow_b = resize_flow(flow_b, w_new, h_new) 103 | 104 | save_flow_f = os.path.join(save_path, f, f'{m_list[i][:-4]}_{m_list[i+1][:-4]}_f.flo') 105 | save_flow_b = os.path.join(save_path, f, f'{m_list[i+1][:-4]}_{m_list[i][:-4]}_b.flo') 106 | 107 | flowwrite(flow_f, save_flow_f, quantize=False) 108 | flowwrite(flow_b, save_flow_b, quantize=False) 109 | -------------------------------------------------------------------------------- /ProPainter/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import subprocess 5 | 6 | from shutil import copyfile 7 | import torch.distributed as dist 8 | 9 | import torch 10 | import torch.multiprocessing as mp 11 | 12 | import core 13 | import core.trainer 14 | import core.trainer_flow_w_edge 15 | 16 | 17 | # import warnings 18 | # warnings.filterwarnings("ignore") 19 | 20 | from core.dist import ( 21 | get_world_size, 22 | get_local_rank, 23 | get_global_rank, 24 | get_master_ip, 25 | ) 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('-c', 29 | '--config', 30 | default='configs/train_propainter.json', 31 | type=str) 32 | parser.add_argument('-p', '--port', default='23490', type=str) 33 | args = parser.parse_args() 34 | 35 | 36 | def main_worker(rank, config): 37 | if 'local_rank' not in config: 38 | config['local_rank'] = config['global_rank'] = rank 39 | if config['distributed']: 40 | torch.cuda.set_device(int(config['local_rank'])) 41 | torch.distributed.init_process_group(backend='nccl', 42 | init_method=config['init_method'], 43 | world_size=config['world_size'], 44 | rank=config['global_rank'], 45 | group_name='mtorch') 46 | print('using GPU {}-{} for training'.format(int(config['global_rank']), 47 | int(config['local_rank']))) 48 | 49 | 50 | config['save_dir'] = os.path.join( 51 | config['save_dir'], 52 | '{}_{}'.format(config['model']['net'], 53 | os.path.basename(args.config).split('.')[0])) 54 | 55 | config['save_metric_dir'] = os.path.join( 56 | './scores', 57 | '{}_{}'.format(config['model']['net'], 58 | os.path.basename(args.config).split('.')[0])) 59 | 60 | if torch.cuda.is_available(): 61 | config['device'] = torch.device("cuda:{}".format(config['local_rank'])) 62 | else: 63 | config['device'] = 'cpu' 64 | 65 | if (not config['distributed']) or config['global_rank'] == 0: 66 | os.makedirs(config['save_dir'], exist_ok=True) 67 | config_path = os.path.join(config['save_dir'], 68 | args.config.split('/')[-1]) 69 | if not os.path.isfile(config_path): 70 | copyfile(args.config, config_path) 71 | print('[**] create folder {}'.format(config['save_dir'])) 72 | 73 | trainer_version = config['trainer']['version'] 74 | trainer = core.__dict__[trainer_version].__dict__['Trainer'](config) 75 | # Trainer(config) 76 | trainer.train() 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | torch.backends.cudnn.benchmark = True 82 | 83 | mp.set_sharing_strategy('file_system') 84 | 85 | # loading configs 86 | config = json.load(open(args.config)) 87 | 88 | # setting distributed configurations 89 | # config['world_size'] = get_world_size() 90 | config['world_size'] = torch.cuda.device_count() 91 | config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" 92 | config['distributed'] = True if config['world_size'] > 1 else False 93 | print('world_size:', config['world_size']) 94 | # setup distributed parallel training environments 95 | 96 | # if get_master_ip() == "127.0.0.X": 97 | # # manually launch distributed processes 98 | # mp.spawn(main_worker, nprocs=config['world_size'], args=(config, )) 99 | # else: 100 | # # multiple processes have been launched by openmpi 101 | # config['local_rank'] = get_local_rank() 102 | # config['global_rank'] = get_global_rank() 103 | # main_worker(-1, config) 104 | 105 | mp.spawn(main_worker, nprocs=torch.cuda.device_count(), args=(config, )) -------------------------------------------------------------------------------- /ProPainter/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | def sizeof_fmt(size, suffix='B'): 9 | """Get human readable file size. 10 | 11 | Args: 12 | size (int): File size. 13 | suffix (str): Suffix. Default: 'B'. 14 | 15 | Return: 16 | str: Formated file siz. 17 | """ 18 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 19 | if abs(size) < 1024.0: 20 | return f'{size:3.1f} {unit}{suffix}' 21 | size /= 1024.0 22 | return f'{size:3.1f} Y{suffix}' 23 | 24 | 25 | def download_file_from_google_drive(file_id, save_path): 26 | """Download files from google drive. 27 | Ref: 28 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 29 | Args: 30 | file_id (str): File id. 31 | save_path (str): Save path. 32 | """ 33 | 34 | session = requests.Session() 35 | URL = 'https://docs.google.com/uc?export=download' 36 | params = {'id': file_id} 37 | 38 | response = session.get(URL, params=params, stream=True) 39 | token = get_confirm_token(response) 40 | if token: 41 | params['confirm'] = token 42 | response = session.get(URL, params=params, stream=True) 43 | 44 | # get file size 45 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 46 | print(response_file_size) 47 | if 'Content-Range' in response_file_size.headers: 48 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 49 | else: 50 | file_size = None 51 | 52 | save_response_content(response, save_path, file_size) 53 | 54 | 55 | def get_confirm_token(response): 56 | for key, value in response.cookies.items(): 57 | if key.startswith('download_warning'): 58 | return value 59 | return None 60 | 61 | 62 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 63 | if file_size is not None: 64 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 65 | 66 | readable_file_size = sizeof_fmt(file_size) 67 | else: 68 | pbar = None 69 | 70 | with open(destination, 'wb') as f: 71 | downloaded_size = 0 72 | for chunk in response.iter_content(chunk_size): 73 | downloaded_size += chunk_size 74 | if pbar is not None: 75 | pbar.update(1) 76 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 77 | if chunk: # filter out keep-alive new chunks 78 | f.write(chunk) 79 | if pbar is not None: 80 | pbar.close() 81 | 82 | 83 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 84 | """Load file form http url, will download models if necessary. 85 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 86 | Args: 87 | url (str): URL to be downloaded. 88 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 89 | Default: None. 90 | progress (bool): Whether to show the download progress. Default: True. 91 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 92 | Returns: 93 | str: The path to the downloaded file. 94 | """ 95 | if model_dir is None: # use the pytorch hub_dir 96 | hub_dir = get_dir() 97 | model_dir = os.path.join(hub_dir, 'checkpoints') 98 | 99 | os.makedirs(model_dir, exist_ok=True) 100 | 101 | parts = urlparse(url) 102 | filename = os.path.basename(parts.path) 103 | if file_name is not None: 104 | filename = file_name 105 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 106 | if not os.path.exists(cached_file): 107 | print(f'Downloading: "{url}" to {cached_file}\n') 108 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 109 | return cached_file -------------------------------------------------------------------------------- /aot/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, z-x-yang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /aot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/__init__.py -------------------------------------------------------------------------------- /aot/configs/models/aotb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTB' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | -------------------------------------------------------------------------------- /aot/configs/models/aotl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTL' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | 11 | self.TRAIN_LONG_TERM_MEM_GAP = 2 12 | 13 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/aots.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTS' 8 | 9 | self.MODEL_LSTT_NUM = 2 10 | -------------------------------------------------------------------------------- /aot/configs/models/aott.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTT' 8 | -------------------------------------------------------------------------------- /aot/configs/models/deaotb.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTB' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | -------------------------------------------------------------------------------- /aot/configs/models/deaotl.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTL' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | 11 | self.TRAIN_LONG_TERM_MEM_GAP = 2 12 | 13 | self.TEST_LONG_TERM_MEM_GAP = 5 14 | -------------------------------------------------------------------------------- /aot/configs/models/deaots.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTS' 8 | 9 | self.MODEL_LSTT_NUM = 2 10 | -------------------------------------------------------------------------------- /aot/configs/models/deaott.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTT' 8 | -------------------------------------------------------------------------------- /aot/configs/models/default.py: -------------------------------------------------------------------------------- 1 | class DefaultModelConfig(): 2 | def __init__(self): 3 | self.MODEL_NAME = 'AOTDefault' 4 | 5 | self.MODEL_VOS = 'aot' 6 | self.MODEL_ENGINE = 'aotengine' 7 | self.MODEL_ALIGN_CORNERS = True 8 | self.MODEL_ENCODER = 'mobilenetv2' 9 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/mobilenet_v2-b0353104.pth' 10 | self.MODEL_ENCODER_DIM = [24, 32, 96, 1280] # 4x, 8x, 16x, 16x 11 | self.MODEL_ENCODER_EMBEDDING_DIM = 256 12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = True 13 | self.MODEL_FREEZE_BN = True 14 | self.MODEL_FREEZE_BACKBONE = False 15 | self.MODEL_MAX_OBJ_NUM = 10 16 | self.MODEL_SELF_HEADS = 8 17 | self.MODEL_ATT_HEADS = 8 18 | self.MODEL_LSTT_NUM = 1 19 | self.MODEL_EPSILON = 1e-5 20 | self.MODEL_USE_PREV_PROB = False 21 | 22 | self.TRAIN_LONG_TERM_MEM_GAP = 9999 23 | self.TRAIN_AUG_TYPE = 'v1' 24 | 25 | self.TEST_LONG_TERM_MEM_GAP = 9999 26 | 27 | self.TEST_SHORT_TERM_MEM_SKIP = 1 28 | -------------------------------------------------------------------------------- /aot/configs/models/default_deaot.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig as BaseConfig 2 | 3 | 4 | class DefaultModelConfig(BaseConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTDefault' 8 | 9 | self.MODEL_VOS = 'deaot' 10 | self.MODEL_ENGINE = 'deaotengine' 11 | 12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = False 13 | 14 | self.MODEL_SELF_HEADS = 1 15 | self.MODEL_ATT_HEADS = 1 16 | 17 | self.TRAIN_AUG_TYPE = 'v2' 18 | -------------------------------------------------------------------------------- /aot/configs/models/r101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/r50_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R50_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet50' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/r50_deaotl.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R50_DeAOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet50' 10 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 11 | 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 17 | -------------------------------------------------------------------------------- /aot/configs/models/rs101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnest101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnest101-22405ba7.pth' # https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/swinb_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'SwinB_AOTL' 8 | 9 | self.MODEL_ENCODER = 'swin_base' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth 11 | self.MODEL_ALIGN_CORNERS = False 12 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x 13 | self.MODEL_LSTT_NUM = 3 14 | 15 | self.TRAIN_LONG_TERM_MEM_GAP = 2 16 | 17 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/models/swinb_deaotl.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'SwinB_DeAOTL' 8 | 9 | self.MODEL_ENCODER = 'swin_base' 10 | self.MODEL_ALIGN_CORNERS = False 11 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x 12 | 13 | self.MODEL_LSTT_NUM = 3 14 | 15 | self.TRAIN_LONG_TERM_MEM_GAP = 2 16 | 17 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot/configs/pre.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultEngineConfig 2 | 3 | 4 | class EngineConfig(DefaultEngineConfig): 5 | def __init__(self, exp_name='default', model='AOTT'): 6 | super().__init__(exp_name, model) 7 | self.STAGE_NAME = 'PRE' 8 | 9 | self.init_dir() 10 | 11 | self.DATASETS = ['static'] 12 | 13 | self.DATA_DYNAMIC_MERGE_PROB = 1.0 14 | 15 | self.TRAIN_LR = 4e-4 16 | self.TRAIN_LR_MIN = 2e-5 17 | self.TRAIN_WEIGHT_DECAY = 0.03 18 | self.TRAIN_SEQ_TRAINING_START_RATIO = 1.0 19 | self.TRAIN_AUX_LOSS_RATIO = 0.1 20 | -------------------------------------------------------------------------------- /aot/configs/pre_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['davis2017'] 13 | 14 | self.TRAIN_TOTAL_STEPS = 50000 15 | 16 | pretrain_stage = 'PRE' 17 | pretrain_ckpt = 'save_step_100000.pth' 18 | self.PRETRAIN_FULL = True # if False, load encoder only 19 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 20 | self.EXP_NAME, pretrain_stage, 21 | 'ema_ckpt', pretrain_ckpt) 22 | -------------------------------------------------------------------------------- /aot/configs/pre_ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_YTB' 9 | 10 | self.init_dir() 11 | 12 | pretrain_stage = 'PRE' 13 | pretrain_ckpt = 'save_step_100000.pth' 14 | self.PRETRAIN_FULL = True # if False, load encoder only 15 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 16 | self.EXP_NAME, pretrain_stage, 17 | 'ema_ckpt', pretrain_ckpt) 18 | -------------------------------------------------------------------------------- /aot/configs/pre_ytb_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_YTB_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['youtubevos', 'davis2017'] 13 | 14 | pretrain_stage = 'PRE' 15 | pretrain_ckpt = 'save_step_100000.pth' 16 | self.PRETRAIN_FULL = True # if False, load encoder only 17 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 18 | self.EXP_NAME, pretrain_stage, 19 | 'ema_ckpt', pretrain_ckpt) 20 | -------------------------------------------------------------------------------- /aot/configs/ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'YTB' 9 | 10 | self.init_dir() 11 | -------------------------------------------------------------------------------- /aot/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/dataloaders/__init__.py -------------------------------------------------------------------------------- /aot/networks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/networks/.DS_Store -------------------------------------------------------------------------------- /aot/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/networks/__init__.py -------------------------------------------------------------------------------- /aot/networks/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.decoders.fpn import FPNSegmentationHead 2 | 3 | 4 | def build_decoder(name, **kwargs): 5 | 6 | if name == 'fpn': 7 | return FPNSegmentationHead(**kwargs) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /aot/networks/decoders/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.layers.basic import ConvGN 5 | 6 | 7 | class FPNSegmentationHead(nn.Module): 8 | def __init__(self, 9 | in_dim, 10 | out_dim, 11 | decode_intermediate_input=True, 12 | hidden_dim=256, 13 | shortcut_dims=[24, 32, 96, 1280], 14 | align_corners=True): 15 | super().__init__() 16 | self.align_corners = align_corners 17 | 18 | self.decode_intermediate_input = decode_intermediate_input 19 | 20 | self.conv_in = ConvGN(in_dim, hidden_dim, 1) 21 | 22 | self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3) 23 | self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3) 24 | self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3) 25 | 26 | self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1) 27 | self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1) 28 | self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1) 29 | 30 | self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1) 31 | 32 | self._init_weight() 33 | 34 | def forward(self, inputs, shortcuts): 35 | 36 | if self.decode_intermediate_input: 37 | x = torch.cat(inputs, dim=1) 38 | else: 39 | x = inputs[-1] 40 | 41 | x = F.relu_(self.conv_in(x)) 42 | x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x)) 43 | 44 | x = F.interpolate(x, 45 | size=shortcuts[-3].size()[-2:], 46 | mode="bilinear", 47 | align_corners=self.align_corners) 48 | x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x)) 49 | 50 | x = F.interpolate(x, 51 | size=shortcuts[-4].size()[-2:], 52 | mode="bilinear", 53 | align_corners=self.align_corners) 54 | x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x)) 55 | 56 | x = self.conv_out(x) 57 | 58 | return x 59 | 60 | def _init_weight(self): 61 | for p in self.parameters(): 62 | if p.dim() > 1: 63 | nn.init.xavier_uniform_(p) 64 | -------------------------------------------------------------------------------- /aot/networks/encoders/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/networks/encoders/.DS_Store -------------------------------------------------------------------------------- /aot/networks/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.encoders.mobilenetv2 import MobileNetV2 2 | from networks.encoders.mobilenetv3 import MobileNetV3Large 3 | from networks.encoders.resnet import ResNet101, ResNet50 4 | from networks.encoders.resnest import resnest 5 | from networks.encoders.swin import build_swin_model 6 | from networks.layers.normalization import FrozenBatchNorm2d 7 | from torch import nn 8 | 9 | 10 | def build_encoder(name, frozen_bn=True, freeze_at=-1): 11 | if frozen_bn: 12 | BatchNorm = FrozenBatchNorm2d 13 | else: 14 | BatchNorm = nn.BatchNorm2d 15 | 16 | if name == 'mobilenetv2': 17 | return MobileNetV2(16, BatchNorm, freeze_at=freeze_at) 18 | elif name == 'mobilenetv3': 19 | return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at) 20 | elif name == 'resnet50': 21 | return ResNet50(16, BatchNorm, freeze_at=freeze_at) 22 | elif name == 'resnet101': 23 | return ResNet101(16, BatchNorm, freeze_at=freeze_at) 24 | elif name == 'resnest50': 25 | return resnest.resnest50(norm_layer=BatchNorm, 26 | dilation=2, 27 | freeze_at=freeze_at) 28 | elif name == 'resnest101': 29 | return resnest.resnest101(norm_layer=BatchNorm, 30 | dilation=2, 31 | freeze_at=freeze_at) 32 | elif 'swin' in name: 33 | return build_swin_model(name, freeze_at=freeze_at) 34 | else: 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /aot/networks/encoders/resnest/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnest import * 2 | -------------------------------------------------------------------------------- /aot/networks/encoders/resnest/resnest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnet import ResNet, Bottleneck 3 | 4 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 5 | 6 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 7 | 8 | _model_sha256 = { 9 | name: checksum 10 | for checksum, name in [ 11 | ('528c19ca', 'resnest50'), 12 | ('22405ba7', 'resnest101'), 13 | ('75117900', 'resnest200'), 14 | ('0cc87c48', 'resnest269'), 15 | ] 16 | } 17 | 18 | 19 | def short_hash(name): 20 | if name not in _model_sha256: 21 | raise ValueError( 22 | 'Pretrained model for {name} is not available.'.format(name=name)) 23 | return _model_sha256[name][:8] 24 | 25 | 26 | resnest_model_urls = { 27 | name: _url_format.format(name, short_hash(name)) 28 | for name in _model_sha256.keys() 29 | } 30 | 31 | 32 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 33 | model = ResNet(Bottleneck, [3, 4, 6, 3], 34 | radix=2, 35 | groups=1, 36 | bottleneck_width=64, 37 | deep_stem=True, 38 | stem_width=32, 39 | avg_down=True, 40 | avd=True, 41 | avd_first=False, 42 | **kwargs) 43 | if pretrained: 44 | model.load_state_dict( 45 | torch.hub.load_state_dict_from_url(resnest_model_urls['resnest50'], 46 | progress=True, 47 | check_hash=True)) 48 | return model 49 | 50 | 51 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 52 | model = ResNet(Bottleneck, [3, 4, 23, 3], 53 | radix=2, 54 | groups=1, 55 | bottleneck_width=64, 56 | deep_stem=True, 57 | stem_width=64, 58 | avg_down=True, 59 | avd=True, 60 | avd_first=False, 61 | **kwargs) 62 | if pretrained: 63 | model.load_state_dict( 64 | torch.hub.load_state_dict_from_url( 65 | resnest_model_urls['resnest101'], 66 | progress=True, 67 | check_hash=True)) 68 | return model 69 | 70 | 71 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): 72 | model = ResNet(Bottleneck, [3, 24, 36, 3], 73 | radix=2, 74 | groups=1, 75 | bottleneck_width=64, 76 | deep_stem=True, 77 | stem_width=64, 78 | avg_down=True, 79 | avd=True, 80 | avd_first=False, 81 | **kwargs) 82 | if pretrained: 83 | model.load_state_dict( 84 | torch.hub.load_state_dict_from_url( 85 | resnest_model_urls['resnest200'], 86 | progress=True, 87 | check_hash=True)) 88 | return model 89 | 90 | 91 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): 92 | model = ResNet(Bottleneck, [3, 30, 48, 8], 93 | radix=2, 94 | groups=1, 95 | bottleneck_width=64, 96 | deep_stem=True, 97 | stem_width=64, 98 | avg_down=True, 99 | avd=True, 100 | avd_first=False, 101 | **kwargs) 102 | if pretrained: 103 | model.load_state_dict( 104 | torch.hub.load_state_dict_from_url( 105 | resnest_model_urls['resnest269'], 106 | progress=True, 107 | check_hash=True)) 108 | return model 109 | -------------------------------------------------------------------------------- /aot/networks/encoders/resnest/splat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv2d, Module, ReLU 5 | from torch.nn.modules.utils import _pair 6 | 7 | __all__ = ['SplAtConv2d', 'DropBlock2D'] 8 | 9 | 10 | class DropBlock2D(object): 11 | def __init__(self, *args, **kwargs): 12 | raise NotImplementedError 13 | 14 | 15 | class SplAtConv2d(Module): 16 | """Split-Attention Conv2d 17 | """ 18 | def __init__(self, 19 | in_channels, 20 | channels, 21 | kernel_size, 22 | stride=(1, 1), 23 | padding=(0, 0), 24 | dilation=(1, 1), 25 | groups=1, 26 | bias=True, 27 | radix=2, 28 | reduction_factor=4, 29 | rectify=False, 30 | rectify_avg=False, 31 | norm_layer=None, 32 | dropblock_prob=0.0, 33 | **kwargs): 34 | super(SplAtConv2d, self).__init__() 35 | padding = _pair(padding) 36 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 37 | self.rectify_avg = rectify_avg 38 | inter_channels = max(in_channels * radix // reduction_factor, 32) 39 | self.radix = radix 40 | self.cardinality = groups 41 | self.channels = channels 42 | self.dropblock_prob = dropblock_prob 43 | if self.rectify: 44 | from rfconv import RFConv2d 45 | self.conv = RFConv2d(in_channels, 46 | channels * radix, 47 | kernel_size, 48 | stride, 49 | padding, 50 | dilation, 51 | groups=groups * radix, 52 | bias=bias, 53 | average_mode=rectify_avg, 54 | **kwargs) 55 | else: 56 | self.conv = Conv2d(in_channels, 57 | channels * radix, 58 | kernel_size, 59 | stride, 60 | padding, 61 | dilation, 62 | groups=groups * radix, 63 | bias=bias, 64 | **kwargs) 65 | self.use_bn = norm_layer is not None 66 | if self.use_bn: 67 | self.bn0 = norm_layer(channels * radix) 68 | self.relu = ReLU(inplace=True) 69 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 70 | if self.use_bn: 71 | self.bn1 = norm_layer(inter_channels) 72 | self.fc2 = Conv2d(inter_channels, 73 | channels * radix, 74 | 1, 75 | groups=self.cardinality) 76 | if dropblock_prob > 0.0: 77 | self.dropblock = DropBlock2D(dropblock_prob, 3) 78 | self.rsoftmax = rSoftMax(radix, groups) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | if self.use_bn: 83 | x = self.bn0(x) 84 | if self.dropblock_prob > 0.0: 85 | x = self.dropblock(x) 86 | x = self.relu(x) 87 | 88 | batch, rchannel = x.shape[:2] 89 | if self.radix > 1: 90 | if torch.__version__ < '1.5': 91 | splited = torch.split(x, int(rchannel // self.radix), dim=1) 92 | else: 93 | splited = torch.split(x, rchannel // self.radix, dim=1) 94 | gap = sum(splited) 95 | else: 96 | gap = x 97 | gap = F.adaptive_avg_pool2d(gap, 1) 98 | gap = self.fc1(gap) 99 | 100 | if self.use_bn: 101 | gap = self.bn1(gap) 102 | gap = self.relu(gap) 103 | 104 | atten = self.fc2(gap) 105 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 106 | 107 | if self.radix > 1: 108 | if torch.__version__ < '1.5': 109 | attens = torch.split(atten, int(rchannel // self.radix), dim=1) 110 | else: 111 | attens = torch.split(atten, rchannel // self.radix, dim=1) 112 | out = sum([att * split for (att, split) in zip(attens, splited)]) 113 | else: 114 | out = atten * x 115 | return out.contiguous() 116 | 117 | 118 | class rSoftMax(nn.Module): 119 | def __init__(self, radix, cardinality): 120 | super().__init__() 121 | self.radix = radix 122 | self.cardinality = cardinality 123 | 124 | def forward(self, x): 125 | batch = x.size(0) 126 | if self.radix > 1: 127 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 128 | x = F.softmax(x, dim=1) 129 | x = x.reshape(batch, -1) 130 | else: 131 | x = torch.sigmoid(x) 132 | return x 133 | -------------------------------------------------------------------------------- /aot/networks/encoders/swin/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_swin_model -------------------------------------------------------------------------------- /aot/networks/encoders/swin/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | 10 | 11 | def build_swin_model(model_type, freeze_at=0): 12 | if model_type == 'swin_base': 13 | model = SwinTransformer(embed_dim=128, 14 | depths=[2, 2, 18, 2], 15 | num_heads=[4, 8, 16, 32], 16 | window_size=7, 17 | drop_path_rate=0.3, 18 | out_indices=(0, 1, 2), 19 | ape=False, 20 | patch_norm=True, 21 | frozen_stages=freeze_at, 22 | use_checkpoint=False) 23 | 24 | else: 25 | raise NotImplementedError(f"Unkown model: {model_type}") 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /aot/networks/engines/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine 2 | from networks.engines.deaot_engine import DeAOTEngine, DeAOTInferEngine 3 | 4 | 5 | def build_engine(name, phase='train', **kwargs): 6 | if name == 'aotengine': 7 | if phase == 'train': 8 | return AOTEngine(**kwargs) 9 | elif phase == 'eval': 10 | return AOTInferEngine(**kwargs) 11 | else: 12 | raise NotImplementedError 13 | elif name == 'deaotengine': 14 | if phase == 'train': 15 | return DeAOTEngine(**kwargs) 16 | elif phase == 'eval': 17 | return DeAOTInferEngine(**kwargs) 18 | else: 19 | raise NotImplementedError 20 | else: 21 | raise NotImplementedError 22 | -------------------------------------------------------------------------------- /aot/networks/engines/deaot_engine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.image import one_hot_mask 4 | 5 | from networks.layers.basic import seq_to_2d 6 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine 7 | 8 | 9 | class DeAOTEngine(AOTEngine): 10 | def __init__(self, 11 | aot_model, 12 | gpu_id=0, 13 | long_term_mem_gap=9999, 14 | short_term_mem_skip=1, 15 | layer_loss_scaling_ratio=2., 16 | max_len_long_term=9999): 17 | super().__init__(aot_model, gpu_id, long_term_mem_gap, 18 | short_term_mem_skip, max_len_long_term) 19 | self.layer_loss_scaling_ratio = layer_loss_scaling_ratio 20 | def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): 21 | 22 | if curr_id_emb is None: 23 | if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: 24 | curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) 25 | else: 26 | curr_one_hot_mask = curr_mask 27 | curr_id_emb = self.assign_identity(curr_one_hot_mask) 28 | 29 | lstt_curr_memories = self.curr_lstt_output[1] 30 | lstt_curr_memories_2d = [] 31 | for layer_idx in range(len(lstt_curr_memories)): 32 | curr_k, curr_v, curr_id_k, curr_id_v = lstt_curr_memories[ 33 | layer_idx] 34 | curr_id_k, curr_id_v = self.AOT.LSTT.layers[ 35 | layer_idx].fuse_key_value_id(curr_id_k, curr_id_v, curr_id_emb) 36 | lstt_curr_memories[layer_idx][2], lstt_curr_memories[layer_idx][ 37 | 3] = curr_id_k, curr_id_v 38 | local_curr_id_k = seq_to_2d( 39 | curr_id_k, self.enc_size_2d) if curr_id_k is not None else None 40 | local_curr_id_v = seq_to_2d(curr_id_v, self.enc_size_2d) 41 | lstt_curr_memories_2d.append([ 42 | seq_to_2d(curr_k, self.enc_size_2d), 43 | seq_to_2d(curr_v, self.enc_size_2d), local_curr_id_k, 44 | local_curr_id_v 45 | ]) 46 | 47 | self.short_term_memories_list.append(lstt_curr_memories_2d) 48 | self.short_term_memories_list = self.short_term_memories_list[ 49 | -self.short_term_mem_skip:] 50 | self.short_term_memories = self.short_term_memories_list[0] 51 | 52 | if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: 53 | # skip the update of long-term memory or not 54 | if not skip_long_term_update: 55 | self.update_long_term_memory(lstt_curr_memories) 56 | self.last_mem_step = self.frame_step 57 | 58 | 59 | class DeAOTInferEngine(AOTInferEngine): 60 | def __init__(self, 61 | aot_model, 62 | gpu_id=0, 63 | long_term_mem_gap=9999, 64 | short_term_mem_skip=1, 65 | max_aot_obj_num=None, 66 | max_len_long_term=9999): 67 | super().__init__(aot_model, gpu_id, long_term_mem_gap, 68 | short_term_mem_skip, max_aot_obj_num, max_len_long_term) 69 | def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): 70 | if isinstance(obj_nums, list): 71 | obj_nums = obj_nums[0] 72 | self.obj_nums = obj_nums 73 | aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) 74 | while (aot_num > len(self.aot_engines)): 75 | new_engine = DeAOTEngine(self.AOT, self.gpu_id, 76 | self.long_term_mem_gap, 77 | self.short_term_mem_skip, 78 | max_len_long_term = self.max_len_long_term) 79 | new_engine.eval() 80 | self.aot_engines.append(new_engine) 81 | 82 | separated_masks, separated_obj_nums = self.separate_mask( 83 | mask, obj_nums) 84 | img_embs = None 85 | for aot_engine, separated_mask, separated_obj_num in zip( 86 | self.aot_engines, separated_masks, separated_obj_nums): 87 | if aot_engine.obj_nums is None or aot_engine.obj_nums[0] < separated_obj_num: 88 | aot_engine.add_reference_frame(img, 89 | separated_mask, 90 | obj_nums=[separated_obj_num], 91 | frame_step=frame_step, 92 | img_embs=img_embs) 93 | else: 94 | aot_engine.update_short_term_memory(separated_mask) 95 | if img_embs is None: # reuse image embeddings 96 | img_embs = aot_engine.curr_enc_embs 97 | 98 | self.update_size() 99 | -------------------------------------------------------------------------------- /aot/networks/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/networks/layers/__init__.py -------------------------------------------------------------------------------- /aot/networks/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FrozenBatchNorm2d(nn.Module): 7 | """ 8 | BatchNorm2d where the batch statistics and the affine parameters 9 | are fixed 10 | """ 11 | def __init__(self, n, epsilon=1e-5): 12 | super(FrozenBatchNorm2d, self).__init__() 13 | self.register_buffer("weight", torch.ones(n)) 14 | self.register_buffer("bias", torch.zeros(n)) 15 | self.register_buffer("running_mean", torch.zeros(n)) 16 | self.register_buffer("running_var", torch.ones(n) - epsilon) 17 | self.epsilon = epsilon 18 | 19 | def forward(self, x): 20 | """ 21 | Refer to Detectron2 (https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py) 22 | """ 23 | if x.requires_grad: 24 | # When gradients are needed, F.batch_norm will use extra memory 25 | # because its backward op computes gradients for weight/bias as well. 26 | scale = self.weight * (self.running_var + self.epsilon).rsqrt() 27 | bias = self.bias - self.running_mean * scale 28 | scale = scale.reshape(1, -1, 1, 1) 29 | bias = bias.reshape(1, -1, 1, 1) 30 | out_dtype = x.dtype # may be half 31 | return x * scale.to(out_dtype) + bias.to(out_dtype) 32 | else: 33 | # When gradients are not needed, F.batch_norm is a single fused op 34 | # and provide more optimization opportunities. 35 | return F.batch_norm( 36 | x, 37 | self.running_mean, 38 | self.running_var, 39 | self.weight, 40 | self.bias, 41 | training=False, 42 | eps=self.epsilon, 43 | ) 44 | -------------------------------------------------------------------------------- /aot/networks/layers/position.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils.math import truncated_normal_ 8 | 9 | 10 | class Downsample2D(nn.Module): 11 | def __init__(self, mode='nearest', scale=4): 12 | super().__init__() 13 | self.mode = mode 14 | self.scale = scale 15 | 16 | def forward(self, x): 17 | n, c, h, w = x.size() 18 | x = F.interpolate(x, 19 | size=(h // self.scale + 1, w // self.scale + 1), 20 | mode=self.mode) 21 | return x 22 | 23 | 24 | def generate_coord(x): 25 | _, _, h, w = x.size() 26 | device = x.device 27 | col = torch.arange(0, h, device=device) 28 | row = torch.arange(0, w, device=device) 29 | grid_h, grid_w = torch.meshgrid(col, row) 30 | return grid_h, grid_w 31 | 32 | 33 | class PositionEmbeddingSine(nn.Module): 34 | def __init__(self, 35 | num_pos_feats=64, 36 | temperature=10000, 37 | normalize=False, 38 | scale=None): 39 | super().__init__() 40 | self.num_pos_feats = num_pos_feats 41 | self.temperature = temperature 42 | self.normalize = normalize 43 | if scale is not None and normalize is False: 44 | raise ValueError("normalize should be True if scale is passed") 45 | if scale is None: 46 | scale = 2 * math.pi 47 | self.scale = scale 48 | 49 | def forward(self, x): 50 | grid_y, grid_x = generate_coord(x) 51 | 52 | y_embed = grid_y.unsqueeze(0).float() 53 | x_embed = grid_x.unsqueeze(0).float() 54 | 55 | if self.normalize: 56 | eps = 1e-6 57 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 58 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 59 | 60 | dim_t = torch.arange(self.num_pos_feats, 61 | dtype=torch.float32, 62 | device=x.device) 63 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) 64 | 65 | pos_x = x_embed[:, :, :, None] / dim_t 66 | pos_y = y_embed[:, :, :, None] / dim_t 67 | pos_x = torch.stack( 68 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), 69 | dim=4).flatten(3) 70 | pos_y = torch.stack( 71 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), 72 | dim=4).flatten(3) 73 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 74 | return pos 75 | 76 | 77 | class PositionEmbeddingLearned(nn.Module): 78 | def __init__(self, num_pos_feats=64, H=30, W=30): 79 | super().__init__() 80 | self.H = H 81 | self.W = W 82 | self.pos_emb = nn.Parameter( 83 | truncated_normal_(torch.zeros(1, num_pos_feats, H, W))) 84 | 85 | def forward(self, x): 86 | bs, _, h, w = x.size() 87 | pos_emb = self.pos_emb 88 | if h != self.H or w != self.W: 89 | pos_emb = F.interpolate(pos_emb, size=(h, w), mode="bilinear") 90 | return pos_emb 91 | -------------------------------------------------------------------------------- /aot/networks/models/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.models.aot import AOT 2 | from networks.models.deaot import DeAOT 3 | 4 | 5 | def build_vos_model(name, cfg, **kwargs): 6 | if name == 'aot': 7 | return AOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) 8 | elif name == 'deaot': 9 | return DeAOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) 10 | else: 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /aot/networks/models/aot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.encoders import build_encoder 4 | from networks.layers.transformer import LongShortTermTransformer 5 | from networks.decoders import build_decoder 6 | from networks.layers.position import PositionEmbeddingSine 7 | 8 | 9 | class AOT(nn.Module): 10 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): 11 | super().__init__() 12 | self.cfg = cfg 13 | self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM 14 | self.epsilon = cfg.MODEL_EPSILON 15 | 16 | self.encoder = build_encoder(encoder, 17 | frozen_bn=cfg.MODEL_FREEZE_BN, 18 | freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT) 19 | self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1], 20 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 21 | kernel_size=1) 22 | 23 | self.LSTT = LongShortTermTransformer( 24 | cfg.MODEL_LSTT_NUM, 25 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 26 | cfg.MODEL_SELF_HEADS, 27 | cfg.MODEL_ATT_HEADS, 28 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, 29 | droppath=cfg.TRAIN_LSTT_DROPPATH, 30 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, 31 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, 32 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, 33 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, 34 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 35 | return_intermediate=True) 36 | 37 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ 38 | (cfg.MODEL_LSTT_NUM + 39 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM 40 | 41 | self.decoder = build_decoder( 42 | decoder, 43 | in_dim=decoder_indim, 44 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, 45 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 46 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, 47 | shortcut_dims=cfg.MODEL_ENCODER_DIM, 48 | align_corners=cfg.MODEL_ALIGN_CORNERS) 49 | 50 | if cfg.MODEL_ALIGN_CORNERS: 51 | self.patch_wise_id_bank = nn.Conv2d( 52 | cfg.MODEL_MAX_OBJ_NUM + 1, 53 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 54 | kernel_size=17, 55 | stride=16, 56 | padding=8) 57 | else: 58 | self.patch_wise_id_bank = nn.Conv2d( 59 | cfg.MODEL_MAX_OBJ_NUM + 1, 60 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 61 | kernel_size=16, 62 | stride=16, 63 | padding=0) 64 | 65 | self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True) 66 | 67 | self.pos_generator = PositionEmbeddingSine( 68 | cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True) 69 | 70 | self._init_weight() 71 | 72 | def get_pos_emb(self, x): 73 | pos_emb = self.pos_generator(x) 74 | return pos_emb 75 | 76 | def get_id_emb(self, x): 77 | id_emb = self.patch_wise_id_bank(x) 78 | id_emb = self.id_dropout(id_emb) 79 | return id_emb 80 | 81 | def encode_image(self, img): 82 | xs = self.encoder(img) 83 | xs[-1] = self.encoder_projector(xs[-1]) 84 | return xs 85 | 86 | def decode_id_logits(self, lstt_emb, shortcuts): 87 | n, c, h, w = shortcuts[-1].size() 88 | decoder_inputs = [shortcuts[-1]] 89 | for emb in lstt_emb: 90 | decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1)) 91 | pred_logit = self.decoder(decoder_inputs, shortcuts) 92 | return pred_logit 93 | 94 | def LSTT_forward(self, 95 | curr_embs, 96 | long_term_memories, 97 | short_term_memories, 98 | curr_id_emb=None, 99 | pos_emb=None, 100 | size_2d=(30, 30)): 101 | n, c, h, w = curr_embs[-1].size() 102 | curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1) 103 | lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories, 104 | short_term_memories, curr_id_emb, 105 | pos_emb, size_2d) 106 | lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip( 107 | *lstt_memories) 108 | return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories 109 | 110 | def _init_weight(self): 111 | nn.init.xavier_uniform_(self.encoder_projector.weight) 112 | nn.init.orthogonal_( 113 | self.patch_wise_id_bank.weight.view( 114 | self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1), 115 | gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2) 116 | -------------------------------------------------------------------------------- /aot/networks/models/deaot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.layers.transformer import DualBranchGPM 4 | from networks.models.aot import AOT 5 | from networks.decoders import build_decoder 6 | 7 | 8 | class DeAOT(AOT): 9 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): 10 | super().__init__(cfg, encoder, decoder) 11 | 12 | self.LSTT = DualBranchGPM( 13 | cfg.MODEL_LSTT_NUM, 14 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 15 | cfg.MODEL_SELF_HEADS, 16 | cfg.MODEL_ATT_HEADS, 17 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, 18 | droppath=cfg.TRAIN_LSTT_DROPPATH, 19 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, 20 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, 21 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, 22 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, 23 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 24 | return_intermediate=True) 25 | 26 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ 27 | (cfg.MODEL_LSTT_NUM * 2 + 28 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM * 2 29 | 30 | self.decoder = build_decoder( 31 | decoder, 32 | in_dim=decoder_indim, 33 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, 34 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 35 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, 36 | shortcut_dims=cfg.MODEL_ENCODER_DIM, 37 | align_corners=cfg.MODEL_ALIGN_CORNERS) 38 | 39 | self.id_norm = nn.LayerNorm(cfg.MODEL_ENCODER_EMBEDDING_DIM) 40 | 41 | self._init_weight() 42 | 43 | def decode_id_logits(self, lstt_emb, shortcuts): 44 | n, c, h, w = shortcuts[-1].size() 45 | decoder_inputs = [shortcuts[-1]] 46 | for emb in lstt_emb: 47 | decoder_inputs.append(emb.view(h, w, n, -1).permute(2, 3, 0, 1)) 48 | pred_logit = self.decoder(decoder_inputs, shortcuts) 49 | return pred_logit 50 | 51 | def get_id_emb(self, x): 52 | id_emb = self.patch_wise_id_bank(x) 53 | id_emb = self.id_norm(id_emb.permute(2, 3, 0, 1)).permute(2, 3, 0, 1) 54 | id_emb = self.id_dropout(id_emb) 55 | return id_emb 56 | -------------------------------------------------------------------------------- /aot/source/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/source/.DS_Store -------------------------------------------------------------------------------- /aot/source/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/source/overview.png -------------------------------------------------------------------------------- /aot/source/overview_deaot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/source/overview_deaot.png -------------------------------------------------------------------------------- /aot/tools/eval.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('..') 6 | 7 | import torch 8 | import torch.multiprocessing as mp 9 | 10 | from networks.managers.evaluator import Evaluator 11 | 12 | 13 | def main_worker(gpu, cfg, seq_queue=None, info_queue=None, enable_amp=False): 14 | # Initiate a evaluating manager 15 | evaluator = Evaluator(rank=gpu, 16 | cfg=cfg, 17 | seq_queue=seq_queue, 18 | info_queue=info_queue) 19 | # Start evaluation 20 | if enable_amp: 21 | with torch.cuda.amp.autocast(enabled=True): 22 | evaluator.evaluating() 23 | else: 24 | evaluator.evaluating() 25 | 26 | 27 | def main(): 28 | import argparse 29 | parser = argparse.ArgumentParser(description="Eval VOS") 30 | parser.add_argument('--exp_name', type=str, default='default') 31 | 32 | parser.add_argument('--stage', type=str, default='pre') 33 | parser.add_argument('--model', type=str, default='aott') 34 | parser.add_argument('--lstt_num', type=int, default=-1) 35 | parser.add_argument('--lt_gap', type=int, default=-1) 36 | parser.add_argument('--st_skip', type=int, default=-1) 37 | parser.add_argument('--max_id_num', type=int, default='-1') 38 | 39 | parser.add_argument('--gpu_id', type=int, default=0) 40 | parser.add_argument('--gpu_num', type=int, default=1) 41 | 42 | parser.add_argument('--ckpt_path', type=str, default='') 43 | parser.add_argument('--ckpt_step', type=int, default=-1) 44 | 45 | parser.add_argument('--dataset', type=str, default='') 46 | parser.add_argument('--split', type=str, default='') 47 | 48 | parser.add_argument('--ema', action='store_true') 49 | parser.set_defaults(ema=False) 50 | 51 | parser.add_argument('--flip', action='store_true') 52 | parser.set_defaults(flip=False) 53 | parser.add_argument('--ms', nargs='+', type=float, default=[1.]) 54 | 55 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3) 56 | 57 | parser.add_argument('--amp', action='store_true') 58 | parser.set_defaults(amp=False) 59 | 60 | args = parser.parse_args() 61 | 62 | engine_config = importlib.import_module('configs.' + args.stage) 63 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 64 | 65 | cfg.TEST_EMA = args.ema 66 | 67 | cfg.TEST_GPU_ID = args.gpu_id 68 | cfg.TEST_GPU_NUM = args.gpu_num 69 | 70 | if args.lstt_num > 0: 71 | cfg.MODEL_LSTT_NUM = args.lstt_num 72 | if args.lt_gap > 0: 73 | cfg.TEST_LONG_TERM_MEM_GAP = args.lt_gap 74 | if args.st_skip > 0: 75 | cfg.TEST_SHORT_TERM_MEM_SKIP = args.st_skip 76 | 77 | if args.max_id_num > 0: 78 | cfg.MODEL_MAX_OBJ_NUM = args.max_id_num 79 | 80 | if args.ckpt_path != '': 81 | cfg.TEST_CKPT_PATH = args.ckpt_path 82 | if args.ckpt_step > 0: 83 | cfg.TEST_CKPT_STEP = args.ckpt_step 84 | 85 | if args.dataset != '': 86 | cfg.TEST_DATASET = args.dataset 87 | 88 | if args.split != '': 89 | cfg.TEST_DATASET_SPLIT = args.split 90 | 91 | cfg.TEST_FLIP = args.flip 92 | cfg.TEST_MULTISCALE = args.ms 93 | 94 | if cfg.TEST_MULTISCALE != [1.]: 95 | cfg.TEST_MAX_SHORT_EDGE = args.max_resolution # for preventing OOM 96 | else: 97 | cfg.TEST_MAX_SHORT_EDGE = None # the default resolution setting of CFBI and AOT 98 | cfg.TEST_MAX_LONG_EDGE = args.max_resolution * 800. / 480. 99 | 100 | if args.gpu_num > 1: 101 | mp.set_start_method('spawn') 102 | seq_queue = mp.Queue() 103 | info_queue = mp.Queue() 104 | mp.spawn(main_worker, 105 | nprocs=cfg.TEST_GPU_NUM, 106 | args=(cfg, seq_queue, info_queue, args.amp)) 107 | else: 108 | main_worker(0, cfg, enable_amp=args.amp) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /aot/tools/train.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import random 3 | import sys 4 | 5 | sys.setrecursionlimit(10000) 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | 9 | import torch.multiprocessing as mp 10 | 11 | from networks.managers.trainer import Trainer 12 | 13 | 14 | def main_worker(gpu, cfg, enable_amp=True): 15 | # Initiate a training manager 16 | trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp) 17 | # Start Training 18 | trainer.sequential_training() 19 | 20 | 21 | def main(): 22 | import argparse 23 | parser = argparse.ArgumentParser(description="Train VOS") 24 | parser.add_argument('--exp_name', type=str, default='') 25 | parser.add_argument('--stage', type=str, default='pre') 26 | parser.add_argument('--model', type=str, default='aott') 27 | parser.add_argument('--max_id_num', type=int, default='-1') 28 | 29 | parser.add_argument('--start_gpu', type=int, default=0) 30 | parser.add_argument('--gpu_num', type=int, default=-1) 31 | parser.add_argument('--batch_size', type=int, default=-1) 32 | parser.add_argument('--dist_url', type=str, default='') 33 | parser.add_argument('--amp', action='store_true') 34 | parser.set_defaults(amp=False) 35 | 36 | parser.add_argument('--pretrained_path', type=str, default='') 37 | 38 | parser.add_argument('--datasets', nargs='+', type=str, default=[]) 39 | parser.add_argument('--lr', type=float, default=-1.) 40 | parser.add_argument('--total_step', type=int, default=-1.) 41 | parser.add_argument('--start_step', type=int, default=-1.) 42 | 43 | args = parser.parse_args() 44 | 45 | engine_config = importlib.import_module('configs.' + args.stage) 46 | 47 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 48 | 49 | if len(args.datasets) > 0: 50 | cfg.DATASETS = args.datasets 51 | 52 | cfg.DIST_START_GPU = args.start_gpu 53 | if args.gpu_num > 0: 54 | cfg.TRAIN_GPUS = args.gpu_num 55 | if args.batch_size > 0: 56 | cfg.TRAIN_BATCH_SIZE = args.batch_size 57 | 58 | if args.pretrained_path != '': 59 | cfg.PRETRAIN_MODEL = args.pretrained_path 60 | 61 | if args.max_id_num > 0: 62 | cfg.MODEL_MAX_OBJ_NUM = args.max_id_num 63 | 64 | if args.lr > 0: 65 | cfg.TRAIN_LR = args.lr 66 | 67 | if args.total_step > 0: 68 | cfg.TRAIN_TOTAL_STEPS = args.total_step 69 | 70 | if args.start_step > 0: 71 | cfg.TRAIN_START_STEP = args.start_step 72 | 73 | if args.dist_url == '': 74 | cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str( 75 | random.randint(0, 9)) 76 | else: 77 | cfg.DIST_URL = args.dist_url 78 | 79 | if cfg.TRAIN_GPUS > 1: 80 | # Use torch.multiprocessing.spawn to launch distributed processes 81 | mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp)) 82 | else: 83 | cfg.TRAIN_GPUS = 1 84 | main_worker(0, cfg, args.amp) 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /aot/train_eval.sh: -------------------------------------------------------------------------------- 1 | exp="default" 2 | gpu_num="4" 3 | 4 | model="aott" 5 | # model="aots" 6 | # model="aotb" 7 | # model="aotl" 8 | # model="r50_deaotl" 9 | # model="swinb_aotl" 10 | 11 | ## Training ## 12 | stage="pre" 13 | python tools/train.py --amp \ 14 | --exp_name ${exp} \ 15 | --stage ${stage} \ 16 | --model ${model} \ 17 | --gpu_num ${gpu_num} 18 | 19 | stage="pre_ytb_dav" 20 | python tools/train.py --amp \ 21 | --exp_name ${exp} \ 22 | --stage ${stage} \ 23 | --model ${model} \ 24 | --gpu_num ${gpu_num} 25 | 26 | ## Evaluation ## 27 | dataset="davis2017" 28 | split="test" 29 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 30 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 31 | 32 | dataset="davis2017" 33 | split="val" 34 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 35 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 36 | 37 | dataset="davis2016" 38 | split="val" 39 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 40 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 41 | 42 | dataset="youtubevos2018" 43 | split="val" # or "val_all_frames" 44 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 45 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 46 | 47 | dataset="youtubevos2019" 48 | split="val" # or "val_all_frames" 49 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 50 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} -------------------------------------------------------------------------------- /aot/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/aot/utils/__init__.py -------------------------------------------------------------------------------- /aot/utils/cp_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"): 6 | exps = os.listdir(curr_dir) 7 | for exp in exps: 8 | print("Exp: ", exp) 9 | exp_dir = os.path.join(curr_dir, exp) 10 | stages = os.listdir(exp_dir) 11 | for stage in stages: 12 | print("Stage: ", stage) 13 | stage_dir = os.path.join(exp_dir, stage) 14 | finals = ["ema_ckpt", "ckpt"] 15 | for final in finals: 16 | print("Final: ", final) 17 | final_dir = os.path.join(stage_dir, final) 18 | ckpts = os.listdir(final_dir) 19 | for ckpt in ckpts: 20 | if '.pth' not in ckpt: 21 | continue 22 | curr_ckpt_path = os.path.join(final_dir, ckpt) 23 | remote_ckpt_path = os.path.join(remote_dir, exp, stage, 24 | final, ckpt) 25 | if os.path.exists(remote_ckpt_path): 26 | os.system('rm {}'.format(remote_ckpt_path)) 27 | try: 28 | shutil.copy(curr_ckpt_path, remote_ckpt_path) 29 | print(ckpt, ': OK') 30 | except OSError as Inst: 31 | print(Inst) 32 | print(ckpt, ': Fail') 33 | 34 | 35 | if __name__ == "__main__": 36 | cp_ckpt() 37 | -------------------------------------------------------------------------------- /aot/utils/ema.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import unicode_literals 3 | 4 | import torch 5 | 6 | 7 | def get_param_buffer_for_ema(model, 8 | update_buffer=False, 9 | required_buffers=['running_mean', 'running_var']): 10 | params = model.parameters() 11 | all_param_buffer = [p for p in params if p.requires_grad] 12 | if update_buffer: 13 | named_buffers = model.named_buffers() 14 | for key, value in named_buffers: 15 | for buffer_name in required_buffers: 16 | if buffer_name in key: 17 | all_param_buffer.append(value) 18 | break 19 | return all_param_buffer 20 | 21 | 22 | class ExponentialMovingAverage: 23 | """ 24 | Maintains (exponential) moving average of a set of parameters. 25 | """ 26 | def __init__(self, parameters, decay, use_num_updates=True): 27 | """ 28 | Args: 29 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 30 | `model.parameters()`. 31 | decay: The exponential decay. 32 | use_num_updates: Whether to use number of updates when computing 33 | averages. 34 | """ 35 | if decay < 0.0 or decay > 1.0: 36 | raise ValueError('Decay must be between 0 and 1') 37 | self.decay = decay 38 | self.num_updates = 0 if use_num_updates else None 39 | self.shadow_params = [p.clone().detach() for p in parameters] 40 | self.collected_params = [] 41 | 42 | def update(self, parameters): 43 | """ 44 | Update currently maintained parameters. 45 | Call this every time the parameters are updated, such as the result of 46 | the `optimizer.step()` call. 47 | Args: 48 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 49 | parameters used to initialize this object. 50 | """ 51 | decay = self.decay 52 | if self.num_updates is not None: 53 | self.num_updates += 1 54 | decay = min(decay, 55 | (1 + self.num_updates) / (10 + self.num_updates)) 56 | one_minus_decay = 1.0 - decay 57 | with torch.no_grad(): 58 | for s_param, param in zip(self.shadow_params, parameters): 59 | s_param.sub_(one_minus_decay * (s_param - param)) 60 | 61 | def copy_to(self, parameters): 62 | """ 63 | Copy current parameters into given collection of parameters. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | updated with the stored moving averages. 67 | """ 68 | for s_param, param in zip(self.shadow_params, parameters): 69 | param.data.copy_(s_param.data) 70 | 71 | def store(self, parameters): 72 | """ 73 | Save the current parameters for restoring later. 74 | Args: 75 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 76 | temporarily stored. 77 | """ 78 | self.collected_params = [param.clone() for param in parameters] 79 | 80 | def restore(self, parameters): 81 | """ 82 | Restore the parameters stored with the `store` method. 83 | Useful to validate the model with EMA parameters without affecting the 84 | original optimization process. Store the parameters before the 85 | `copy_to` method. After validation (or model saving), use this to 86 | restore the former parameters. 87 | Args: 88 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 89 | updated with the stored parameters. 90 | """ 91 | for c_param, param in zip(self.collected_params, parameters): 92 | param.data.copy_(c_param.data) 93 | del (self.collected_params) 94 | -------------------------------------------------------------------------------- /aot/utils/eval.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os 3 | 4 | 5 | def zip_folder(source_folder, zip_dir): 6 | f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED) 7 | pre_len = len(os.path.dirname(source_folder)) 8 | for dirpath, dirnames, filenames in os.walk(source_folder): 9 | for filename in filenames: 10 | pathfile = os.path.join(dirpath, filename) 11 | arcname = pathfile[pre_len:].strip(os.path.sep) 12 | f.write(pathfile, arcname) 13 | f.close() -------------------------------------------------------------------------------- /aot/utils/learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, 5 | base_lr, 6 | p, 7 | itr, 8 | max_itr, 9 | restart=1, 10 | warm_up_steps=1000, 11 | is_cosine_decay=False, 12 | min_lr=1e-5, 13 | encoder_lr_ratio=1.0, 14 | freeze_params=[]): 15 | 16 | if restart > 1: 17 | each_max_itr = int(math.ceil(float(max_itr) / restart)) 18 | itr = itr % each_max_itr 19 | warm_up_steps /= restart 20 | max_itr = each_max_itr 21 | 22 | if itr < warm_up_steps: 23 | now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps 24 | else: 25 | itr = itr - warm_up_steps 26 | max_itr = max_itr - warm_up_steps 27 | if is_cosine_decay: 28 | now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr / 29 | (max_itr + 1)) + 30 | 1.) * 0.5 31 | else: 32 | now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p 33 | 34 | for param_group in optimizer.param_groups: 35 | if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]: 36 | param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr 37 | else: 38 | param_group['lr'] = now_lr 39 | 40 | for freeze_param in freeze_params: 41 | if freeze_param in param_group["name"]: 42 | param_group['lr'] = 0 43 | param_group['weight_decay'] = 0 44 | break 45 | 46 | return now_lr 47 | 48 | 49 | def get_trainable_params(model, 50 | base_lr, 51 | weight_decay, 52 | use_frozen_bn=False, 53 | exclusive_wd_dict={}, 54 | no_wd_keys=[]): 55 | params = [] 56 | memo = set() 57 | total_param = 0 58 | for key, value in model.named_parameters(): 59 | if value in memo: 60 | continue 61 | total_param += value.numel() 62 | if not value.requires_grad: 63 | continue 64 | memo.add(value) 65 | wd = weight_decay 66 | for exclusive_key in exclusive_wd_dict.keys(): 67 | if exclusive_key in key: 68 | wd = exclusive_wd_dict[exclusive_key] 69 | break 70 | if len(value.shape) == 1: # normalization layers 71 | if 'bias' in key: # bias requires no weight decay 72 | wd = 0. 73 | elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay 74 | wd = 0. 75 | elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder 76 | wd = 0. 77 | else: 78 | for no_wd_key in no_wd_keys: 79 | if no_wd_key in key: 80 | wd = 0. 81 | break 82 | params += [{ 83 | "params": [value], 84 | "lr": base_lr, 85 | "weight_decay": wd, 86 | "name": key 87 | }] 88 | 89 | print('Total Param: {:.2f}M'.format(total_param / 1e6)) 90 | return params 91 | 92 | 93 | def freeze_params(module): 94 | for p in module.parameters(): 95 | p.requires_grad = False 96 | 97 | 98 | def calculate_params(state_dict): 99 | memo = set() 100 | total_param = 0 101 | for key, value in state_dict.items(): 102 | if value in memo: 103 | continue 104 | memo.add(value) 105 | total_param += value.numel() 106 | print('Total Param: {:.2f}M'.format(total_param / 1e6)) 107 | -------------------------------------------------------------------------------- /aot/utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0): 5 | all_matrix = [] 6 | for idx in range(num): 7 | random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id)) 8 | if keep_first: 9 | fg = random_matrix[1:][torch.randperm(dim - 1)] 10 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) 11 | else: 12 | random_matrix = random_matrix[torch.randperm(dim)] 13 | all_matrix.append(random_matrix) 14 | return torch.stack(all_matrix, dim=0) 15 | 16 | 17 | def truncated_normal_(tensor, mean=0, std=.02): 18 | size = tensor.shape 19 | tmp = tensor.new_empty(size + (4, )).normal_() 20 | valid = (tmp < 2) & (tmp > -2) 21 | ind = valid.max(-1, keepdim=True)[1] 22 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 23 | tensor.data.mul_(std).add_(mean) 24 | return tensor 25 | -------------------------------------------------------------------------------- /aot/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self, momentum=0.999): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | self.long_count = 0 12 | self.momentum = momentum 13 | self.moving_avg = 0 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | if self.long_count == 0: 23 | self.moving_avg = val 24 | else: 25 | momentum = min(self.momentum, 1. - 1. / self.long_count) 26 | self.moving_avg = self.moving_avg * momentum + val * (1 - momentum) 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.long_count += n 31 | self.avg = self.sum / self.count 32 | -------------------------------------------------------------------------------- /aot/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pytorch_iou(pred, target, obj_num, epsilon=1e-6): 5 | ''' 6 | pred: [bs, h, w] 7 | target: [bs, h, w] 8 | obj_num: [bs] 9 | ''' 10 | bs = pred.size(0) 11 | all_iou = [] 12 | for idx in range(bs): 13 | now_pred = pred[idx].unsqueeze(0) 14 | now_target = target[idx].unsqueeze(0) 15 | now_obj_num = obj_num[idx] 16 | 17 | obj_ids = torch.arange(0, now_obj_num + 1, 18 | device=now_pred.device).int().view(-1, 1, 1) 19 | if obj_ids.size(0) == 1: # only contain background 20 | continue 21 | else: 22 | obj_ids = obj_ids[1:] 23 | now_pred = (now_pred == obj_ids).float() 24 | now_target = (now_target == obj_ids).float() 25 | 26 | intersection = (now_pred * now_target).sum((1, 2)) 27 | union = ((now_pred + now_target) > 0).float().sum((1, 2)) 28 | 29 | now_iou = (intersection + epsilon) / (union + epsilon) 30 | 31 | all_iou.append(now_iou.mean()) 32 | if len(all_iou) > 0: 33 | all_iou = torch.stack(all_iou).mean() 34 | else: 35 | all_iou = torch.ones((1), device=pred.device) 36 | return all_iou 37 | -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/demo.gif -------------------------------------------------------------------------------- /groundingdino/_C.cp310-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/groundingdino/_C.cp310-win_amd64.pyd -------------------------------------------------------------------------------- /groundingdino/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/groundingdino/__init__.py -------------------------------------------------------------------------------- /groundingdino/config/GroundingDINO_SwinB_cfg.py: -------------------------------------------------------------------------------- 1 | batch_size = 1 2 | modelname = "groundingdino" 3 | backbone = "swin_B_384_22k" 4 | position_embedding = "sine" 5 | pe_temperatureH = 20 6 | pe_temperatureW = 20 7 | return_interm_indices = [1, 2, 3] 8 | backbone_freeze_keywords = None 9 | enc_layers = 6 10 | dec_layers = 6 11 | pre_norm = False 12 | dim_feedforward = 2048 13 | hidden_dim = 256 14 | dropout = 0.0 15 | nheads = 8 16 | num_queries = 900 17 | query_dim = 4 18 | num_patterns = 0 19 | num_feature_levels = 4 20 | enc_n_points = 4 21 | dec_n_points = 4 22 | two_stage_type = "standard" 23 | two_stage_bbox_embed_share = False 24 | two_stage_class_embed_share = False 25 | transformer_activation = "relu" 26 | dec_pred_bbox_embed_share = True 27 | dn_box_noise_scale = 1.0 28 | dn_label_noise_ratio = 0.5 29 | dn_label_coef = 1.0 30 | dn_bbox_coef = 1.0 31 | embed_init_tgt = True 32 | dn_labelbook_size = 2000 33 | max_text_len = 256 34 | text_encoder_type = "bert-base-uncased" 35 | use_text_enhancer = True 36 | use_fusion_layer = True 37 | use_checkpoint = True 38 | use_transformer_ckpt = True 39 | use_text_cross_attention = True 40 | text_dropout = 0.0 41 | fusion_dropout = 0.0 42 | fusion_droppath = 0.1 43 | sub_sentence_present = True 44 | -------------------------------------------------------------------------------- /groundingdino/config/GroundingDINO_SwinT_OGC.py: -------------------------------------------------------------------------------- 1 | batch_size = 1 2 | modelname = "groundingdino" 3 | backbone = "swin_T_224_1k" 4 | position_embedding = "sine" 5 | pe_temperatureH = 20 6 | pe_temperatureW = 20 7 | return_interm_indices = [1, 2, 3] 8 | backbone_freeze_keywords = None 9 | enc_layers = 6 10 | dec_layers = 6 11 | pre_norm = False 12 | dim_feedforward = 2048 13 | hidden_dim = 256 14 | dropout = 0.0 15 | nheads = 8 16 | num_queries = 900 17 | query_dim = 4 18 | num_patterns = 0 19 | num_feature_levels = 4 20 | enc_n_points = 4 21 | dec_n_points = 4 22 | two_stage_type = "standard" 23 | two_stage_bbox_embed_share = False 24 | two_stage_class_embed_share = False 25 | transformer_activation = "relu" 26 | dec_pred_bbox_embed_share = True 27 | dn_box_noise_scale = 1.0 28 | dn_label_noise_ratio = 0.5 29 | dn_label_coef = 1.0 30 | dn_bbox_coef = 1.0 31 | embed_init_tgt = True 32 | dn_labelbook_size = 2000 33 | max_text_len = 256 34 | text_encoder_type = "bert-base-uncased" 35 | use_text_enhancer = True 36 | use_fusion_layer = True 37 | use_checkpoint = True 38 | use_transformer_ckpt = True 39 | use_text_cross_attention = True 40 | text_dropout = 0.0 41 | fusion_dropout = 0.0 42 | fusion_droppath = 0.1 43 | sub_sentence_present = True 44 | -------------------------------------------------------------------------------- /groundingdino/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/groundingdino/config/__init__.py -------------------------------------------------------------------------------- /groundingdino/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/groundingdino/datasets/__init__.py -------------------------------------------------------------------------------- /groundingdino/models/GroundingDINO/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # Conditional DETR 8 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 9 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 10 | # ------------------------------------------------------------------------ 11 | # Copied from DETR (https://github.com/facebookresearch/detr) 12 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 13 | # ------------------------------------------------------------------------ 14 | 15 | from .groundingdino import build_groundingdino 16 | -------------------------------------------------------------------------------- /groundingdino/models/GroundingDINO/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import build_backbone 2 | -------------------------------------------------------------------------------- /groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | namespace groundingdino { 20 | 21 | at::Tensor 22 | ms_deform_attn_forward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const int im2col_step) 29 | { 30 | if (value.type().is_cuda()) 31 | { 32 | #ifdef WITH_CUDA 33 | return ms_deform_attn_cuda_forward( 34 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 35 | #else 36 | AT_ERROR("Not compiled with GPU support"); 37 | #endif 38 | } 39 | AT_ERROR("Not implemented on the CPU"); 40 | } 41 | 42 | std::vector 43 | ms_deform_attn_backward( 44 | const at::Tensor &value, 45 | const at::Tensor &spatial_shapes, 46 | const at::Tensor &level_start_index, 47 | const at::Tensor &sampling_loc, 48 | const at::Tensor &attn_weight, 49 | const at::Tensor &grad_output, 50 | const int im2col_step) 51 | { 52 | if (value.type().is_cuda()) 53 | { 54 | #ifdef WITH_CUDA 55 | return ms_deform_attn_cuda_backward( 56 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 57 | #else 58 | AT_ERROR("Not compiled with GPU support"); 59 | #endif 60 | } 61 | AT_ERROR("Not implemented on the CPU"); 62 | } 63 | 64 | } // namespace groundingdino -------------------------------------------------------------------------------- /groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | namespace groundingdino { 17 | 18 | at::Tensor 19 | ms_deform_attn_cpu_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step) 26 | { 27 | AT_ERROR("Not implement on cpu"); 28 | } 29 | 30 | std::vector 31 | ms_deform_attn_cpu_backward( 32 | const at::Tensor &value, 33 | const at::Tensor &spatial_shapes, 34 | const at::Tensor &level_start_index, 35 | const at::Tensor &sampling_loc, 36 | const at::Tensor &attn_weight, 37 | const at::Tensor &grad_output, 38 | const int im2col_step) 39 | { 40 | AT_ERROR("Not implement on cpu"); 41 | } 42 | 43 | } // namespace groundingdino 44 | -------------------------------------------------------------------------------- /groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | namespace groundingdino { 15 | 16 | at::Tensor 17 | ms_deform_attn_cpu_forward( 18 | const at::Tensor &value, 19 | const at::Tensor &spatial_shapes, 20 | const at::Tensor &level_start_index, 21 | const at::Tensor &sampling_loc, 22 | const at::Tensor &attn_weight, 23 | const int im2col_step); 24 | 25 | std::vector 26 | ms_deform_attn_cpu_backward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const at::Tensor &grad_output, 33 | const int im2col_step); 34 | 35 | } // namespace groundingdino 36 | -------------------------------------------------------------------------------- /groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | namespace groundingdino { 15 | 16 | at::Tensor ms_deform_attn_cuda_forward( 17 | const at::Tensor &value, 18 | const at::Tensor &spatial_shapes, 19 | const at::Tensor &level_start_index, 20 | const at::Tensor &sampling_loc, 21 | const at::Tensor &attn_weight, 22 | const int im2col_step); 23 | 24 | std::vector ms_deform_attn_cuda_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | } // namespace groundingdino -------------------------------------------------------------------------------- /groundingdino/models/GroundingDINO/csrc/cuda_version.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace groundingdino { 4 | int get_cudart_version() { 5 | return CUDART_VERSION; 6 | } 7 | } // namespace groundingdino 8 | -------------------------------------------------------------------------------- /groundingdino/models/GroundingDINO/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #include "MsDeformAttn/ms_deform_attn.h" 4 | 5 | namespace groundingdino { 6 | 7 | #ifdef WITH_CUDA 8 | extern int get_cudart_version(); 9 | #endif 10 | 11 | std::string get_cuda_version() { 12 | #ifdef WITH_CUDA 13 | std::ostringstream oss; 14 | 15 | // copied from 16 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 17 | auto printCudaStyleVersion = [&](int v) { 18 | oss << (v / 1000) << "." << (v / 10 % 100); 19 | if (v % 10 != 0) { 20 | oss << "." << (v % 10); 21 | } 22 | }; 23 | printCudaStyleVersion(get_cudart_version()); 24 | return oss.str(); 25 | #else 26 | return std::string("not available"); 27 | #endif 28 | } 29 | 30 | // similar to 31 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp 32 | std::string get_compiler_version() { 33 | std::ostringstream ss; 34 | #if defined(__GNUC__) 35 | #ifndef __clang__ 36 | { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } 37 | #endif 38 | #endif 39 | 40 | #if defined(__clang_major__) 41 | { 42 | ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." 43 | << __clang_patchlevel__; 44 | } 45 | #endif 46 | 47 | #if defined(_MSC_VER) 48 | { ss << "MSVC " << _MSC_FULL_VER; } 49 | #endif 50 | return ss.str(); 51 | } 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 55 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 56 | } 57 | 58 | } // namespace groundingdino -------------------------------------------------------------------------------- /groundingdino/models/GroundingDINO/transformer_vanilla.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | """ 10 | DETR Transformer class. 11 | 12 | Copy-paste from torch.nn.Transformer with modifications: 13 | * positional encodings are passed in MHattention 14 | * extra LN at the end of encoder is removed 15 | * decoder returns a stack of activations from all decoding layers 16 | """ 17 | from typing import Optional 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | from torch import Tensor, nn 22 | 23 | from .utils import ( 24 | MLP, 25 | _get_activation_fn, 26 | _get_clones, 27 | gen_encoder_output_proposals, 28 | gen_sineembed_for_position, 29 | sigmoid_focal_loss, 30 | ) 31 | 32 | 33 | class TextTransformer(nn.Module): 34 | def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1): 35 | super().__init__() 36 | self.num_layers = num_layers 37 | self.d_model = d_model 38 | self.nheads = nheads 39 | self.dim_feedforward = dim_feedforward 40 | self.norm = None 41 | 42 | single_encoder_layer = TransformerEncoderLayer( 43 | d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout 44 | ) 45 | self.layers = _get_clones(single_encoder_layer, num_layers) 46 | 47 | def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor): 48 | """ 49 | 50 | Args: 51 | text_attention_mask: bs, num_token 52 | memory_text: bs, num_token, d_model 53 | 54 | Raises: 55 | RuntimeError: _description_ 56 | 57 | Returns: 58 | output: bs, num_token, d_model 59 | """ 60 | 61 | output = memory_text.transpose(0, 1) 62 | 63 | for layer in self.layers: 64 | output = layer(output, src_key_padding_mask=text_attention_mask) 65 | 66 | if self.norm is not None: 67 | output = self.norm(output) 68 | 69 | return output.transpose(0, 1) 70 | 71 | 72 | class TransformerEncoderLayer(nn.Module): 73 | def __init__( 74 | self, 75 | d_model, 76 | nhead, 77 | dim_feedforward=2048, 78 | dropout=0.1, 79 | activation="relu", 80 | normalize_before=False, 81 | ): 82 | super().__init__() 83 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 84 | # Implementation of Feedforward model 85 | self.linear1 = nn.Linear(d_model, dim_feedforward) 86 | self.dropout = nn.Dropout(dropout) 87 | self.linear2 = nn.Linear(dim_feedforward, d_model) 88 | 89 | self.norm1 = nn.LayerNorm(d_model) 90 | self.norm2 = nn.LayerNorm(d_model) 91 | self.dropout1 = nn.Dropout(dropout) 92 | self.dropout2 = nn.Dropout(dropout) 93 | 94 | self.activation = _get_activation_fn(activation) 95 | self.normalize_before = normalize_before 96 | self.nhead = nhead 97 | 98 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 99 | return tensor if pos is None else tensor + pos 100 | 101 | def forward( 102 | self, 103 | src, 104 | src_mask: Optional[Tensor] = None, 105 | src_key_padding_mask: Optional[Tensor] = None, 106 | pos: Optional[Tensor] = None, 107 | ): 108 | # repeat attn mask 109 | if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]: 110 | # bs, num_q, num_k 111 | src_mask = src_mask.repeat(self.nhead, 1, 1) 112 | 113 | q = k = self.with_pos_embed(src, pos) 114 | 115 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0] 116 | 117 | # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 118 | src = src + self.dropout1(src2) 119 | src = self.norm1(src) 120 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 121 | src = src + self.dropout2(src2) 122 | src = self.norm2(src) 123 | return src 124 | -------------------------------------------------------------------------------- /groundingdino/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | from .GroundingDINO import build_groundingdino 9 | 10 | 11 | def build_model(args): 12 | # we use register to maintain models from catdet6 on. 13 | from .registry import MODULE_BUILD_FUNCS 14 | 15 | assert args.modelname in MODULE_BUILD_FUNCS._module_dict 16 | build_func = MODULE_BUILD_FUNCS.get(args.modelname) 17 | model = build_func(args) 18 | return model 19 | -------------------------------------------------------------------------------- /groundingdino/models/registry.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # -*- coding: utf-8 -*- 8 | # @Author: Yihao Chen 9 | # @Date: 2021-08-16 16:03:17 10 | # @Last Modified by: Shilong Liu 11 | # @Last Modified time: 2022-01-23 15:26 12 | # modified from mmcv 13 | 14 | import inspect 15 | from functools import partial 16 | 17 | 18 | class Registry(object): 19 | def __init__(self, name): 20 | self._name = name 21 | self._module_dict = dict() 22 | 23 | def __repr__(self): 24 | format_str = self.__class__.__name__ + "(name={}, items={})".format( 25 | self._name, list(self._module_dict.keys()) 26 | ) 27 | return format_str 28 | 29 | def __len__(self): 30 | return len(self._module_dict) 31 | 32 | @property 33 | def name(self): 34 | return self._name 35 | 36 | @property 37 | def module_dict(self): 38 | return self._module_dict 39 | 40 | def get(self, key): 41 | return self._module_dict.get(key, None) 42 | 43 | def registe_with_name(self, module_name=None, force=False): 44 | return partial(self.register, module_name=module_name, force=force) 45 | 46 | def register(self, module_build_function, module_name=None, force=False): 47 | """Register a module build function. 48 | Args: 49 | module (:obj:`nn.Module`): Module to be registered. 50 | """ 51 | if not inspect.isfunction(module_build_function): 52 | raise TypeError( 53 | "module_build_function must be a function, but got {}".format( 54 | type(module_build_function) 55 | ) 56 | ) 57 | if module_name is None: 58 | module_name = module_build_function.__name__ 59 | if not force and module_name in self._module_dict: 60 | raise KeyError("{} is already registered in {}".format(module_name, self.name)) 61 | self._module_dict[module_name] = module_build_function 62 | 63 | return module_build_function 64 | 65 | 66 | MODULE_BUILD_FUNCS = Registry("model build functions") 67 | -------------------------------------------------------------------------------- /groundingdino/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /groundingdino/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 12 | return torch.stack(b, dim=-1) 13 | 14 | 15 | def box_xyxy_to_cxcywh(x): 16 | x0, y0, x1, y1 = x.unbind(-1) 17 | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] 18 | return torch.stack(b, dim=-1) 19 | 20 | 21 | # modified from torchvision to also return the union 22 | def box_iou(boxes1, boxes2): 23 | area1 = box_area(boxes1) 24 | area2 = box_area(boxes2) 25 | 26 | # import ipdb; ipdb.set_trace() 27 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 28 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 29 | 30 | wh = (rb - lt).clamp(min=0) # [N,M,2] 31 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 32 | 33 | union = area1[:, None] + area2 - inter 34 | 35 | iou = inter / (union + 1e-6) 36 | return iou, union 37 | 38 | 39 | def generalized_box_iou(boxes1, boxes2): 40 | """ 41 | Generalized IoU from https://giou.stanford.edu/ 42 | 43 | The boxes should be in [x0, y0, x1, y1] format 44 | 45 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 46 | and M = len(boxes2) 47 | """ 48 | # degenerate boxes gives inf / nan results 49 | # so do an early check 50 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 51 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 52 | # except: 53 | # import ipdb; ipdb.set_trace() 54 | iou, union = box_iou(boxes1, boxes2) 55 | 56 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 57 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 58 | 59 | wh = (rb - lt).clamp(min=0) # [N,M,2] 60 | area = wh[:, :, 0] * wh[:, :, 1] 61 | 62 | return iou - (area - union) / (area + 1e-6) 63 | 64 | 65 | # modified from torchvision to also return the union 66 | def box_iou_pairwise(boxes1, boxes2): 67 | area1 = box_area(boxes1) 68 | area2 = box_area(boxes2) 69 | 70 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] 71 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] 72 | 73 | wh = (rb - lt).clamp(min=0) # [N,2] 74 | inter = wh[:, 0] * wh[:, 1] # [N] 75 | 76 | union = area1 + area2 - inter 77 | 78 | iou = inter / union 79 | return iou, union 80 | 81 | 82 | def generalized_box_iou_pairwise(boxes1, boxes2): 83 | """ 84 | Generalized IoU from https://giou.stanford.edu/ 85 | 86 | Input: 87 | - boxes1, boxes2: N,4 88 | Output: 89 | - giou: N, 4 90 | """ 91 | # degenerate boxes gives inf / nan results 92 | # so do an early check 93 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 94 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 95 | assert boxes1.shape == boxes2.shape 96 | iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4 97 | 98 | lt = torch.min(boxes1[:, :2], boxes2[:, :2]) 99 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) 100 | 101 | wh = (rb - lt).clamp(min=0) # [N,2] 102 | area = wh[:, 0] * wh[:, 1] 103 | 104 | return iou - (area - union) / area 105 | 106 | 107 | def masks_to_boxes(masks): 108 | """Compute the bounding boxes around the provided masks 109 | 110 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 111 | 112 | Returns a [N, 4] tensors, with the boxes in xyxy format 113 | """ 114 | if masks.numel() == 0: 115 | return torch.zeros((0, 4), device=masks.device) 116 | 117 | h, w = masks.shape[-2:] 118 | 119 | y = torch.arange(0, h, dtype=torch.float) 120 | x = torch.arange(0, w, dtype=torch.float) 121 | y, x = torch.meshgrid(y, x) 122 | 123 | x_mask = masks * x.unsqueeze(0) 124 | x_max = x_mask.flatten(1).max(-1)[0] 125 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 126 | 127 | y_mask = masks * y.unsqueeze(0) 128 | y_max = y_mask.flatten(1).max(-1)[0] 129 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 130 | 131 | return torch.stack([x_min, y_min, x_max, y_max], 1) 132 | 133 | 134 | if __name__ == "__main__": 135 | x = torch.rand(5, 4) 136 | y = torch.rand(3, 4) 137 | iou, union = box_iou(x, y) 138 | import ipdb 139 | 140 | ipdb.set_trace() 141 | -------------------------------------------------------------------------------- /groundingdino/util/get_tokenlizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast 2 | import os 3 | 4 | 5 | def get_tokenlizer(text_encoder_type): 6 | if not isinstance(text_encoder_type, str): 7 | # print("text_encoder_type is not a str") 8 | if hasattr(text_encoder_type, "text_encoder_type"): 9 | text_encoder_type = text_encoder_type.text_encoder_type 10 | elif text_encoder_type.get("text_encoder_type", False): 11 | text_encoder_type = text_encoder_type.get("text_encoder_type") 12 | elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type): 13 | pass 14 | else: 15 | raise ValueError( 16 | "Unknown type of text_encoder_type: {}".format(type(text_encoder_type)) 17 | ) 18 | print("final text_encoder_type: {}".format(text_encoder_type)) 19 | 20 | tokenizer = AutoTokenizer.from_pretrained(f'ckpt/{text_encoder_type}') 21 | return tokenizer 22 | 23 | 24 | def get_pretrained_language_model(text_encoder_type): 25 | if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)): 26 | return BertModel.from_pretrained(f'ckpt/{text_encoder_type}') 27 | if text_encoder_type == "roberta-base": 28 | return RobertaModel.from_pretrained(text_encoder_type) 29 | 30 | raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type)) 31 | -------------------------------------------------------------------------------- /groundingdino/util/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | 7 | from termcolor import colored 8 | 9 | 10 | class _ColorfulFormatter(logging.Formatter): 11 | def __init__(self, *args, **kwargs): 12 | self._root_name = kwargs.pop("root_name") + "." 13 | self._abbrev_name = kwargs.pop("abbrev_name", "") 14 | if len(self._abbrev_name): 15 | self._abbrev_name = self._abbrev_name + "." 16 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 17 | 18 | def formatMessage(self, record): 19 | record.name = record.name.replace(self._root_name, self._abbrev_name) 20 | log = super(_ColorfulFormatter, self).formatMessage(record) 21 | if record.levelno == logging.WARNING: 22 | prefix = colored("WARNING", "red", attrs=["blink"]) 23 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 24 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 25 | else: 26 | return log 27 | return prefix + " " + log 28 | 29 | 30 | # so that calling setup_logger multiple times won't add many handlers 31 | @functools.lru_cache() 32 | def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None): 33 | """ 34 | Initialize the detectron2 logger and set its verbosity level to "INFO". 35 | 36 | Args: 37 | output (str): a file name or a directory to save log. If None, will not save log file. 38 | If ends with ".txt" or ".log", assumed to be a file name. 39 | Otherwise, logs will be saved to `output/log.txt`. 40 | name (str): the root module name of this logger 41 | 42 | Returns: 43 | logging.Logger: a logger 44 | """ 45 | logger = logging.getLogger(name) 46 | logger.setLevel(logging.DEBUG) 47 | logger.propagate = False 48 | 49 | if abbrev_name is None: 50 | abbrev_name = name 51 | 52 | plain_formatter = logging.Formatter( 53 | "[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S" 54 | ) 55 | # stdout logging: master only 56 | if distributed_rank == 0: 57 | ch = logging.StreamHandler(stream=sys.stdout) 58 | ch.setLevel(logging.DEBUG) 59 | if color: 60 | formatter = _ColorfulFormatter( 61 | colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s", 62 | datefmt="%m/%d %H:%M:%S", 63 | root_name=name, 64 | abbrev_name=str(abbrev_name), 65 | ) 66 | else: 67 | formatter = plain_formatter 68 | ch.setFormatter(formatter) 69 | logger.addHandler(ch) 70 | 71 | # file logging: all workers 72 | if output is not None: 73 | if output.endswith(".txt") or output.endswith(".log"): 74 | filename = output 75 | else: 76 | filename = os.path.join(output, "log.txt") 77 | if distributed_rank > 0: 78 | filename = filename + f".rank{distributed_rank}" 79 | os.makedirs(os.path.dirname(filename), exist_ok=True) 80 | 81 | fh = logging.StreamHandler(_cached_log_stream(filename)) 82 | fh.setLevel(logging.DEBUG) 83 | fh.setFormatter(plain_formatter) 84 | logger.addHandler(fh) 85 | 86 | return logger 87 | 88 | 89 | # cache the opened file object, so that different calls to `setup_logger` 90 | # with the same file name can safely write to the same file. 91 | @functools.lru_cache(maxsize=None) 92 | def _cached_log_stream(filename): 93 | return open(filename, "a") 94 | -------------------------------------------------------------------------------- /groundingdino/util/time_counter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | 5 | class TimeCounter: 6 | def __init__(self) -> None: 7 | pass 8 | 9 | def clear(self): 10 | self.timedict = {} 11 | self.basetime = time.perf_counter() 12 | 13 | def timeit(self, name): 14 | nowtime = time.perf_counter() - self.basetime 15 | self.timedict[name] = nowtime 16 | self.basetime = time.perf_counter() 17 | 18 | 19 | class TimeHolder: 20 | def __init__(self) -> None: 21 | self.timedict = {} 22 | 23 | def update(self, _timedict: dict): 24 | for k, v in _timedict.items(): 25 | if k not in self.timedict: 26 | self.timedict[k] = AverageMeter(name=k, val_only=True) 27 | self.timedict[k].update(val=v) 28 | 29 | def final_res(self): 30 | return {k: v.avg for k, v in self.timedict.items()} 31 | 32 | def __str__(self): 33 | return json.dumps(self.final_res(), indent=2) 34 | 35 | 36 | class AverageMeter(object): 37 | """Computes and stores the average and current value""" 38 | 39 | def __init__(self, name, fmt=":f", val_only=False): 40 | self.name = name 41 | self.fmt = fmt 42 | self.val_only = val_only 43 | self.reset() 44 | 45 | def reset(self): 46 | self.val = 0 47 | self.avg = 0 48 | self.sum = 0 49 | self.count = 0 50 | 51 | def update(self, val, n=1): 52 | self.val = val 53 | self.sum += val * n 54 | self.count += n 55 | self.avg = self.sum / self.count 56 | 57 | def __str__(self): 58 | if self.val_only: 59 | fmtstr = "{name} {val" + self.fmt + "}" 60 | else: 61 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 62 | return fmtstr.format(**self.__dict__) 63 | -------------------------------------------------------------------------------- /groundingdino/util/vl_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import List 4 | 5 | import torch 6 | 7 | 8 | def create_positive_map_from_span(tokenized, token_span, max_text_len=256): 9 | """construct a map such that positive_map[i,j] = True iff box i is associated to token j 10 | Input: 11 | - tokenized: 12 | - input_ids: Tensor[1, ntokens] 13 | - attention_mask: Tensor[1, ntokens] 14 | - token_span: list with length num_boxes. 15 | - each item: [start_idx, end_idx] 16 | """ 17 | positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float) 18 | for j, tok_list in enumerate(token_span): 19 | for (beg, end) in tok_list: 20 | beg_pos = tokenized.char_to_token(beg) 21 | end_pos = tokenized.char_to_token(end - 1) 22 | if beg_pos is None: 23 | try: 24 | beg_pos = tokenized.char_to_token(beg + 1) 25 | if beg_pos is None: 26 | beg_pos = tokenized.char_to_token(beg + 2) 27 | except: 28 | beg_pos = None 29 | if end_pos is None: 30 | try: 31 | end_pos = tokenized.char_to_token(end - 2) 32 | if end_pos is None: 33 | end_pos = tokenized.char_to_token(end - 3) 34 | except: 35 | end_pos = None 36 | if beg_pos is None or end_pos is None: 37 | continue 38 | 39 | assert beg_pos is not None and end_pos is not None 40 | if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE": 41 | positive_map[j, beg_pos] = 1 42 | break 43 | else: 44 | positive_map[j, beg_pos : end_pos + 1].fill_(1) 45 | 46 | return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) 47 | 48 | 49 | def build_captions_and_token_span(cat_list, force_lowercase): 50 | """ 51 | Return: 52 | captions: str 53 | cat2tokenspan: dict 54 | { 55 | 'dog': [[0, 2]], 56 | ... 57 | } 58 | """ 59 | 60 | cat2tokenspan = {} 61 | captions = "" 62 | for catname in cat_list: 63 | class_name = catname 64 | if force_lowercase: 65 | class_name = class_name.lower() 66 | if "/" in class_name: 67 | class_name_list: List = class_name.strip().split("/") 68 | class_name_list.append(class_name) 69 | class_name: str = random.choice(class_name_list) 70 | 71 | tokens_positive_i = [] 72 | subnamelist = [i.strip() for i in class_name.strip().split(" ")] 73 | for subname in subnamelist: 74 | if len(subname) == 0: 75 | continue 76 | if len(captions) > 0: 77 | captions = captions + " " 78 | strat_idx = len(captions) 79 | end_idx = strat_idx + len(subname) 80 | tokens_positive_i.append([strat_idx, end_idx]) 81 | captions = captions + subname 82 | 83 | if len(tokens_positive_i) > 0: 84 | captions = captions + " ." 85 | cat2tokenspan[class_name] = tokens_positive_i 86 | 87 | return captions, cat2tokenspan 88 | 89 | 90 | def build_id2posspan_and_caption(category_dict: dict): 91 | """Build id2pos_span and caption from category_dict 92 | 93 | Args: 94 | category_dict (dict): category_dict 95 | """ 96 | cat_list = [item["name"].lower() for item in category_dict] 97 | id2catname = {item["id"]: item["name"].lower() for item in category_dict} 98 | caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True) 99 | id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()} 100 | return id2posspan, caption 101 | -------------------------------------------------------------------------------- /groundingdino/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' 2 | -------------------------------------------------------------------------------- /img2vid.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | # set the directory containing the images 5 | img_dir = './assets/840_iSXIa0hE8Ek' 6 | 7 | # set the output video file name and codec 8 | out_file = './assets/840_iSXIa0hE8Ek.mp4' 9 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 10 | 11 | # get the dimensions of the first image 12 | img_path = os.path.join(img_dir, os.listdir(img_dir)[0]) 13 | img = cv2.imread(img_path) 14 | height, width, channels = img.shape 15 | 16 | # create the VideoWriter object 17 | out = cv2.VideoWriter(out_file, fourcc, 10, (width, height)) 18 | 19 | # loop through the images and write them to the video 20 | for img_name in sorted(os.listdir(img_dir)): 21 | img_path = os.path.join(img_dir, img_name) 22 | img = cv2.imread(img_path) 23 | out.write(img) 24 | 25 | # release the VideoWriter object and close the video file 26 | out.release() 27 | -------------------------------------------------------------------------------- /model_args.py: -------------------------------------------------------------------------------- 1 | # Explanation of generator_args is in sam/segment_anything/automatic_mask_generator.py: SamAutomaticMaskGenerator 2 | sam_args = { 3 | 'sam_checkpoint': "ckpt/sam_vit_b_01ec64.pth", 4 | 'model_type': "vit_b", 5 | 'generator_args':{ 6 | 'points_per_side': 16, 7 | 'pred_iou_thresh': 0.8, 8 | 'stability_score_thresh': 0.9, 9 | 'crop_n_layers': 1, 10 | 'crop_n_points_downscale_factor': 2, 11 | 'min_mask_region_area': 200, 12 | }, 13 | 'gpu_id': 0, 14 | } 15 | aot_args = { 16 | 'phase': 'PRE_YTB_DAV', 17 | 'model': 'r50_deaotl', 18 | 'model_path': 'ckpt/R50_DeAOTL_PRE_YTB_DAV.pth', 19 | 'long_term_mem_gap': 9999, 20 | 'max_len_long_term': 9999, 21 | 'gpu_id': 0, 22 | } 23 | segtracker_args = { 24 | 'sam_gap': 10, # the interval to run sam to segment new objects 25 | 'min_area': 200, # minimal mask area to add a new mask as a new object 26 | 'max_obj_num': 255, # maximal object number to track in a video 27 | 'min_new_obj_iou': 0.8, # the background area ratio of a new object should > 80% 28 | } -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # A webui for Propainter 2 | [简体中文](./readme_zh.md)\ 3 | A webui that you can easily pick up objects from the video and eliminate them. 4 | 5 | ## faster-propainter | 提速版 6 | https://github.com/halfzm/faster-propainter 7 | 8 | ## Demo 9 | ![](./demo.gif) 10 | 11 | ## Getting started 12 | If you don't want to install the environment, you can download the package directly.\ 13 | link [百度网盘](https://pan.baidu.com/s/1XkQhzCzTtzVfgQg5heQQrA?pwd=jo38 )\ 14 | tutorial [bilibili](https://www.bilibili.com/video/BV1NH4y1o7mS/) [youtube](https://www.youtube.com/watch?v=CcivHjbHIcQ) 15 | 16 | ### clone repo 17 | ``` 18 | git clone git@github.com:halfzm/ProPainter-Webui.git 19 | ``` 20 | 21 | ### create conda enviroment 22 | ``` 23 | conda create -n propainter python=3.10 24 | conda activate propainter 25 | ``` 26 | 27 | ### install dependencies 28 | Just follow the instructions in [Segment-ant-Track-Anything](https://github.com/z-x-yang/Segment-and-Track-Anything) 29 | 和 [ProPainter](https://github.com/sczhou/ProPainter)(P.S.Don't need to install groundingdino, I have put it in the project.) 30 | ``` 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ### prepare pretrained models 35 | Download all the needed models for propainter \ 36 | [propainter](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/ProPainter.pth)\ 37 | [raft-things](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/raft-things.pth)\ 38 | [recurrent_flow_completion](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/recurrent_flow_completion.pth)\ 39 | [i3d_rgb_imagenet](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/i3d_rgb_imagenet.pt) 40 | 41 | Download all the needed models for segment-and-track-anything\ 42 | SAM-VIT-B ([sam_vit_b_01ec64.pth](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth))\ 43 | R50-DeAOT-L ([R50_DeAOTL_PRE_YTB_DAV.pth](https://drive.google.com/file/d/1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ/view))\ 44 | GroundingDINO-T ([groundingdino_swint_ogc](https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth))\ 45 | 46 | The directory structure will be arranged as: 47 | ``` 48 | ckpt 49 | |- bert-base-uncased 50 | |- groundingdino_swint_ogc.pth 51 | |- R50_EdAOTL_PRE_YTB_DAV.pth 52 | |- sam_vit_b_01ec64.pth 53 | ... 54 | ProPainter/weights 55 | |- ProPainter.pth 56 | |- recurrent_flow_completion.pth 57 | |- raft-things.pth 58 | |- i3d_rgb_imagenet.pt (for evaluating VFID metric) 59 | |- README.md 60 | ``` 61 | 62 | ### quick start 63 | ``` 64 | python app.py 65 | ``` 66 | 67 | 68 | ## Reference 69 | - [Segment-ant-Track-Anything](https://github.com/z-x-yang/Segment-and-Track-Anything) 70 | - [ProPainter](https://github.com/sczhou/ProPainter) 71 | 72 | 73 | ## Star History 74 | [![Star History Chart](https://api.star-history.com/svg?repos=halfzm/ProPainter-Webui&type=Date)](https://star-history.com/#halfzm/ProPainter-Webui&Date) 75 | -------------------------------------------------------------------------------- /readme_en.md: -------------------------------------------------------------------------------- 1 | # A webui for Propainter 2 | A webui that you can easily pick up objects from the video and eliminate them. 3 | 4 | ## Demo 5 | ![](./demo.gif) 6 | 7 | ## Getting started 8 | If you don't want to install the environment, you can download the package directly.\ 9 | link [百度网盘](https://pan.baidu.com/s/1XkQhzCzTtzVfgQg5heQQrA?pwd=jo38 )\ 10 | tutorial [bilibili](https://www.bilibili.com/video/BV1NH4y1o7mS/) [youtube](https://www.youtube.com/watch?v=CcivHjbHIcQ) 11 | 12 | ### clone repo 13 | ``` 14 | git clone https://github.com/halfzm/ProPainiter-Webui.git 15 | ``` 16 | 17 | ### create conda enviroment 18 | ``` 19 | conda create -n propainter python=3.10 20 | conda activate propainter 21 | ``` 22 | 23 | ### install dependencies 24 | Just follow the instructions in [Segment-ant-Track-Anything](https://github.com/z-x-yang/Segment-and-Track-Anything) 25 | 和 [ProPainter](https://github.com/sczhou/ProPainter)(P.S.Don't need to install groundingdino, I have put it in the project.) 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ### prepare pretrained models 31 | Download all the needed models for propainter \ 32 | [propainter](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/ProPainter.pth)\ 33 | [raft-things](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/raft-things.pth)\ 34 | [recurrent_flow_completion](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/recurrent_flow_completion.pth)\ 35 | [i3d_rgb_imagenet](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/i3d_rgb_imagenet.pt) 36 | 37 | Download all the needed models for segment-and-track-anything\ 38 | SAM-VIT-B ([sam_vit_b_01ec64.pth](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth))\ 39 | R50-DeAOT-L ([R50_DeAOTL_PRE_YTB_DAV.pth](https://drive.google.com/file/d/1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ/view))\ 40 | GroundingDINO-T ([groundingdino_swint_ogc](https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth))\ 41 | 42 | The directory structure will be arranged as: 43 | ``` 44 | ckpt 45 | |- bert-base-uncased 46 | |- groundingdino_swint_ogc.pth 47 | |- R50_EdAOTL_PRE_YTB_DAV.pth 48 | |- sam_vit_b_01ec64.pth 49 | ... 50 | ProPainter/weights 51 | |- ProPainter.pth 52 | |- recurrent_flow_completion.pth 53 | |- raft-things.pth 54 | |- i3d_rgb_imagenet.pt (for evaluating VFID metric) 55 | |- README.md 56 | ``` 57 | 58 | ### quick start 59 | ``` 60 | python app.py 61 | ``` 62 | 63 | 64 | ## Reference 65 | - [Segment-ant-Track-Anything](https://github.com/z-x-yang/Segment-and-Track-Anything) 66 | - [ProPainter](https://github.com/sczhou/ProPainter) 67 | 68 | 69 | ## Star History 70 | [![Star History Chart](https://api.star-history.com/svg?repos=halfzm/ProPainter-Webui&type=Date)](https://star-history.com/#halfzm/ProPainter-Webui&Date) -------------------------------------------------------------------------------- /readme_zh.md: -------------------------------------------------------------------------------- 1 | # Propainter 的一个简单 web-ui 2 | 一个可以快速选择物体并将其从视频中消除的web-ui 3 | 4 | ## 效果演示 5 | ![](./demo.gif) 6 | 7 | ## 使用 8 | 如果不想安装环境也可以直接下载整合包,下载后双击start.bat即可\ 9 | 下载链接 [百度网盘](https://pan.baidu.com/s/1XkQhzCzTtzVfgQg5heQQrA?pwd=jo38 )\ 10 | 使用教程 [bilibili](https://www.bilibili.com/video/BV1NH4y1o7mS/) [youtube](https://www.youtube.com/watch?v=CcivHjbHIcQ) 11 | 12 | ### 克隆项目到本地 13 | ``` 14 | git clone https://github.com/halfzm/ProPainiter-Webui.git 15 | ``` 16 | 17 | ### 创建虚拟环境 18 | ``` 19 | conda create -n propainter python=3.10 20 | conda activate propainter 21 | ``` 22 | 23 | ### 安装依赖 24 | 请参考 [Segment-ant-Track-Anything](https://github.com/z-x-yang/Segment-and-Track-Anything) 25 | 和 [ProPainter](https://github.com/sczhou/ProPainter) 项目中的安装要求(P.S.无需安装groundingdino) 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ### 下载预训练模型 31 | 下载 propainter 需要的模型 \ 32 | [propainter](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/ProPainter.pth)\ 33 | [raft-things](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/raft-things.pth)\ 34 | [recurrent_flow_completion](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/recurrent_flow_completion.pth)\ 35 | [i3d_rgb_imagenet](https://github.com/sczhou/ProPainter/releases/download/v0.1.0/i3d_rgb_imagenet.pt) 36 | 37 | 下载 segment-and-track-anything 所需的模型\ 38 | SAM-VIT-B ([sam_vit_b_01ec64.pth](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth))\ 39 | R50-DeAOT-L ([R50_DeAOTL_PRE_YTB_DAV.pth](https://drive.google.com/file/d/1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ/view))\ 40 | GroundingDINO-T ([groundingdino_swint_ogc](https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth))\ 41 | 42 | 下载后的文件结构应该如下: 43 | ``` 44 | ckpt 45 | |- bert-base-uncased 46 | |- groundingdino_swint_ogc.pth 47 | |- R50_EdAOTL_PRE_YTB_DAV.pth 48 | |- sam_vit_b_01ec64.pth 49 | ... 50 | ProPainter/weights 51 | |- ProPainter.pth 52 | |- recurrent_flow_completion.pth 53 | |- raft-things.pth 54 | |- i3d_rgb_imagenet.pt (for evaluating VFID metric) 55 | |- README.md 56 | ``` 57 | 58 | ### 快速启动 59 | ``` 60 | python app.py 61 | ``` 62 | 63 | 64 | ## 参考 65 | - [Segment-ant-Track-Anything](https://github.com/z-x-yang/Segment-and-Track-Anything) 66 | - [ProPainter](https://github.com/sczhou/ProPainter) 67 | 68 | 69 | ## 星标历史 70 | [![Star History Chart](https://api.star-history.com/svg?repos=halfzm/ProPainter-Webui&type=Date)](https://star-history.com/#halfzm/ProPainter-Webui&Date) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.5 2 | onnx==1.14.0 3 | onnxruntime-gpu==1.15.0 4 | torch==2.0.1+cu117 5 | torchvision==0.15.2+cu117 6 | transformers==4.34.0 7 | 8 | -------------------------------------------------------------------------------- /sam/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002 3 | max-line-length = 100 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | per-file-ignores = 7 | **/__init__.py:F401,F403,E402 8 | -------------------------------------------------------------------------------- /sam/.gitignore: -------------------------------------------------------------------------------- 1 | .nfs* 2 | 3 | # compilation and distribution 4 | __pycache__ 5 | _ext 6 | *.pyc 7 | *.pyd 8 | *.so 9 | *.dll 10 | *.egg-info/ 11 | build/ 12 | dist/ 13 | wheels/ 14 | 15 | # pytorch/python/numpy formats 16 | *.pth 17 | *.pkl 18 | *.npy 19 | *.ts 20 | model_ts*.txt 21 | 22 | # onnx models 23 | *.onnx 24 | 25 | # ipython/jupyter notebooks 26 | **/.ipynb_checkpoints/ 27 | 28 | # Editor temporaries 29 | *.swn 30 | *.swo 31 | *.swp 32 | *~ 33 | 34 | # editor settings 35 | .idea 36 | .vscode 37 | _darcs 38 | -------------------------------------------------------------------------------- /sam/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /sam/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /sam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/__init__.py -------------------------------------------------------------------------------- /sam/assets/masks1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/assets/masks1.png -------------------------------------------------------------------------------- /sam/assets/masks2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/assets/masks2.jpg -------------------------------------------------------------------------------- /sam/assets/model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/assets/model_diagram.png -------------------------------------------------------------------------------- /sam/assets/notebook1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/assets/notebook1.png -------------------------------------------------------------------------------- /sam/assets/notebook2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/assets/notebook2.png -------------------------------------------------------------------------------- /sam/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | { 5 | black --version | grep -E "23\." > /dev/null 6 | } || { 7 | echo "Linter requires 'black==23.*' !" 8 | exit 1 9 | } 10 | 11 | ISORT_VERSION=$(isort --version-number) 12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then 13 | echo "Linter requires isort==5.12.0 !" 14 | exit 1 15 | fi 16 | 17 | echo "Running isort ..." 18 | isort . --atomic 19 | 20 | echo "Running black ..." 21 | black -l 100 . 22 | 23 | echo "Running flake8 ..." 24 | if [ -x "$(command -v flake8)" ]; then 25 | flake8 . 26 | else 27 | python3 -m flake8 . 28 | fi 29 | 30 | echo "Running mypy..." 31 | 32 | mypy --exclude 'setup.py|notebooks' . 33 | -------------------------------------------------------------------------------- /sam/notebooks/images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/notebooks/images/dog.jpg -------------------------------------------------------------------------------- /sam/notebooks/images/groceries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/notebooks/images/groceries.jpg -------------------------------------------------------------------------------- /sam/notebooks/images/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/notebooks/images/truck.jpg -------------------------------------------------------------------------------- /sam/segment_anything/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/halfzm/ProPainter-Webui/5165465a025803a2821308b1eac9709293c8981f/sam/segment_anything/.DS_Store -------------------------------------------------------------------------------- /sam/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /sam/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /sam/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /sam/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /sam/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /sam/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=segment_anything 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /sam/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="segment_anything", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /tool/detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import PIL 5 | 6 | from groundingdino.models import build_model as build_grounding_dino 7 | from groundingdino.util.slconfig import SLConfig 8 | from groundingdino.util.utils import clean_state_dict 9 | from groundingdino.util.inference import annotate, predict 10 | import groundingdino.datasets.transforms as T 11 | 12 | from torchvision.ops import box_convert 13 | 14 | 15 | class Detector: 16 | 17 | def __init__(self, device): 18 | config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py" 19 | grounding_dino_ckpt = './ckpt/groundingdino_swint_ogc.pth' 20 | args = SLConfig.fromfile(config_file) 21 | args.device = device 22 | self.deivce = device 23 | self.gd = build_grounding_dino(args) 24 | 25 | checkpoint = torch.load(grounding_dino_ckpt, map_location='cpu') 26 | log = self.gd.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) 27 | print("Model loaded from {} \n => {}".format(grounding_dino_ckpt, log)) 28 | self.gd.eval() 29 | 30 | def image_transform_grounding(self, init_image): 31 | transform = T.Compose([ 32 | T.RandomResize([800], max_size=1333), 33 | T.ToTensor(), 34 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 35 | ]) 36 | image, _ = transform(init_image, None) # 3, h, w 37 | return init_image, image 38 | 39 | def image_transform_grounding_for_vis(self, init_image): 40 | transform = T.Compose([ 41 | T.RandomResize([800], max_size=1333), 42 | ]) 43 | image, _ = transform(init_image, None) # 3, h, w 44 | return image 45 | 46 | def transfer_boxes_format(self, boxes, height, width): 47 | boxes = boxes * torch.Tensor([width, height, width, height]) 48 | boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy") 49 | 50 | transfered_boxes = [] 51 | for i in range(len(boxes)): 52 | box = boxes[i] 53 | transfered_box = [[int(box[0]), int(box[1])], [int(box[2]), int(box[3])]] 54 | transfered_boxes.append(transfered_box) 55 | 56 | transfered_boxes = np.array(transfered_boxes) 57 | return transfered_boxes 58 | 59 | @torch.no_grad() 60 | def run_grounding(self, origin_frame, grounding_caption, box_threshold, text_threshold): 61 | ''' 62 | return: 63 | annotated_frame:nd.array 64 | transfered_boxes: nd.array [N, 4]: [[x0, y0], [x1, y1]] 65 | ''' 66 | height, width, _ = origin_frame.shape 67 | img_pil = PIL.Image.fromarray(origin_frame) 68 | re_width, re_height = img_pil.size 69 | _, image_tensor = self.image_transform_grounding(img_pil) 70 | # img_pil = self.image_transform_grounding_for_vis(img_pil) 71 | 72 | # run grounidng 73 | boxes, logits, phrases = predict(self.gd, 74 | image_tensor, 75 | grounding_caption, 76 | box_threshold, 77 | text_threshold, 78 | device=self.deivce) 79 | annotated_frame = annotate(image_source=np.asarray(img_pil), boxes=boxes, logits=logits, 80 | phrases=phrases)[:, :, ::-1] 81 | annotated_frame = cv2.resize(annotated_frame, (width, height), interpolation=cv2.INTER_LINEAR) 82 | 83 | # transfer boxes to sam-format 84 | transfered_boxes = self.transfer_boxes_format(boxes, re_height, re_width) 85 | return annotated_frame, transfered_boxes 86 | -------------------------------------------------------------------------------- /tool/segmentor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from sam.segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator 5 | 6 | class Segmentor: 7 | def __init__(self, sam_args): 8 | """ 9 | sam_args: 10 | sam_checkpoint: path of SAM checkpoint 11 | generator_args: args for everything_generator 12 | gpu_id: device 13 | """ 14 | self.device = sam_args["gpu_id"] 15 | self.sam = sam_model_registry[sam_args["model_type"]](checkpoint=sam_args["sam_checkpoint"]) 16 | self.sam.to(device=self.device) 17 | self.everything_generator = SamAutomaticMaskGenerator(model=self.sam, **sam_args['generator_args']) 18 | self.interactive_predictor = self.everything_generator.predictor 19 | self.have_embedded = False 20 | 21 | @torch.no_grad() 22 | def set_image(self, image): 23 | # calculate the embedding only once per frame. 24 | if not self.have_embedded: 25 | self.interactive_predictor.set_image(image) 26 | self.have_embedded = True 27 | @torch.no_grad() 28 | def interactive_predict(self, prompts, mode, multimask=True): 29 | assert self.have_embedded, 'image embedding for sam need be set before predict.' 30 | 31 | if mode == 'point': 32 | masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], 33 | point_labels=prompts['point_modes'], 34 | multimask_output=multimask) 35 | elif mode == 'mask': 36 | masks, scores, logits = self.interactive_predictor.predict(mask_input=prompts['mask_prompt'], 37 | multimask_output=multimask) 38 | elif mode == 'point_mask': 39 | masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], 40 | point_labels=prompts['point_modes'], 41 | mask_input=prompts['mask_prompt'], 42 | multimask_output=multimask) 43 | 44 | return masks, scores, logits 45 | 46 | @torch.no_grad() 47 | def segment_with_click(self, origin_frame, coords, modes, multimask=True): 48 | ''' 49 | 50 | return: 51 | mask: one-hot 52 | ''' 53 | self.set_image(origin_frame) 54 | 55 | prompts = { 56 | 'point_coords': coords, 57 | 'point_modes': modes, 58 | } 59 | masks, scores, logits = self.interactive_predict(prompts, 'point', multimask) 60 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] 61 | prompts = { 62 | 'point_coords': coords, 63 | 'point_modes': modes, 64 | 'mask_prompt': logit[None, :, :] 65 | } 66 | masks, scores, logits = self.interactive_predict(prompts, 'point_mask', multimask) 67 | mask = masks[np.argmax(scores)] 68 | 69 | return mask.astype(np.uint8) 70 | 71 | def segment_with_box(self, origin_frame, bbox, reset_image=False): 72 | if reset_image: 73 | self.interactive_predictor.set_image(origin_frame) 74 | else: 75 | self.set_image(origin_frame) 76 | # coord = np.array([[int((bbox[1][0] - bbox[0][0]) / 2.), int((bbox[1][1] - bbox[0][1]) / 2)]]) 77 | # point_label = np.array([1]) 78 | 79 | masks, scores, logits = self.interactive_predictor.predict( 80 | point_coords=None, 81 | point_labels=None, 82 | box=np.array([bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]), 83 | multimask_output=True 84 | ) 85 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] 86 | 87 | masks, scores, logits = self.interactive_predictor.predict( 88 | point_coords=None, 89 | point_labels=None, 90 | box=np.array([[bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]]), 91 | mask_input=logit[None, :, :], 92 | multimask_output=True 93 | ) 94 | mask = masks[np.argmax(scores)] 95 | 96 | return [mask] 97 | -------------------------------------------------------------------------------- /tool/transfer_tools.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def mask2bbox(mask): 5 | if len(np.where(mask > 0)[0]) == 0: 6 | print(f'not mask') 7 | return np.array([[0, 0], [0, 0]]).astype(np.int64) 8 | 9 | x_ = np.sum(mask, axis=0) 10 | y_ = np.sum(mask, axis=1) 11 | 12 | x0 = np.min(np.nonzero(x_)[0]) 13 | x1 = np.max(np.nonzero(x_)[0]) 14 | y0 = np.min(np.nonzero(y_)[0]) 15 | y1 = np.max(np.nonzero(y_)[0]) 16 | 17 | return np.array([[x0, y0], [x1, y1]]).astype(np.int64) 18 | 19 | def draw_outline(mask, frame): 20 | _, binary_mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY) 21 | 22 | contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 23 | 24 | cv2.drawContours(frame, contours, -1, (0, 0, 255), 2) 25 | 26 | return frame 27 | 28 | def draw_points(points, modes, frame): 29 | neg_points = points[np.argwhere(modes==0)[:, 0]] 30 | pos_points = points[np.argwhere(modes==1)[:, 0]] 31 | 32 | for i in range(len(neg_points)): 33 | point = neg_points[i] 34 | cv2.circle(frame, (point[0], point[1]), 8, (255, 80, 80), -1) 35 | 36 | for i in range(len(pos_points)): 37 | point = pos_points[i] 38 | cv2.circle(frame, (point[0], point[1]), 8, (0, 153, 255), -1) 39 | 40 | return frame 41 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def resolve_relative_path(path: str) -> str: 5 | return os.path.abspath(os.path.join(os.path.dirname(__file__), path)) 6 | --------------------------------------------------------------------------------