├── RAFT ├── 1.png ├── LICENSE ├── RAFT.png ├── README.md ├── __init__.py ├── alt_cuda_corr │ ├── correlation.cpp │ ├── correlation_kernel.cu │ └── setup.py ├── chairs_split.txt ├── core │ ├── __init__.py │ ├── corr.py │ ├── datasets.py │ ├── extractor.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── frame_utils.py │ │ └── utils.py ├── demo-frames │ ├── frame_0016.png │ └── frame_0017.png ├── demo.py ├── download_models.sh ├── evaluate.py ├── example.png ├── flow_warp.py ├── logs │ └── hadoop.kylin.libdfs.log ├── models │ ├── 1.py │ └── raft-things.pth-no-zip ├── tosave.npy ├── train.py ├── train_mixed.sh └── train_standard.sh ├── RAFT_core ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── corr.cpython-37.pyc │ ├── extractor.cpython-37.pyc │ ├── raft.cpython-37.pyc │ └── update.cpython-37.pyc ├── corr.py ├── datasets.py ├── extractor.py ├── raft-things.pth-no-zip ├── raft.py ├── update.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── utils.cpython-37.pyc │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ └── utils.py ├── README.md ├── TC_cal.py ├── VC_perclip.py ├── change2_480p.py ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── defaults.cpython-37.pyc ├── defaults.py ├── vsp-hrnetv2.yaml ├── vsp-mobilenetv2dilated-c1_deepsup.yaml ├── vsp-mobilenetv2dilated-ppm_deepsup.yaml ├── vsp-resnet101-upernet.yaml ├── vsp-resnet101dilated-deeplab.yaml ├── vsp-resnet101dilated-nonlocal2d.yaml ├── vsp-resnet101dilated-ocr_deepsup.yaml ├── vsp-resnet101dilated-ppm_clip.yaml ├── vsp-resnet101dilated-ppm_deepsup.yaml ├── vsp-resnet101dilated-ppm_deepsup_clip.yaml ├── vsp-resnet101dilated_tdnet.yaml ├── vsp-resnet18dilated-ppm_deepsup.yaml ├── vsp-resnet18dilated-ppm_deepsup_clip.yaml ├── vsp-resnet50-upernet.yaml ├── vsp-resnet50dilated-deeplab.yaml ├── vsp-resnet50dilated-ppm_deepsup.yaml ├── vsp-resnet50dilated-ppm_deepsup_clip.yaml └── vsp-resnet50dilated-tdnet.yaml ├── dataset.py ├── dataset2.py ├── lib ├── nn │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ ├── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── batchnorm.cpython-37.pyc │ │ │ ├── comm.cpython-37.pyc │ │ │ └── replicate.cpython-37.pyc │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ ├── tests │ │ │ ├── test_numeric_batchnorm.py │ │ │ └── test_sync_batchnorm.py │ │ └── unittest.py │ └── parallel │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── data_parallel.cpython-37.pyc │ │ └── data_parallel.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── th.cpython-37.pyc │ ├── data │ ├── __init__.py │ ├── dataloader.py │ ├── dataset.py │ ├── distributed.py │ └── sampler.py │ └── th.py ├── models ├── .non_local2d.py.swp ├── .propnet.py.swo ├── .propnet.py.swp ├── BiConvLSTM.py ├── ETC.py ├── ETC_ocr.py ├── __init__.py ├── __pycache__ │ ├── BiConvLSTM.cpython-37.pyc │ ├── ETC.cpython-37.pyc │ ├── ETC_ocr.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── clip_ocr.cpython-37.pyc │ ├── clip_psp.cpython-37.pyc │ ├── deeplab.cpython-37.pyc │ ├── hrnet.cpython-37.pyc │ ├── hrnet_clip.cpython-37.pyc │ ├── mobilenet.cpython-37.pyc │ ├── models.cpython-37.pyc │ ├── netwarp.cpython-37.pyc │ ├── netwarp_ocr.cpython-37.pyc │ ├── non_local.cpython-37.pyc │ ├── non_local_models.cpython-37.pyc │ ├── ocrnet.cpython-37.pyc │ ├── propnet.cpython-37.pyc │ ├── resnet.cpython-37.pyc │ ├── resnext.cpython-37.pyc │ ├── utils.cpython-37.pyc │ ├── warp_our.cpython-37.pyc │ └── warp_our_merge.cpython-37.pyc ├── clip_ocr.py ├── clip_psp.py ├── deeplab.py ├── deeplabv3 │ ├── aspp.py │ └── decoder.py ├── hrnet.py ├── hrnet_clip.py ├── hrnet_clip_2.py ├── mobilenet.py ├── models.py ├── netwarp.py ├── netwarp_ocr.py ├── non_local.py ├── non_local_models.py ├── ocr_modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── spatial_ocr_block.cpython-37.pyc │ ├── spatial_ocr_block.py │ └── spatial_ocr_block_max.py ├── ocrnet.py ├── propnet.py ├── resnet.py ├── resnext.py ├── sync_batchnorm │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── batchnorm.cpython-37.pyc │ │ ├── comm.cpython-37.pyc │ │ └── replicate.cpython-37.pyc │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── td4_psp │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── loss.cpython-37.pyc │ │ ├── td4_psp.cpython-37.pyc │ │ └── transformer.cpython-37.pyc │ ├── loss.py │ ├── pspnet_4p.py │ ├── resnet_bak.py │ ├── td4_psp.py │ ├── td4_psp_bak.py │ ├── transformer.py │ └── utils │ │ ├── __init__.py │ │ ├── files.py │ │ └── model_store.py ├── utils.py ├── warp_our.py └── warp_our_merge.py ├── scripts ├── run_etc.sh ├── run_netwarp.sh ├── run_ocr.sh ├── run_psp.sh ├── run_temporal_ocr.sh └── run_temporal_psp.sh ├── test.py ├── test_clip2.py ├── train.py ├── train_clip2.py └── utils.py /RAFT/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/1.png -------------------------------------------------------------------------------- /RAFT/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, princeton-vl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /RAFT/RAFT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/RAFT.png -------------------------------------------------------------------------------- /RAFT/README.md: -------------------------------------------------------------------------------- 1 | # RAFT 2 | This repository contains the source code for our paper: 3 | 4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 | 8 | 9 | 10 | ## Requirements 11 | The code has been tested with PyTorch 1.6 and Cuda 10.1. 12 | ```Shell 13 | conda create --name raft 14 | conda activate raft 15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch 16 | conda install matplotlib 17 | conda install tensorboard 18 | conda install scipy 19 | conda install opencv 20 | ``` 21 | 22 | ## Demos 23 | Pretrained models can be downloaded by running 24 | ```Shell 25 | ./download_models.sh 26 | ``` 27 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) 28 | 29 | You can demo a trained model on a sequence of frames 30 | ```Shell 31 | python demo.py --model=models/raft-things.pth --path=demo-frames 32 | ``` 33 | 34 | ## Required Data 35 | To evaluate/train RAFT, you will need to download the required datasets. 36 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 37 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 38 | * [Sintel](http://sintel.is.tue.mpg.de/) 39 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 40 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) 41 | 42 | 43 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder 44 | 45 | ```Shell 46 | ├── datasets 47 | ├── Sintel 48 | ├── test 49 | ├── training 50 | ├── KITTI 51 | ├── testing 52 | ├── training 53 | ├── devkit 54 | ├── FlyingChairs_release 55 | ├── data 56 | ├── FlyingThings3D 57 | ├── frames_cleanpass 58 | ├── frames_finalpass 59 | ├── optical_flow 60 | ``` 61 | 62 | ## Evaluation 63 | You can evaluate a trained model using `evaluate.py` 64 | ```Shell 65 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision 66 | ``` 67 | 68 | ## Training 69 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard 70 | ```Shell 71 | ./train_standard.sh 72 | ``` 73 | 74 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) 75 | ```Shell 76 | ./train_mixed.sh 77 | ``` 78 | 79 | ## (Optional) Efficent Implementation 80 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension 81 | ```Shell 82 | cd alt_cuda_corr && python setup.py install && cd .. 83 | ``` 84 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. 85 | -------------------------------------------------------------------------------- /RAFT/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/__init__.py -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /RAFT/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/core/__init__.py -------------------------------------------------------------------------------- /RAFT/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /RAFT/core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock, SmallUpdateBlock 7 | from extractor import BasicEncoder, SmallEncoder 8 | from corr import CorrBlock, AlternateCorrBlock 9 | from utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in self.args: 42 | self.args.dropout = 0 43 | 44 | if 'alternate_corr' not in self.args: 45 | self.args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | def freeze_bn(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.BatchNorm2d): 61 | m.eval() 62 | 63 | def initialize_flow(self, img): 64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 65 | N, C, H, W = img.shape 66 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 67 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 68 | 69 | # optical flow computed as difference: flow = coords1 - coords0 70 | return coords0, coords1 71 | 72 | def upsample_flow(self, flow, mask): 73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 74 | N, _, H, W = flow.shape 75 | mask = mask.view(N, 1, 9, 8, 8, H, W) 76 | mask = torch.softmax(mask, dim=2) 77 | 78 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 80 | 81 | up_flow = torch.sum(mask * up_flow, dim=2) 82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 83 | return up_flow.reshape(N, 2, 8*H, 8*W) 84 | 85 | 86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 87 | """ Estimate optical flow between pair of frames """ 88 | 89 | image1 = 2 * (image1 / 255.0) - 1.0 90 | image2 = 2 * (image2 / 255.0) - 1.0 91 | 92 | image1 = image1.contiguous() 93 | image2 = image2.contiguous() 94 | 95 | hdim = self.hidden_dim 96 | cdim = self.context_dim 97 | 98 | # run the feature network 99 | with autocast(enabled=self.args.mixed_precision): 100 | fmap1, fmap2 = self.fnet([image1, image2]) 101 | 102 | fmap1 = fmap1.float() 103 | fmap2 = fmap2.float() 104 | if self.args.alternate_corr: 105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 106 | else: 107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | 109 | # run the context network 110 | with autocast(enabled=self.args.mixed_precision): 111 | cnet = self.cnet(image1) 112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 113 | net = torch.tanh(net) 114 | inp = torch.relu(inp) 115 | 116 | coords0, coords1 = self.initialize_flow(image1) 117 | 118 | if flow_init is not None: 119 | coords1 = coords1 + flow_init 120 | 121 | flow_predictions = [] 122 | for itr in range(iters): 123 | coords1 = coords1.detach() 124 | corr = corr_fn(coords1) # index correlation volume 125 | 126 | flow = coords1 - coords0 127 | with autocast(enabled=self.args.mixed_precision): 128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 129 | 130 | # F(t+1) = F(t) + \Delta(t) 131 | coords1 = coords1 + delta_flow 132 | 133 | # upsample predictions 134 | if up_mask is None: 135 | flow_up = upflow8(coords1 - coords0) 136 | else: 137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 138 | 139 | flow_predictions.append(flow_up) 140 | 141 | if test_mode: 142 | return coords1 - coords0, flow_up 143 | 144 | return flow_predictions 145 | -------------------------------------------------------------------------------- /RAFT/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /RAFT/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/core/utils/__init__.py -------------------------------------------------------------------------------- /RAFT/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /RAFT/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /RAFT/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, x): 19 | return F.pad(x, self._pad, mode='replicate') 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/demo-frames/frame_0016.png -------------------------------------------------------------------------------- /RAFT/demo-frames/frame_0017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/demo-frames/frame_0017.png -------------------------------------------------------------------------------- /RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | from raft import RAFT 13 | from utils import flow_viz 14 | from utils.utils import InputPadder 15 | import torch.nn.functional as F 16 | 17 | 18 | 19 | DEVICE = 'cuda' 20 | 21 | def load_image(imfile): 22 | img = np.array(Image.open(imfile)).astype(np.uint8) 23 | img = torch.from_numpy(img).permute(2, 0, 1).float() 24 | return img 25 | 26 | 27 | def load_image_list(image_files): 28 | images = [] 29 | for imfile in sorted(image_files): 30 | images.append(load_image(imfile)) 31 | 32 | images = torch.stack(images, dim=0) 33 | images = images.to(DEVICE) 34 | 35 | padder = InputPadder(images.shape) 36 | return padder.pad(images)[0] 37 | 38 | 39 | def viz(img, flo): 40 | img = img[0].permute(1,2,0).cpu().numpy() 41 | flo = flo[0].permute(1,2,0).cpu().numpy() 42 | 43 | # map flow to rgb image 44 | flo = flow_viz.flow_to_image(flo) 45 | flo = Image.fromarray(flo) 46 | flo.save('example.png') 47 | 48 | 49 | def demo(args): 50 | model = torch.nn.DataParallel(RAFT(args)) 51 | model.load_state_dict(torch.load(args.model)) 52 | 53 | model = model.module 54 | model.to(DEVICE) 55 | model.eval() 56 | 57 | with torch.no_grad(): 58 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 59 | glob.glob(os.path.join(args.path, '*.jpg')) 60 | 61 | images = load_image_list(images) 62 | for i in range(images.shape[0]-1): 63 | image1 = images[i,None] 64 | image2 = images[i+1,None] 65 | print(image1.size()) 66 | print(image2.size()) 67 | image1 = F.interpolate(image1,(480,720)) 68 | image2 = F.interpolate(image2,(480,720)) 69 | 70 | 71 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 72 | to_save = flow_up.squeeze(0).cpu().numpy() 73 | np.save('tosave.npy',to_save) 74 | viz(image1, flow_up) 75 | 76 | 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('--model', help="restore checkpoint") 80 | parser.add_argument('--path', help="dataset for evaluation") 81 | parser.add_argument('--small', action='store_true', help='use small model') 82 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 83 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 84 | args = parser.parse_args() 85 | 86 | demo(args) 87 | -------------------------------------------------------------------------------- /RAFT/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 3 | unzip models.zip 4 | -------------------------------------------------------------------------------- /RAFT/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/example.png -------------------------------------------------------------------------------- /RAFT/flow_warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from PIL import Image 5 | import torch.nn.functional as F 6 | 7 | def warp(x, flo): 8 | """ 9 | warp an image/tensor (im2) back to im1, according to the optical flow 10 | x: [B, C, H, W] (im2) 11 | flo: [B, 2, H, W] flow 12 | """ 13 | B, C, H, W = x.size() 14 | # mesh grid 15 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 16 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 17 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 18 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 19 | grid = torch.cat((xx,yy),1).float() 20 | 21 | if x.is_cuda: 22 | grid = grid.todevice(x.device) 23 | vgrid = grid + flo 24 | 25 | # scale grid to [-1,1] 26 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 27 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 28 | 29 | vgrid = vgrid.permute(0,2,3,1) 30 | output = nn.functional.grid_sample(x, vgrid) 31 | 32 | return output 33 | 34 | 35 | def flow_warp(x, flow): 36 | """Warp an image or feature map with optical flow 37 | Args: 38 | x (Tensor): size (n, c, h, w) 39 | flow (Tensor): size (n, 2, h, w), values range from -1 to 1 (relevant to image width or height) 40 | padding_mode (str): 'zeros' or 'border' 41 | 42 | Returns: 43 | Tensor: warped image or feature map 44 | """ 45 | assert x.size()[-2:] == flow.size()[-2:] 46 | n, _, h, w = x.size() 47 | x_ = torch.arange(w).view(1, -1).expand(h, -1) 48 | y_ = torch.arange(h).view(-1, 1).expand(-1, w) 49 | grid = torch.stack([x_, y_], dim=0).float() 50 | grid = grid.unsqueeze(0).expand(n, -1, -1, -1) 51 | grid[:, 0, :, :] = 2 * grid[:, 0, :, :] / (w - 1) - 1 52 | grid[:, 1, :, :] = 2 * grid[:, 1, :, :] / (h - 1) - 1 53 | grid += 2 * flow 54 | grid = grid.permute(0, 2, 3, 1) 55 | return F.grid_sample(x, grid) 56 | 57 | 58 | 59 | 60 | 61 | if __name__=='__main__': 62 | img = '/home/miaojiaxu/jiaxu3/vsp_segment/RAFT-master/demo-frames/frame_0016.png' 63 | img = Image.open(img) 64 | img = img.resize((1024,440)) 65 | img = np.array(img) 66 | img = img/255. 67 | 68 | 69 | 70 | flow = np.load('tosave.npy') 71 | 72 | img = torch.from_numpy(img) 73 | img = img.unsqueeze(0).permute(0,3,1,2) 74 | img = img.float() 75 | 76 | print(img.size()) 77 | flow = torch.from_numpy(flow) 78 | flow = flow.unsqueeze(0) 79 | print(flow.size()) 80 | img2 = warp(img,flow) 81 | print(img2.size()) 82 | img2 = img2.squeeze(0).permute(1,2,0).numpy() 83 | print(img2.shape) 84 | img2 = Image.fromarray((img2*255.).astype('uint8')) 85 | img2.save('1.png') 86 | 87 | -------------------------------------------------------------------------------- /RAFT/logs/hadoop.kylin.libdfs.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/logs/hadoop.kylin.libdfs.log -------------------------------------------------------------------------------- /RAFT/models/1.py: -------------------------------------------------------------------------------- 1 | print('!@!') 2 | 3 | if __name__=='__main__': 4 | print('22') 5 | -------------------------------------------------------------------------------- /RAFT/models/raft-things.pth-no-zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/models/raft-things.pth-no-zip -------------------------------------------------------------------------------- /RAFT/tosave.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/tosave.npy -------------------------------------------------------------------------------- /RAFT/train_mixed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision 7 | -------------------------------------------------------------------------------- /RAFT/train_standard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 -------------------------------------------------------------------------------- /RAFT_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__init__.py -------------------------------------------------------------------------------- /RAFT_core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/__pycache__/corr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/corr.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/__pycache__/extractor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/extractor.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/__pycache__/raft.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/raft.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/__pycache__/update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/update.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from RAFT_core.utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /RAFT_core/raft-things.pth-no-zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/raft-things.pth-no-zip -------------------------------------------------------------------------------- /RAFT_core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import sys 6 | sys.path.append('RAFT_core') 7 | 8 | from update import BasicUpdateBlock, SmallUpdateBlock 9 | from extractor import BasicEncoder, SmallEncoder 10 | from corr import CorrBlock, AlternateCorrBlock 11 | from RAFT_core.utils.utils import bilinear_sampler, coords_grid, upflow8 12 | 13 | try: 14 | autocast = torch.cuda.amp.autocast 15 | except: 16 | # dummy autocast for PyTorch < 1.6 17 | class autocast: 18 | def __init__(self, enabled): 19 | pass 20 | def __enter__(self): 21 | pass 22 | def __exit__(self, *args): 23 | pass 24 | 25 | 26 | class RAFT(nn.Module): 27 | def __init__(self,requires_grad=False): 28 | super(RAFT, self).__init__() 29 | # self.args = args 30 | 31 | self.hidden_dim = hdim = 128 32 | self.context_dim = cdim = 128 33 | corr_levels = 4 34 | corr_radius = 4 35 | self.corr_radius = corr_radius 36 | 37 | 38 | 39 | # feature network, context network, and update block 40 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=0) 41 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=0) 42 | self.update_block = BasicUpdateBlock(corr_levels,corr_radius, hidden_dim=hdim) 43 | if not requires_grad: 44 | for param in self.parameters(): 45 | param.requires_grad = False 46 | 47 | def freeze_bn(self): 48 | for m in self.modules(): 49 | if isinstance(m, nn.BatchNorm2d): 50 | m.eval() 51 | 52 | def initialize_flow(self, img): 53 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 54 | N, C, H, W = img.shape 55 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 56 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 57 | 58 | # optical flow computed as difference: flow = coords1 - coords0 59 | return coords0, coords1 60 | 61 | def upsample_flow(self, flow, mask): 62 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 63 | N, _, H, W = flow.shape 64 | mask = mask.view(N, 1, 9, 8, 8, H, W) 65 | mask = torch.softmax(mask, dim=2) 66 | 67 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 68 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 69 | 70 | up_flow = torch.sum(mask * up_flow, dim=2) 71 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 72 | return up_flow.reshape(N, 2, 8*H, 8*W) 73 | 74 | 75 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 76 | """ Estimate optical flow between pair of frames """ 77 | 78 | image1 = 2 * (image1 / 255.0) - 1.0 79 | image2 = 2 * (image2 / 255.0) - 1.0 80 | 81 | image1 = image1.contiguous() 82 | image2 = image2.contiguous() 83 | 84 | hdim = self.hidden_dim 85 | cdim = self.context_dim 86 | 87 | # run the feature network 88 | fmap1, fmap2 = self.fnet([image1, image2]) 89 | 90 | fmap1 = fmap1.float() 91 | fmap2 = fmap2.float() 92 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.corr_radius) 93 | 94 | # run the context network 95 | cnet = self.cnet(image1) 96 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 97 | net = torch.tanh(net) 98 | inp = torch.relu(inp) 99 | 100 | coords0, coords1 = self.initialize_flow(image1) 101 | 102 | if flow_init is not None: 103 | coords1 = coords1 + flow_init 104 | 105 | flow_predictions = [] 106 | for itr in range(iters): 107 | coords1 = coords1.detach() 108 | corr = corr_fn(coords1) # index correlation volume 109 | 110 | flow = coords1 - coords0 111 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 112 | 113 | # F(t+1) = F(t) + \Delta(t) 114 | coords1 = coords1 + delta_flow 115 | 116 | # upsample predictions 117 | if up_mask is None: 118 | flow_up = upflow8(coords1 - coords0) 119 | else: 120 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 121 | 122 | flow_predictions.append(flow_up) 123 | 124 | if test_mode: 125 | return coords1 - coords0, flow_up 126 | 127 | return flow_predictions 128 | -------------------------------------------------------------------------------- /RAFT_core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self,corr_levels,corr_radius): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = corr_levels * (2*corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self,corr_levels,corr_radius ): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = corr_levels * (2*corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, corr_levels,corr_radius, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(corr_levels,corr_radius) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, corr_levels,corr_radius, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.encoder = BasicMotionEncoder(corr_levels,corr_radius) 118 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 119 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 120 | 121 | self.mask = nn.Sequential( 122 | nn.Conv2d(128, 256, 3, padding=1), 123 | nn.ReLU(inplace=True), 124 | nn.Conv2d(256, 64*9, 1, padding=0)) 125 | 126 | def forward(self, net, inp, corr, flow, upsample=True): 127 | motion_features = self.encoder(flow, corr) 128 | inp = torch.cat([inp, motion_features], dim=1) 129 | 130 | net = self.gru(net, inp) 131 | delta_flow = self.flow_head(net) 132 | 133 | # scale mask to balence gradients 134 | mask = .25 * self.mask(net) 135 | return net, mask, delta_flow 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /RAFT_core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/utils/__init__.py -------------------------------------------------------------------------------- /RAFT_core/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /RAFT_core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /RAFT_core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, x): 19 | #return F.pad(x, self._pad, mode='replicate') 20 | return F.pad(x, self._pad, mode='constant') 21 | 22 | def unpad(self,x): 23 | ht, wd = x.shape[-2:] 24 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 25 | return x[:,:, c[0]:c[1], c[2]:c[3]] 26 | 27 | def forward_interpolate(flow): 28 | flow = flow.detach().cpu().numpy() 29 | dx, dy = flow[0], flow[1] 30 | 31 | ht, wd = dx.shape 32 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 33 | 34 | x1 = x0 + dx 35 | y1 = y0 + dy 36 | 37 | x1 = x1.reshape(-1) 38 | y1 = y1.reshape(-1) 39 | dx = dx.reshape(-1) 40 | dy = dy.reshape(-1) 41 | 42 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 43 | x1 = x1[valid] 44 | y1 = y1[valid] 45 | dx = dx[valid] 46 | dy = dy[valid] 47 | 48 | flow_x = interpolate.griddata( 49 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 50 | 51 | flow_y = interpolate.griddata( 52 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 53 | 54 | flow = np.stack([flow_x, flow_y], axis=0) 55 | return torch.from_numpy(flow).float() 56 | 57 | 58 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 59 | """ Wrapper for grid_sample, uses pixel coordinates """ 60 | H, W = img.shape[-2:] 61 | xgrid, ygrid = coords.split([1,1], dim=-1) 62 | xgrid = 2*xgrid/(W-1) - 1 63 | ygrid = 2*ygrid/(H-1) - 1 64 | 65 | grid = torch.cat([xgrid, ygrid], dim=-1) 66 | img = F.grid_sample(img, grid, align_corners=True) 67 | 68 | if mask: 69 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 70 | return img, mask.float() 71 | 72 | return img 73 | 74 | 75 | def coords_grid(batch, ht, wd): 76 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 77 | coords = torch.stack(coords[::-1], dim=0).float() 78 | return coords[None].repeat(batch, 1, 1, 1) 79 | 80 | 81 | def upflow8(flow, mode='bilinear'): 82 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 83 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild 2 | 3 | A pytorch implementation of the CVPR2021 paper "VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild" 4 | 5 | # Preparation 6 | 7 | ## Download VSPW dataset 8 | 9 | The VSPW dataset with extracted frames and masks is available [here](https://github.com/sssdddwww2/vspw_dataset_download). Now you can directly download [VSPW_480P dataset](https://github.com/sssdddwww2/vspw_dataset_download). 10 | 11 | ## Dependencies 12 | - Python 3.7 13 | - Pytorch 1.3.1 14 | - Numpy 15 | 16 | Download the ImageNet-pretrained models at [this link](https://drive.google.com/file/d/1VFmObwlx4d_K7FOjFNk5LhEb3jP8_NaD/view?usp=sharing). Put it in the root folder and decompress it. 17 | 18 | # Train and Test 19 | 20 | Resize the frames and masks of the VSPW dataset to *480p*. 21 | 22 | ``` 23 | python change2_480p.py 24 | ``` 25 | 26 | Edit the *.sh* files in *scripts/* and change the **$DATAROOT** to your path to VSPW_480p. 27 | 28 | ## Image-based methods 29 | 30 | PSPNet 31 | 32 | ``` 33 | sh scripts/run_psp.sh 34 | ``` 35 | 36 | OCRNet 37 | 38 | ``` 39 | sh scripts/run_ocr.sh 40 | ``` 41 | 42 | ## Video-based methods 43 | 44 | TCB-PSP 45 | 46 | ``` 47 | sh run_temporal_psp.sh 48 | ``` 49 | 50 | TCB-OCR 51 | 52 | ``` 53 | sh run_temporal_ocr.sh 54 | ``` 55 | 56 | ## Evaluation on TC and VC 57 | 58 | Change dataroot and prediction root in *TC_cal.py* and *VC_perclip.py*. 59 | 60 | ``` 61 | python TC_cal.py 62 | ``` 63 | 64 | ``` 65 | python VC_perclip.py 66 | ``` 67 | 68 | This implementation utilized [this code](https://github.com/CSAILVision/semantic-segmentation-pytorch) and [RAFT](https://github.com/princeton-vl/RAFT). 69 | 70 | 71 | 72 | # Citation 73 | 74 | ``` 75 | @inproceedings{miao2021vspw, 76 | 77 | title={VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild}, 78 | 79 | author={Miao, Jiaxu and Wei, Yunchao and Wu, Yu and Liang, Chen and Li, Guangrui and Yang, Yi}, 80 | 81 | booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition}, 82 | 83 | year={2021} 84 | 85 | } 86 | ``` 87 | 88 | 89 | -------------------------------------------------------------------------------- /TC_cal.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from RAFT_core.raft import RAFT 4 | from RAFT_core.utils.utils import InputPadder 5 | from collections import OrderedDict 6 | from utils import Evaluator 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import sys 11 | 12 | def flowwarp(x, flo): 13 | """ 14 | warp an image/tensor (im2) back to im1, according to the optical flow 15 | x: [B, C, H, W] (im2) 16 | flo: [B, 2, H, W] flow 17 | """ 18 | B, C, H, W = x.size() 19 | # mesh grid 20 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 21 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 22 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 23 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 24 | grid = torch.cat((xx,yy),1).float() 25 | 26 | if x.is_cuda: 27 | grid = grid.to(x.device) 28 | vgrid = grid + flo 29 | 30 | # scale grid to [-1,1] 31 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 32 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 33 | 34 | vgrid = vgrid.permute(0,2,3,1) 35 | output = nn.functional.grid_sample(x, vgrid,mode='nearest',align_corners=False) 36 | 37 | return output 38 | 39 | 40 | 41 | num_class=124 42 | 43 | DIR_='/your/path/to/VSPW_480p' 44 | 45 | data_dir=DIR_+'/data' 46 | result_dir='./prediction' 47 | #list_=['1001_5z_ijQjUf_0','1002_QXQ_QoswLOs'] 48 | 49 | split='val.txt' 50 | with open(os.path.join(DIR_,split),'r') as f: 51 | 52 | list_ = f.readlines() 53 | list_ = [v[:-1] for v in list_] 54 | 55 | ### 56 | gpu=0 57 | model_raft = RAFT() 58 | to_load = torch.load('./RAFT_core/raft-things.pth-no-zip') 59 | new_state_dict = OrderedDict() 60 | for k, v in to_load.items(): 61 | name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module. 62 | new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 63 | model_raft.load_state_dict(new_state_dict) 64 | model_raft = model_raft.cuda(gpu) 65 | ### 66 | total_TC=0. 67 | evaluator = Evaluator(num_class) 68 | for video in list_[:100]: 69 | if video[0]=='.': 70 | continue 71 | imglist_ = sorted(os.listdir(os.path.join(data_dir,video,'origin'))) 72 | for i,img in enumerate(imglist_[:-1]): 73 | if img[0]=='.': 74 | continue 75 | #print('processing video : {} image: {}'.format(video,img)) 76 | next_img = imglist_[i+1] 77 | imgname = img 78 | next_imgname = next_img 79 | img = Image.open(os.path.join(data_dir,video,'origin',img)) 80 | next_img =Image.open(os.path.join(data_dir,video,'origin',next_img)) 81 | image1 = torch.from_numpy(np.array(img)) 82 | image2 = torch.from_numpy(np.array(next_img)) 83 | padder = InputPadder(image1.size()[:2]) 84 | image1 = image1.unsqueeze(0).permute(0,3,1,2) 85 | image2 = image2.unsqueeze(0).permute(0,3,1,2) 86 | image1 = padder.pad(image1) 87 | image2 = padder.pad(image2) 88 | image1 = image1.cuda(gpu) 89 | image2 = image2.cuda(gpu) 90 | with torch.no_grad(): 91 | model_raft.eval() 92 | _,flow = model_raft(image1,image2,iters=20, test_mode=True) 93 | flow = padder.unpad(flow) 94 | 95 | flow = flow.data.cpu() 96 | pred = Image.open(os.path.join(result_dir,video,imgname.split('.')[0]+'.png')) 97 | next_pred = Image.open(os.path.join(result_dir,video,next_imgname.split('.')[0]+'.png')) 98 | pred =torch.from_numpy(np.array(pred)) 99 | next_pred = torch.from_numpy(np.array(next_pred)) 100 | next_pred = next_pred.unsqueeze(0).unsqueeze(0).float() 101 | # print(next_pred) 102 | 103 | warp_pred = flowwarp(next_pred,flow) 104 | # print(warp_pred) 105 | warp_pred = warp_pred.int().squeeze(1).numpy() 106 | pred = pred.unsqueeze(0).numpy() 107 | evaluator.add_batch(pred, warp_pred) 108 | # v_mIoU = evaluator.Mean_Intersection_over_Union() 109 | # total_TC+=v_mIoU 110 | # print('processed video : {} score:{}'.format(video,v_mIoU)) 111 | 112 | #TC = total_TC/len(list_) 113 | TC = evaluator.Mean_Intersection_over_Union() 114 | 115 | print("TC score is {}".format(TC)) 116 | 117 | print(split) 118 | print(result_dir) 119 | 120 | 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /VC_perclip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | #from utils import Evaluator 5 | import sys 6 | 7 | def get_common(list_,predlist,clip_num,h,w): 8 | accs = [] 9 | for i in range(len(list_)-clip_num): 10 | global_common = np.ones((h,w)) 11 | predglobal_common = np.ones((h,w)) 12 | 13 | 14 | for j in range(1,clip_num): 15 | common = (list_[i] == list_[i+j]) 16 | global_common = np.logical_and(global_common,common) 17 | pred_common = (predlist[i]==predlist[i+j]) 18 | predglobal_common = np.logical_and(predglobal_common,pred_common) 19 | pred = (predglobal_common*global_common) 20 | 21 | acc = pred.sum()/global_common.sum() 22 | accs.append(acc) 23 | return accs 24 | 25 | 26 | 27 | DIR='/your/path/to/VSPW_480p' 28 | 29 | Pred='./predicts' 30 | split = 'val.txt' 31 | 32 | with open(os.path.join(DIR,split),'r') as f: 33 | lines = f.readlines() 34 | for line in lines: 35 | videolist = [line[:-1] for line in lines] 36 | total_acc=[] 37 | 38 | clip_num=16 39 | 40 | 41 | for video in videolist: 42 | if video[0]=='.': 43 | continue 44 | imglist = [] 45 | predlist = [] 46 | 47 | images = sorted(os.listdir(os.path.join(DIR,'data',video,'mask'))) 48 | 49 | if len(images)<=clip_num: 50 | continue 51 | for imgname in images: 52 | if imgname[0]=='.': 53 | continue 54 | img = Image.open(os.path.join(DIR,'data',video,'mask',imgname)) 55 | w,h = img.size 56 | img = np.array(img) 57 | imglist.append(img) 58 | pred = Image.open(os.path.join(Pred,video,imgname)) 59 | pred = np.array(pred) 60 | predlist.append(pred) 61 | 62 | accs = get_common(imglist,predlist,clip_num,h,w) 63 | print(sum(accs)/len(accs)) 64 | total_acc.extend(accs) 65 | Acc = np.array(total_acc) 66 | Acc = np.nanmean(Acc) 67 | print(Pred) 68 | print('*'*10) 69 | print('VC{} score: {} on {} set'.format(clip_num,Acc,split)) 70 | print('*'*10) 71 | 72 | -------------------------------------------------------------------------------- /change2_480p.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from multiprocessing import Pool 4 | 5 | 6 | DIR='/your/path/to/VSPW' 7 | 8 | Target_Dir = '/your/path/to/VSPW_480p' 9 | 10 | 11 | def change(DIR,video,image): 12 | img = Image.open(os.path.join(DIR,'data',video,'origin',image)) 13 | w,h = img.size 14 | 15 | if not os.path.exists(os.path.join(Target_Dir,'data',video,'origin')): 16 | os.makedirs(os.path.join(Target_Dir,'data',video,'origin')) 17 | img = img.resize((int(480*w/h),480),Image.BILINEAR) 18 | img.save(os.path.join(Target_Dir,'data',video,'origin',image)) 19 | 20 | if os.path.isfile(os.path.join(DIR,'data',video,'mask',image.split('.')[0]+'.png')): 21 | 22 | 23 | mask = Image.open(os.path.join(DIR,'data',video,'mask',image.split('.')[0]+'.png')) 24 | mask = mask.resize((int(480*w/h),480),Image.NEAREST) 25 | 26 | if not os.path.exists(os.path.join(Target_Dir,'data',video,'mask')): 27 | os.makedirs(os.path.join(Target_Dir,'data',video,'mask')) 28 | 29 | mask.save(os.path.join(Target_Dir,'data',video,'mask',image.split('.')[0]+'.png')) 30 | print('Processing video {} image {}'.format(video,image)) 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | #p = Pool(8) 39 | for video in sorted(os.listdir(os.path.join(DIR,'data'))): 40 | if video[0]=='.': 41 | continue 42 | for image in sorted(os.listdir(os.path.join(DIR,'data',video,'origin'))): 43 | if image[0]=='.': 44 | continue 45 | # p.apply_async(change,args=(DIR,video,image,)) 46 | change(DIR,video,image) 47 | #p.close() 48 | #p.join() 49 | print('finish') 50 | 51 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/config/__pycache__/defaults.cpython-37.pyc -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Config definition 5 | # ----------------------------------------------------------------------------- 6 | 7 | _C = CN() 8 | _C.DIR = "ckpt/ade20k-resnet50dilated-ppm_deepsup" 9 | 10 | # ----------------------------------------------------------------------------- 11 | # Dataset 12 | # ----------------------------------------------------------------------------- 13 | _C.DATASET = CN() 14 | _C.DATASET.root_dataset = "./data/" 15 | _C.DATASET.list_train = "./data/training.odgt" 16 | _C.DATASET.list_val = "./data/validation.odgt" 17 | _C.DATASET.num_class = 150 18 | # multiscale train/test, size of short edge (int or tuple) 19 | _C.DATASET.imgSizes = (300, 375, 450, 525, 600) 20 | # maximum input image size of long edge 21 | _C.DATASET.imgMaxSize = 1000 22 | # maxmimum downsampling rate of the network 23 | _C.DATASET.padding_constant = 8 24 | # downsampling rate of the segmentation label 25 | _C.DATASET.segm_downsampling_rate = 8 26 | # randomly horizontally flip images when train/test 27 | _C.DATASET.random_flip = True 28 | 29 | # ----------------------------------------------------------------------------- 30 | # Model 31 | # ----------------------------------------------------------------------------- 32 | _C.MODEL = CN() 33 | # architecture of net_encoder 34 | _C.MODEL.arch_encoder = "resnet50dilated" 35 | # architecture of net_decoder 36 | _C.MODEL.arch_decoder = "ppm_deepsup" 37 | # weights to finetune net_encoder 38 | _C.MODEL.weights_encoder = "" 39 | # weights to finetune net_decoder 40 | _C.MODEL.weights_decoder = "" 41 | # number of feature channels between encoder and decoder 42 | _C.MODEL.fc_dim = 2048 43 | 44 | # ----------------------------------------------------------------------------- 45 | # Training 46 | # ----------------------------------------------------------------------------- 47 | _C.TRAIN = CN() 48 | _C.TRAIN.batch_size_per_gpu = 2 49 | # epochs to train for 50 | _C.TRAIN.num_epoch = 20 51 | # epoch to start training. useful if continue from a checkpoint 52 | _C.TRAIN.start_epoch = 0 53 | # iterations of each epoch (irrelevant to batch size) 54 | _C.TRAIN.epoch_iters = 5000 55 | 56 | _C.TRAIN.optim = "SGD" 57 | _C.TRAIN.lr_encoder = 0.02 58 | _C.TRAIN.lr_decoder = 0.02 59 | # power in poly to drop LR 60 | _C.TRAIN.lr_pow = 0.9 61 | # momentum for sgd, beta1 for adam 62 | _C.TRAIN.beta1 = 0.9 63 | # weights regularizer 64 | _C.TRAIN.weight_decay = 1e-4 65 | # the weighting of deep supervision loss 66 | _C.TRAIN.deep_sup_scale = 0.4 67 | # fix bn params, only under finetuning 68 | _C.TRAIN.fix_bn = False 69 | # number of data loading workers 70 | _C.TRAIN.workers = 16 71 | 72 | # frequency to display 73 | _C.TRAIN.disp_iter = 20 74 | # manual seed 75 | _C.TRAIN.seed = 304 76 | 77 | # ----------------------------------------------------------------------------- 78 | # Validation 79 | # ----------------------------------------------------------------------------- 80 | _C.VAL = CN() 81 | # currently only supports 1 82 | _C.VAL.batch_size = 1 83 | # output visualization during validation 84 | _C.VAL.visualize = False 85 | # the checkpoint to evaluate on 86 | _C.VAL.checkpoint = "epoch_20.pth" 87 | 88 | # ----------------------------------------------------------------------------- 89 | # Testing 90 | # ----------------------------------------------------------------------------- 91 | _C.TEST = CN() 92 | # currently only supports 1 93 | _C.TEST.batch_size = 1 94 | # the checkpoint to test on 95 | _C.TEST.checkpoint = "epoch_20.pth" 96 | # folder to output visualization results 97 | _C.TEST.result = "./" 98 | -------------------------------------------------------------------------------- /config/vsp-hrnetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 32 9 | segm_downsampling_rate: 4 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "hrnetv2" 14 | arch_decoder: "c1" 15 | fc_dim: 720 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 30 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_30.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_30.pth" 40 | result: "./" 41 | 42 | DIR: "/home/miaojiaxu/jiaxu_2/semantic_seg/ade20k-hrnetv2-c1_pretrain" 43 | -------------------------------------------------------------------------------- /config/vsp-mobilenetv2dilated-c1_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "mobilenetv2dilated" 14 | arch_decoder: "c1_deepsup" 15 | fc_dim: 320 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 3 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-mobilenetv2dilated-c1_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-mobilenetv2dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "mobilenetv2dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 320 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 3 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-mobilenetv2dilated-c1_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet101-upernet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 32 9 | segm_downsampling_rate: 4 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101" 14 | arch_decoder: "upernet" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 50 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_40.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_40.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet101-upernet" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet101dilated-deeplab.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101dilated" 14 | arch_decoder: "deeplab" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet101dilated-nonlocal2d.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101dilated" 14 | arch_decoder: "nonlocal2d" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet101dilated-ocr_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101dilated" 14 | arch_decoder: "ocrnet_deepsup" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet101dilated-ppm_clip.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101dilated" 14 | arch_decoder: "ppm_clip" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet101dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet101dilated-ppm_deepsup_clip.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101dilated" 14 | arch_decoder: "ppm_deepsup_clip" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet101dilated_tdnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101" 14 | arch_decoder: "deeplab" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet18dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet18dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 512 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet18dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet18dilated-ppm_deepsup_clip.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet18" 14 | arch_decoder: "ppm_deepsup_clip" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet50-upernet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 32 9 | segm_downsampling_rate: 4 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50" 14 | arch_decoder: "upernet" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 30 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_30.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_30.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50-upernet" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet50dilated-deeplab.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50dilated" 14 | arch_decoder: "deeplab" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet50dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet50dilated-ppm_deepsup_clip.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50dilated" 14 | arch_decoder: "ppm_deepsup_clip" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/vsp-resnet50dilated-tdnet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50" 14 | arch_decoder: "deeplab" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /lib/nn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /lib/nn/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/modules/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/modules/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/modules/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /lib/nn/parallel/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/parallel/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/parallel/__pycache__/data_parallel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/parallel/__pycache__/data_parallel.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /lib/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/th.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/utils/__pycache__/th.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, TensorDataset, ConcatDataset 3 | from .dataloader import DataLoader 4 | -------------------------------------------------------------------------------- /lib/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /lib/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /lib/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /models/.non_local2d.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/.non_local2d.py.swp -------------------------------------------------------------------------------- /models/.propnet.py.swo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/.propnet.py.swo -------------------------------------------------------------------------------- /models/.propnet.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/.propnet.py.swp -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ModelBuilder, SegmentationModule,SegmentationModule_clip,SegmentationModule_allclip,ClipWarpNet 2 | from .netwarp import NetWarp 3 | from .ETC import ETC 4 | from .non_local_models import Non_local3d,Non_local2d 5 | from .propnet import PropNet 6 | from .warp_our_merge import OurWarpMerge 7 | from .clip_psp import Clip_PSP 8 | from .clip_ocr import ClipOCRNet 9 | from .netwarp_ocr import NetWarp_ocr 10 | from .ETC_ocr import ETC_ocr 11 | -------------------------------------------------------------------------------- /models/__pycache__/BiConvLSTM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/BiConvLSTM.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ETC.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/ETC.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ETC_ocr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/ETC_ocr.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/clip_ocr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/clip_ocr.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/clip_psp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/clip_psp.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/deeplab.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/deeplab.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/hrnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/hrnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/hrnet_clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/hrnet_clip.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/netwarp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/netwarp.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/netwarp_ocr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/netwarp_ocr.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/non_local.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/non_local.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/non_local_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/non_local_models.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ocrnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/ocrnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/propnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/propnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnext.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/resnext.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/warp_our.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/warp_our.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/warp_our_merge.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/warp_our_merge.cpython-37.pyc -------------------------------------------------------------------------------- /models/deeplabv3/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 60 | self.bn1 = BatchNorm(256) 61 | self.relu = nn.ReLU() 62 | self.dropout = nn.Dropout(0.5) 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x1 = self.aspp1(x) 67 | x2 = self.aspp2(x) 68 | x3 = self.aspp3(x) 69 | x4 = self.aspp4(x) 70 | x5 = self.global_avg_pool(x) 71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | return self.dropout(x) 79 | 80 | def _init_weight(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /models/deeplabv3/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm,args): 9 | super(Decoder, self).__init__() 10 | self.args = args 11 | if backbone == 'resnet' or backbone == 'drn': 12 | low_level_inplanes = 256 13 | elif backbone == 'xception': 14 | low_level_inplanes = 128 15 | elif backbone == 'mobilenet': 16 | low_level_inplanes = 24 17 | else: 18 | raise NotImplementedError 19 | 20 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 21 | self.bn1 = BatchNorm(48) 22 | self.relu = nn.ReLU() 23 | if self.args.deeplab_as_base: 24 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 25 | BatchNorm(256), 26 | nn.ReLU(), 27 | nn.Dropout(0.5), 28 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 29 | BatchNorm(256), 30 | nn.ReLU() 31 | ) 32 | self.lastlast_conv = nn.Sequential( 33 | nn.Dropout(0.1), 34 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1) 35 | ) 36 | # nn.Dropout(0.1), 37 | # nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 38 | else: 39 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 40 | BatchNorm(256), 41 | nn.ReLU(), 42 | nn.Dropout(0.5), 43 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 44 | BatchNorm(256), 45 | nn.ReLU(), 46 | nn.Dropout(0.1), 47 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 48 | self._init_weight() 49 | 50 | 51 | def forward(self, x, low_level_feat): 52 | low_level_feat = self.conv1(low_level_feat) 53 | low_level_feat = self.bn1(low_level_feat) 54 | low_level_feat = self.relu(low_level_feat) 55 | 56 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 57 | x = torch.cat((x, low_level_feat), dim=1) 58 | if self.args.deeplab_as_base: 59 | x = self.last_conv(x) 60 | y = self.lastlast_conv(x) 61 | return y,x 62 | else: 63 | x = self.last_conv(x) 64 | 65 | return x 66 | 67 | def _init_weight(self): 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | torch.nn.init.kaiming_normal_(m.weight) 71 | elif isinstance(m, SynchronizedBatchNorm2d): 72 | m.weight.data.fill_(1) 73 | m.bias.data.zero_() 74 | elif isinstance(m, nn.BatchNorm2d): 75 | m.weight.data.fill_(1) 76 | m.bias.data.zero_() 77 | 78 | def build_decoder(num_classes, backbone, BatchNorm,args): 79 | return Decoder(num_classes, backbone, BatchNorm,args) 80 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This MobileNetV2 implementation is modified from the following repository: 3 | https://github.com/tonylins/pytorch-mobilenet-v2 4 | """ 5 | 6 | import torch.nn as nn 7 | import math 8 | from .utils import load_url 9 | from models.sync_batchnorm import SynchronizedBatchNorm2d 10 | 11 | BatchNorm2d = SynchronizedBatchNorm2d 12 | 13 | 14 | __all__ = ['mobilenetv2'] 15 | 16 | 17 | model_urls = { 18 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', 19 | } 20 | 21 | 22 | def conv_bn(inp, oup, stride): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 25 | BatchNorm2d(oup), 26 | nn.ReLU6(inplace=True) 27 | ) 28 | 29 | 30 | def conv_1x1_bn(inp, oup): 31 | return nn.Sequential( 32 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 33 | BatchNorm2d(oup), 34 | nn.ReLU6(inplace=True) 35 | ) 36 | 37 | 38 | class InvertedResidual(nn.Module): 39 | def __init__(self, inp, oup, stride, expand_ratio): 40 | super(InvertedResidual, self).__init__() 41 | self.stride = stride 42 | assert stride in [1, 2] 43 | 44 | hidden_dim = round(inp * expand_ratio) 45 | self.use_res_connect = self.stride == 1 and inp == oup 46 | 47 | if expand_ratio == 1: 48 | self.conv = nn.Sequential( 49 | # dw 50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 51 | BatchNorm2d(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 55 | BatchNorm2d(oup), 56 | ) 57 | else: 58 | self.conv = nn.Sequential( 59 | # pw 60 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 61 | BatchNorm2d(hidden_dim), 62 | nn.ReLU6(inplace=True), 63 | # dw 64 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 65 | BatchNorm2d(hidden_dim), 66 | nn.ReLU6(inplace=True), 67 | # pw-linear 68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 69 | BatchNorm2d(oup), 70 | ) 71 | 72 | def forward(self, x): 73 | if self.use_res_connect: 74 | return x + self.conv(x) 75 | else: 76 | return self.conv(x) 77 | 78 | 79 | class MobileNetV2(nn.Module): 80 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 81 | super(MobileNetV2, self).__init__() 82 | block = InvertedResidual 83 | input_channel = 32 84 | last_channel = 1280 85 | interverted_residual_setting = [ 86 | # t, c, n, s 87 | [1, 16, 1, 1], 88 | [6, 24, 2, 2], 89 | [6, 32, 3, 2], 90 | [6, 64, 4, 2], 91 | [6, 96, 3, 1], 92 | [6, 160, 3, 2], 93 | [6, 320, 1, 1], 94 | ] 95 | 96 | # building first layer 97 | assert input_size % 32 == 0 98 | input_channel = int(input_channel * width_mult) 99 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 100 | self.features = [conv_bn(3, input_channel, 2)] 101 | # building inverted residual blocks 102 | for t, c, n, s in interverted_residual_setting: 103 | output_channel = int(c * width_mult) 104 | for i in range(n): 105 | if i == 0: 106 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 107 | else: 108 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 109 | input_channel = output_channel 110 | # building last several layers 111 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 112 | # make it nn.Sequential 113 | self.features = nn.Sequential(*self.features) 114 | 115 | # building classifier 116 | self.classifier = nn.Sequential( 117 | nn.Dropout(0.2), 118 | nn.Linear(self.last_channel, n_class), 119 | ) 120 | 121 | self._initialize_weights() 122 | 123 | def forward(self, x): 124 | x = self.features(x) 125 | x = x.mean(3).mean(2) 126 | x = self.classifier(x) 127 | return x 128 | 129 | def _initialize_weights(self): 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | if m.bias is not None: 135 | m.bias.data.zero_() 136 | elif isinstance(m, BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | elif isinstance(m, nn.Linear): 140 | n = m.weight.size(1) 141 | m.weight.data.normal_(0, 0.01) 142 | m.bias.data.zero_() 143 | 144 | 145 | def mobilenetv2(pretrained=False, **kwargs): 146 | """Constructs a MobileNet_V2 model. 147 | 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | """ 151 | model = MobileNetV2(n_class=1000, **kwargs) 152 | if pretrained: 153 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) 154 | return model 155 | -------------------------------------------------------------------------------- /models/ocr_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/ocr_modules/__init__.py -------------------------------------------------------------------------------- /models/ocr_modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/ocr_modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/ocr_modules/__pycache__/spatial_ocr_block.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/ocr_modules/__pycache__/spatial_ocr_block.cpython-37.pyc -------------------------------------------------------------------------------- /models/ocrnet.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: RainbowSecret 3 | ## Microsoft Research 4 | ## yuyua@microsoft.com 5 | ## Copyright (c) 2018 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | import pdb 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import functional as F 14 | from models.sync_batchnorm import SynchronizedBatchNorm2d 15 | BatchNorm2d = SynchronizedBatchNorm2d 16 | BN_MOMENTUM = 0.1 17 | 18 | #from lib.models.backbones.backbone_selector import BackboneSelector 19 | #from lib.models.tools.module_helper import ModuleHelper 20 | 21 | 22 | class SpatialOCRNet(nn.Module): 23 | """ 24 | Object-Contextual Representations for Semantic Segmentation, 25 | Yuan, Yuhui and Chen, Xilin and Wang, Jingdong 26 | """ 27 | def __init__(self, num_class): 28 | self.inplanes = 128 29 | super(SpatialOCRNet, self).__init__() 30 | self.num_classes=num_class 31 | in_channels = [1024, 2048] 32 | self.conv_3x3 = nn.Sequential( 33 | nn.Conv2d(in_channels[1], 512, kernel_size=3, stride=1, padding=1), 34 | BatchNorm2d(512), 35 | nn.ReLU(inplace=True) 36 | ) 37 | 38 | from models.ocr_modules.spatial_ocr_block import SpatialGather_Module, SpatialOCR_Module 39 | self.spatial_context_head = SpatialGather_Module(self.num_classes) 40 | self.spatial_ocr_head = SpatialOCR_Module(in_channels=512, 41 | key_channels=256, 42 | out_channels=512, 43 | scale=1, 44 | dropout=0.05 45 | ) 46 | 47 | self.head = nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True) 48 | self.dsn_head = nn.Sequential( 49 | nn.Conv2d(in_channels[0], 512, kernel_size=3, stride=1, padding=1), 50 | BatchNorm2d(512), 51 | nn.ReLU(inplace=True), 52 | nn.Dropout2d(0.05), 53 | nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True) 54 | ) 55 | 56 | def forward(self, x,segSize=None): 57 | 58 | x_dsn = self.dsn_head(x[-2]) 59 | x = self.conv_3x3(x[-1]) 60 | context = self.spatial_context_head(x, x_dsn) 61 | x = self.spatial_ocr_head(x, context) 62 | x = self.head(x) 63 | 64 | if segSize is not None: # is True during inference 65 | x = F.interpolate( 66 | x, size=segSize, mode='bilinear', align_corners=False) 67 | x = F.softmax(x, dim=1) 68 | return x 69 | else: 70 | x = F.log_softmax(x, dim=1) 71 | x_dsn = F.log_softmax(x_dsn, dim=1) 72 | return x,x_dsn 73 | 74 | 75 | #class ASPOCRNet(nn.Module): 76 | # """ 77 | # Object-Contextual Representations for Semantic Segmentation, 78 | # Yuan, Yuhui and Chen, Xilin and Wang, Jingdong 79 | # """ 80 | # def __init__(self, configer): 81 | # self.inplanes = 128 82 | # super(ASPOCRNet, self).__init__() 83 | # self.configer = configer 84 | # self.num_classes = self.configer.get('data', 'num_classes') 85 | # self.backbone = BackboneSelector(configer).get_backbone() 86 | # 87 | # # extra added layers 88 | # if "wide_resnet38" in self.configer.get('network', 'backbone'): 89 | # in_channels = [2048, 4096] 90 | # else: 91 | # in_channels = [1024, 2048] 92 | # 93 | # # we should increase the dilation rates as the output stride is larger 94 | # from lib.models.modules.spatial_ocr_block import SpatialOCR_ASP_Module 95 | # self.asp_ocr_head = SpatialOCR_ASP_Module(features=2048, 96 | # hidden_features=256, 97 | # out_features=256, 98 | # num_classes=self.num_classes, 99 | # bn_type=self.configer.get('network', 'bn_type')) 100 | # 101 | # self.head = nn.Conv2d(256, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True) 102 | # self.dsn_head = nn.Sequential( 103 | # nn.Conv2d(in_channels[0], 512, kernel_size=3, stride=1, padding=1), 104 | # ModuleHelper.BNReLU(512, bn_type=self.configer.get('network', 'bn_type')), 105 | # nn.Dropout2d(0.1), 106 | # nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True) 107 | # ) 108 | # 109 | # def forward(self, x_): 110 | # x = self.backbone(x_) 111 | # x_dsn = self.dsn_head(x[-2]) 112 | # x = self.asp_ocr_head(x[-1], x_dsn) 113 | # x = self.head(x) 114 | # x_dsn = F.interpolate(x_dsn, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True) 115 | # x = F.interpolate(x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True) 116 | # return x_dsn, x 117 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from .utils import load_url 4 | from models.sync_batchnorm import SynchronizedBatchNorm2d 5 | BatchNorm2d = SynchronizedBatchNorm2d 6 | 7 | 8 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101 9 | 10 | 11 | model_urls = { 12 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', 13 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class GroupBottleneck(nn.Module): 24 | expansion = 2 25 | 26 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): 27 | super(GroupBottleneck, self).__init__() 28 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 29 | self.bn1 = BatchNorm2d(planes) 30 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 31 | padding=1, groups=groups, bias=False) 32 | self.bn2 = BatchNorm2d(planes) 33 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 34 | self.bn3 = BatchNorm2d(planes * 2) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv3(out) 51 | out = self.bn3(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class ResNeXt(nn.Module): 63 | 64 | def __init__(self, block, layers, groups=32, num_classes=1000): 65 | self.inplanes = 128 66 | super(ResNeXt, self).__init__() 67 | self.conv1 = conv3x3(3, 64, stride=2) 68 | self.bn1 = BatchNorm2d(64) 69 | self.relu1 = nn.ReLU(inplace=True) 70 | self.conv2 = conv3x3(64, 64) 71 | self.bn2 = BatchNorm2d(64) 72 | self.relu2 = nn.ReLU(inplace=True) 73 | self.conv3 = conv3x3(64, 128) 74 | self.bn3 = BatchNorm2d(128) 75 | self.relu3 = nn.ReLU(inplace=True) 76 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 77 | 78 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) 79 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) 80 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) 81 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) 82 | self.avgpool = nn.AvgPool2d(7, stride=1) 83 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | elif isinstance(m, BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | def _make_layer(self, block, planes, blocks, stride=1, groups=1): 94 | downsample = None 95 | if stride != 1 or self.inplanes != planes * block.expansion: 96 | downsample = nn.Sequential( 97 | nn.Conv2d(self.inplanes, planes * block.expansion, 98 | kernel_size=1, stride=stride, bias=False), 99 | BatchNorm2d(planes * block.expansion), 100 | ) 101 | 102 | layers = [] 103 | layers.append(block(self.inplanes, planes, stride, groups, downsample)) 104 | self.inplanes = planes * block.expansion 105 | for i in range(1, blocks): 106 | layers.append(block(self.inplanes, planes, groups=groups)) 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | x = self.relu1(self.bn1(self.conv1(x))) 112 | x = self.relu2(self.bn2(self.conv2(x))) 113 | x = self.relu3(self.bn3(self.conv3(x))) 114 | x = self.maxpool(x) 115 | 116 | x = self.layer1(x) 117 | x = self.layer2(x) 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | 121 | x = self.avgpool(x) 122 | x = x.view(x.size(0), -1) 123 | x = self.fc(x) 124 | 125 | return x 126 | 127 | 128 | ''' 129 | def resnext50(pretrained=False, **kwargs): 130 | """Constructs a ResNet-50 model. 131 | 132 | Args: 133 | pretrained (bool): If True, returns a model pre-trained on Places 134 | """ 135 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) 136 | if pretrained: 137 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False) 138 | return model 139 | ''' 140 | 141 | 142 | def resnext101(pretrained=False, **kwargs): 143 | """Constructs a ResNet-101 model. 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on Places 147 | """ 148 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) 149 | if pretrained: 150 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False) 151 | return model 152 | 153 | 154 | # def resnext152(pretrained=False, **kwargs): 155 | # """Constructs a ResNeXt-152 model. 156 | # 157 | # Args: 158 | # pretrained (bool): If True, returns a model pre-trained on Places 159 | # """ 160 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) 161 | # if pretrained: 162 | # model.load_state_dict(load_url(model_urls['resnext152'])) 163 | # return model 164 | -------------------------------------------------------------------------------- /models/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /models/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /models/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /models/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /models/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /models/td4_psp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__init__.py -------------------------------------------------------------------------------- /models/td4_psp/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/td4_psp/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /models/td4_psp/__pycache__/td4_psp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__pycache__/td4_psp.cpython-37.pyc -------------------------------------------------------------------------------- /models/td4_psp/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /models/td4_psp/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | 5 | __all__ = ['SegmentationLosses', 'OhemCELoss2D'] 6 | 7 | class SegmentationLosses(nn.CrossEntropyLoss): 8 | """2D Cross Entropy Loss with Auxilary Loss""" 9 | def __init__(self, 10 | weight=None, 11 | ignore_index=-1): 12 | 13 | super(SegmentationLosses, self).__init__(weight, None, ignore_index) 14 | 15 | def forward(self, pred, target): 16 | return super(SegmentationLosses, self).forward(pred, target) 17 | 18 | 19 | 20 | 21 | class OhemCELoss2D(nn.CrossEntropyLoss): 22 | """2D Cross Entropy Loss with Auxilary Loss""" 23 | def __init__(self, 24 | n_min, 25 | thresh=0.7, 26 | ignore_index=-1): 27 | 28 | super(OhemCELoss2D, self).__init__(None, None, ignore_index, reduction='none') 29 | 30 | self.thresh = -math.log(thresh) 31 | self.n_min = n_min 32 | self.ignore_index = ignore_index 33 | 34 | def forward(self, pred, target): 35 | return self.OhemCELoss(pred, target) 36 | 37 | def OhemCELoss(self, logits, labels): 38 | loss = super(OhemCELoss2D, self).forward(logits, labels).view(-1) 39 | loss, _ = torch.sort(loss, descending=True) 40 | if loss[self.n_min] > self.thresh: 41 | loss = loss[loss>self.thresh] 42 | else: 43 | loss = loss[:self.n_min] 44 | return torch.mean(loss) -------------------------------------------------------------------------------- /models/td4_psp/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/utils/__init__.py -------------------------------------------------------------------------------- /models/td4_psp/utils/files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import errno 4 | import shutil 5 | import hashlib 6 | from tqdm import tqdm 7 | import torch 8 | 9 | __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1'] 10 | 11 | def save_checkpoint(state, args, is_best, filename='pretrained.pth.tar'): 12 | """Saves pretrained to disk""" 13 | directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname) 14 | if not os.path.exists(directory): 15 | os.makedirs(directory) 16 | filename = directory + filename 17 | torch.save(state, filename) 18 | if is_best: 19 | shutil.copyfile(filename, directory + 'model_best.pth.tar') 20 | 21 | 22 | def download(url, path=None, overwrite=False, sha1_hash=None): 23 | """Download an given URL 24 | Parameters 25 | ---------- 26 | url : str 27 | URL to download 28 | path : str, optional 29 | Destination path to store downloaded file. By default stores to the 30 | current directory with same name as in url. 31 | overwrite : bool, optional 32 | Whether to overwrite destination file if already exists. 33 | sha1_hash : str, optional 34 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 35 | but doesn't match. 36 | Returns 37 | ------- 38 | str 39 | The file path of the downloaded file. 40 | """ 41 | if path is None: 42 | fname = url.split('/')[-1] 43 | else: 44 | path = os.path.expanduser(path) 45 | if os.path.isdir(path): 46 | fname = os.path.join(path, url.split('/')[-1]) 47 | else: 48 | fname = path 49 | 50 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 51 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 52 | if not os.path.exists(dirname): 53 | os.makedirs(dirname) 54 | 55 | print('Downloading %s from %s...'%(fname, url)) 56 | r = requests.get(url, stream=True) 57 | if r.status_code != 200: 58 | raise RuntimeError("Failed downloading url %s"%url) 59 | total_length = r.headers.get('content-length') 60 | with open(fname, 'wb') as f: 61 | if total_length is None: # no content length header 62 | for chunk in r.iter_content(chunk_size=1024): 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | else: 66 | total_length = int(total_length) 67 | for chunk in tqdm(r.iter_content(chunk_size=1024), 68 | total=int(total_length / 1024. + 0.5), 69 | unit='KB', unit_scale=False, dynamic_ncols=True): 70 | f.write(chunk) 71 | 72 | if sha1_hash and not check_sha1(fname, sha1_hash): 73 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 74 | 'The repo may be outdated or download may be incomplete. ' \ 75 | 'If the "repo_url" is overridden, consider switching to ' \ 76 | 'the default repo.'.format(fname)) 77 | 78 | return fname 79 | 80 | 81 | def check_sha1(filename, sha1_hash): 82 | """Check whether the sha1 hash of the file content matches the expected hash. 83 | Parameters 84 | ---------- 85 | filename : str 86 | Path to the file. 87 | sha1_hash : str 88 | Expected sha1 hash in hexadecimal digits. 89 | Returns 90 | ------- 91 | bool 92 | Whether the file content matches the expected hash. 93 | """ 94 | sha1 = hashlib.sha1() 95 | with open(filename, 'rb') as f: 96 | while True: 97 | data = f.read(1048576) 98 | if not data: 99 | break 100 | sha1.update(data) 101 | 102 | return sha1.hexdigest() == sha1_hash 103 | 104 | 105 | def mkdir(path): 106 | """make dir exists okay""" 107 | try: 108 | os.makedirs(path) 109 | except OSError as exc: # Python >2.5 110 | if exc.errno == errno.EEXIST and os.path.isdir(path): 111 | pass 112 | else: 113 | raise 114 | -------------------------------------------------------------------------------- /models/td4_psp/utils/model_store.py: -------------------------------------------------------------------------------- 1 | """Model store which provides pretrained models.""" 2 | from __future__ import print_function 3 | __all__ = ['get_model_file', 'purge'] 4 | import os 5 | import zipfile 6 | 7 | from .files import download, check_sha1 8 | 9 | _model_sha1 = {name: checksum for checksum, name in [ 10 | ('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'), 11 | ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'), 12 | ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'), 13 | ('da4785cfc837bf00ef95b52fb218feefe703011f', 'wideresnet38'), 14 | ('b41562160173ee2e979b795c551d3c7143b1e5b5', 'wideresnet50'), 15 | ('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'), 16 | ('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50_ade'), 17 | ('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'), 18 | ('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'), 19 | ('075195c5237b778c718fd73ceddfa1376c18dfd0', 'deeplab_resnet50_ade'), 20 | ('5ee47ee28b480cc781a195d13b5806d5bbc616bf', 'encnet_resnet101_coco'), 21 | ('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50_pcontext'), 22 | ('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101_pcontext'), 23 | ('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50_ade'), 24 | ('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101_ade'), 25 | ]} 26 | 27 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' 28 | _url_format = '{repo_url}encoding/models/{file_name}.zip' 29 | 30 | def short_hash(name): 31 | if name not in _model_sha1: 32 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 33 | return _model_sha1[name][:8] 34 | 35 | def get_model_file(name, root=os.path.join('~', '.encoding', 'models')): 36 | r"""Return location for the pretrained on local file system. 37 | 38 | This function will download from online model zoo when model cannot be found or has mismatch. 39 | The root directory will be created if it doesn't exist. 40 | 41 | Parameters 42 | ---------- 43 | name : str 44 | Name of the model. 45 | root : str, default '~/.encoding/models' 46 | Location for keeping the model parameters. 47 | 48 | Returns 49 | ------- 50 | file_path 51 | Path to the requested pretrained model file. 52 | """ 53 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) 54 | root = os.path.expanduser(root) 55 | file_path = os.path.join(root, file_name+'.pth') 56 | sha1_hash = _model_sha1[name] 57 | if os.path.exists(file_path): 58 | if check_sha1(file_path, sha1_hash): 59 | return file_path 60 | else: 61 | print('Mismatch in the content of model file {} detected.' + 62 | ' Downloading again.'.format(file_path)) 63 | else: 64 | print('Model file {} is not found. Downloading.'.format(file_path)) 65 | 66 | if not os.path.exists(root): 67 | os.makedirs(root) 68 | 69 | zip_file_path = os.path.join(root, file_name+'.zip') 70 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) 71 | if repo_url[-1] != '/': 72 | repo_url = repo_url + '/' 73 | download(_url_format.format(repo_url=repo_url, file_name=file_name), 74 | path=zip_file_path, 75 | overwrite=True) 76 | with zipfile.ZipFile(zip_file_path) as zf: 77 | zf.extractall(root) 78 | os.remove(zip_file_path) 79 | 80 | if check_sha1(file_path, sha1_hash): 81 | return file_path 82 | else: 83 | raise ValueError('Downloaded file has different hash. Please try again.') 84 | 85 | def purge(root=os.path.join('~', '.encoding', 'models')): 86 | r"""Purge all pretrained model files in local file store. 87 | 88 | Parameters 89 | ---------- 90 | root : str, default '~/.encoding/models' 91 | Location for keeping the model parameters. 92 | """ 93 | root = os.path.expanduser(root) 94 | files = os.listdir(root) 95 | for f in files: 96 | if f.endswith(".pth"): 97 | os.remove(os.path.join(root, f)) 98 | 99 | def pretrained_model_list(): 100 | return list(_model_sha1.keys()) 101 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | try: 4 | from urllib import urlretrieve 5 | except ImportError: 6 | from urllib.request import urlretrieve 7 | import torch 8 | 9 | 10 | def load_url(url, model_dir='./pretrained', map_location=None): 11 | if not os.path.exists(model_dir): 12 | os.makedirs(model_dir) 13 | filename = url.split('/')[-1] 14 | cached_file = os.path.join(model_dir, filename) 15 | if not os.path.exists(cached_file): 16 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 17 | urlretrieve(url, cached_file) 18 | return torch.load(cached_file, map_location=map_location) 19 | -------------------------------------------------------------------------------- /scripts/run_etc.sh: -------------------------------------------------------------------------------- 1 | DATAROOT="your/path/to/VSPW_480p" 2 | 3 | 4 | 5 | SAVE="./savemodel" 6 | DATAROOT2='data2' 7 | BATCHSIZE=8 8 | WORKERS=12 9 | CROPSIZE=479 10 | 11 | 12 | 13 | START_GPU=0 14 | GPU_NUM=2 15 | TRAINFPS=1 16 | EPOCH=120 17 | LR=0.002 18 | VAL=False 19 | USETWODATA=False 20 | LESSLABEL=False 21 | CLIPNUM=2 22 | DILATION=0 23 | CLIPUP=False 24 | CLIPMIDDLE=False 25 | OTHERGT=False 26 | PROPCLIP2=False 27 | EARLYFUSE=True 28 | EARLYCAT=False 29 | CONVLSTM=False 30 | NON_LOCAL=False 31 | 32 | 33 | 34 | FIX=False 35 | ALLSUP=True 36 | ALLSUPSCALE=0.5 37 | LINEAR_COM=True 38 | DISTSOFTMAX=False 39 | DISTNEAREST=False 40 | TEMP=0.05 41 | 42 | DILATION2="3,6,9" 43 | 44 | CLIPOCR_ALL=False 45 | USEMEMORY=True 46 | 47 | METHOD='etc' 48 | 49 | 50 | 51 | PREROOT='' 52 | PRE_ENC="./imgnetpre/resnet101-imagenet.pth" 53 | MAXDIST='3' 54 | ######### 55 | ARCH=resnet101 56 | CFG='vsp-'$ARCH'dilated-ppm_deepsup_clip.yaml' 57 | #CFG='vsp-'$ARCH'dilated-ppm_clip.yaml' 58 | #CFG="vsp-"$ARCH"dilated_tdnet.yaml" 59 | PREDIR="../data/imgnetpre/"$ARCH"-imagenet.pth" 60 | 61 | NAME='newjob_lr'$LR'_bs'$BATCHSIZE'_epoch'$EPOCH'_FPS'$TRAINFPS'_clipnum'$CLIPNUM"_dilation"$DILATION"_fix"$FIX"_tdnet"$TDNET"_arch"$ARCH'_method'$METHOD'_DISTSOFTMAX'$DISTSOFTMAX'_DISTNEAREST'$DISTNEAREST"_CLIPOCR_ALL"$CLIPOCR_ALL"_USEMEMORY"$USEMEMORY"imgnetpre" 62 | 63 | 64 | SAVEROOT=$SAVE"/"$NAME 65 | python train_clip2.py --cfg config/$CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --clip_num $CLIPNUM --dilation_num $DILATION --clip_up $CLIPUP --clip_middle $CLIPMIDDLE --fix $FIX --othergt $OTHERGT --propclip2 $PROPCLIP2 --earlyfuse $EARLYFUSE --early_usecat $EARLYCAT --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --pre_enc $PRE_ENC --max_distances $MAXDIST --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 66 | 67 | 68 | 69 | ###inference 70 | echo 'val' 71 | BATCHSIZE=1 72 | GPU_NUM=1 73 | ISSAVE=True 74 | LESSLABLE=False 75 | USE720p=False 76 | EARLYFUSE=False 77 | EARLYCAT=False 78 | 79 | #CLIPNUM=5 80 | 81 | IMGSAVEROOT='./clipsaveimg/'$NAME 82 | 83 | LOAD=$SAVEROOT'/model_epoch_'$EPOCH'.pth' 84 | 85 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'val' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 86 | 87 | echo 'test' 88 | 89 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'test' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /scripts/run_netwarp.sh: -------------------------------------------------------------------------------- 1 | DATAROOT="your/path/to/VSPW_480p" 2 | 3 | 4 | 5 | SAVE="./savemodel" 6 | DATAROOT2='data2' 7 | BATCHSIZE=8 8 | WORKERS=12 9 | CROPSIZE=479 10 | 11 | 12 | 13 | START_GPU=0 14 | GPU_NUM=2 15 | TRAINFPS=1 16 | EPOCH=120 17 | LR=0.002 18 | VAL=False 19 | USETWODATA=False 20 | LESSLABEL=False 21 | CLIPNUM=2 22 | DILATION=0 23 | CLIPUP=False 24 | CLIPMIDDLE=False 25 | OTHERGT=False 26 | PROPCLIP2=False 27 | EARLYFUSE=True 28 | EARLYCAT=False 29 | CONVLSTM=False 30 | NON_LOCAL=False 31 | 32 | 33 | 34 | FIX=False 35 | ALLSUP=True 36 | ALLSUPSCALE=0.5 37 | LINEAR_COM=True 38 | DISTSOFTMAX=False 39 | DISTNEAREST=False 40 | TEMP=0.05 41 | 42 | DILATION2="3,6,9" 43 | 44 | CLIPOCR_ALL=False 45 | USEMEMORY=True 46 | 47 | METHOD='netwarp' 48 | 49 | 50 | 51 | PREROOT='' 52 | PRE_ENC="./imgnetpre/resnet101-imagenet.pth" 53 | MAXDIST='3' 54 | ######### 55 | ARCH=resnet101 56 | CFG='vsp-'$ARCH'dilated-ppm_deepsup_clip.yaml' 57 | #CFG='vsp-'$ARCH'dilated-ppm_clip.yaml' 58 | #CFG="vsp-"$ARCH"dilated_tdnet.yaml" 59 | PREDIR="../data/imgnetpre/"$ARCH"-imagenet.pth" 60 | 61 | NAME='newjob_lr'$LR'_bs'$BATCHSIZE'_epoch'$EPOCH'_FPS'$TRAINFPS'_clipnum'$CLIPNUM"_dilation"$DILATION"_fix"$FIX"_tdnet"$TDNET"_arch"$ARCH'_method'$METHOD'_DISTSOFTMAX'$DISTSOFTMAX'_DISTNEAREST'$DISTNEAREST"_CLIPOCR_ALL"$CLIPOCR_ALL"_USEMEMORY"$USEMEMORY"imgnetpre" 62 | 63 | 64 | SAVEROOT=$SAVE"/"$NAME 65 | python train_clip2.py --cfg config/$CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --clip_num $CLIPNUM --dilation_num $DILATION --clip_up $CLIPUP --clip_middle $CLIPMIDDLE --fix $FIX --othergt $OTHERGT --propclip2 $PROPCLIP2 --earlyfuse $EARLYFUSE --early_usecat $EARLYCAT --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --pre_enc $PRE_ENC --max_distances $MAXDIST --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 66 | 67 | 68 | 69 | ###inference 70 | echo 'val' 71 | BATCHSIZE=1 72 | GPU_NUM=1 73 | ISSAVE=True 74 | LESSLABLE=False 75 | USE720p=False 76 | EARLYFUSE=False 77 | EARLYCAT=False 78 | 79 | #CLIPNUM=5 80 | 81 | IMGSAVEROOT='./clipsaveimg/'$NAME 82 | 83 | LOAD=$SAVEROOT'/model_epoch_'$EPOCH'.pth' 84 | 85 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'val' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 86 | 87 | echo 'test' 88 | 89 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'test' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /scripts/run_ocr.sh: -------------------------------------------------------------------------------- 1 | DATAROOT="/your/path/to/LVSP_plus_data_label124_480p" 2 | 3 | 4 | ##### 5 | ARCH=res101_ocrnet 6 | CFG="config/vsp-resnet101dilated-ocr_deepsup.yaml" 7 | #### 8 | 9 | 10 | 11 | PREDIR='./imgnetpre/resnet101-imagenet.pth' 12 | 13 | 14 | SAVE="./savemodel" 15 | 16 | 17 | DATAROOT2=../data/adeour 18 | BATCHSIZE=8 19 | WORKERS=12 20 | USETWODATA=False 21 | START_GPU=0 22 | GPU_NUM=2 23 | TRAINFPS=2 24 | LR=0.002 25 | 26 | 27 | 28 | CROPSIZE=479 29 | 30 | LESSLABEL=False 31 | 32 | 33 | USE_CLIPDATASET=True 34 | EPOCH=120 35 | NAME='job_lr'$LR'batchsize'$BATCHSIZE'_EPOCH'$EPOCH'_FPS'$TRAINFPS"_arch"$ARCH"new124_gpu"$GPU_NUM"_480p""USE_CLIPDATASET"$USE_CLIPDATASET 36 | SAVEROOT=$SAVE"/"$NAME 37 | VAL=False 38 | echo $CFG 39 | 40 | 41 | echo 'train...' 42 | python train.py --cfg $CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --use_clipdataset $USE_CLIPDATASET 43 | 44 | 45 | LOAD_EN=$SAVEROOT'/encoder_epoch_'$EPOCH'.pth' 46 | LOAD_DE=$SAVEROOT'/decoder_epoch_'$EPOCH'.pth' 47 | 48 | 49 | 50 | TESTBATCHSIZE=2 51 | ISSAVE=False 52 | IMGSAVEROOT='./saveimg/'$NAME'_train' 53 | USE720p=False 54 | LESSLABLE=False 55 | 56 | echo 'val...' 57 | python test.py --cfg $CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --load_en $LOAD_EN --load_de $LOAD_DE --batchsize $TESTBATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --split 'val' 58 | echo 'test...' 59 | 60 | python test.py --cfg $CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --load_en $LOAD_EN --load_de $LOAD_DE --batchsize $TESTBATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --split 'test' 61 | 62 | -------------------------------------------------------------------------------- /scripts/run_psp.sh: -------------------------------------------------------------------------------- 1 | DATAROOT="/your/path/to/LVSP_plus_data_label124_480p" 2 | 3 | 4 | ##### 5 | #ARCH=res101_ocrnet 6 | #CFG="config/vsp-resnet101dilated-ocr_deepsup.yaml" 7 | #### 8 | 9 | #ARCH=res101_deeplab 10 | #CFG="config/vsp-resnet101dilated-deeplab.yaml" 11 | #CFG="config/vsp-resnet50dilated-deeplab.yaml" 12 | 13 | ###### 14 | #ARCH=res101_nonlocal2d_nodown 15 | #CFG="config/vsp-resnet101dilated-nonlocal2d.yaml" 16 | ######## 17 | #ARCH=mobile_ppm 18 | ARCH=res101_ppm 19 | CFG="config/vsp-resnet101dilated-ppm_deepsup.yaml" 20 | #CFG="config/ade20k-resnet50dilated-ppm_deepsup.yaml" 21 | #CFG="config/ade20k-mobilenetv2dilated-ppm_deepsup.yaml" 22 | 23 | #ARCH=resnet101uper 24 | #CFG="config/ade20k-resnet101-upernet.yaml" 25 | #CFG="config/ade20k-resnet50-upernet.yaml" 26 | 27 | 28 | #PREDIR="../data/imgnetpre/resnet101-imagenet.pth" 29 | PREDIR='./imgnetpre/resnet101-imagenet.pth' 30 | #PREDIR="../data/imgnetpre/resnet50-imagenet.pth" 31 | #PREDIR="../data/imgnetpre/mobilenet_v2.pth.tar" 32 | 33 | 34 | #SAVE="../afs/video_seg/vsp_124" 35 | SAVE="./savemodel" 36 | 37 | #ARCH='hrnet' 38 | #CFG="config/ade20k-hrnetv2.yaml" 39 | #PREDIR="../data/imgnetpre/hrnetv2_w48-imagenet.pth" 40 | 41 | DATAROOT2=../data/adeour 42 | BATCHSIZE=8 43 | WORKERS=12 44 | USETWODATA=False 45 | START_GPU=0 46 | GPU_NUM=2 47 | TRAINFPS=2 48 | LR=0.002 49 | 50 | 51 | 52 | CROPSIZE=479 53 | 54 | LESSLABEL=False 55 | 56 | 57 | USE_CLIPDATASET=True 58 | EPOCH=120 59 | NAME='job_lr'$LR'batchsize'$BATCHSIZE'_EPOCH'$EPOCH'_FPS'$TRAINFPS"_arch"$ARCH"new124_gpu"$GPU_NUM"_480p""USE_CLIPDATASET"$USE_CLIPDATASET 60 | SAVEROOT=$SAVE"/"$NAME 61 | VAL=False 62 | echo $CFG 63 | 64 | 65 | echo 'train...' 66 | python train.py --cfg $CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --use_clipdataset $USE_CLIPDATASET 67 | 68 | 69 | LOAD_EN=$SAVEROOT'/encoder_epoch_'$EPOCH'.pth' 70 | LOAD_DE=$SAVEROOT'/decoder_epoch_'$EPOCH'.pth' 71 | 72 | 73 | 74 | TESTBATCHSIZE=2 75 | ISSAVE=True 76 | IMGSAVEROOT='./saveimg/'$NAME'_train' 77 | USE720p=False 78 | LESSLABLE=False 79 | 80 | echo 'val...' 81 | python test.py --cfg $CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --load_en $LOAD_EN --load_de $LOAD_DE --batchsize $TESTBATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --split 'val' 82 | echo 'test...' 83 | 84 | python test.py --cfg $CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --load_en $LOAD_EN --load_de $LOAD_DE --batchsize $TESTBATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --split 'test' 85 | 86 | -------------------------------------------------------------------------------- /scripts/run_temporal_ocr.sh: -------------------------------------------------------------------------------- 1 | DATAROOT="your/path/to/VSPW_480p" 2 | 3 | 4 | 5 | SAVE="./savemodel" 6 | DATAROOT2='data2' 7 | BATCHSIZE=8 8 | WORKERS=12 9 | CROPSIZE=479 10 | 11 | 12 | 13 | START_GPU=0 14 | GPU_NUM=4 15 | TRAINFPS=1 16 | EPOCH=120 17 | LR=0.002 18 | VAL=False 19 | USETWODATA=False 20 | LESSLABEL=False 21 | CLIPNUM=4 22 | DILATION=0 23 | CLIPUP=False 24 | CLIPMIDDLE=False 25 | OTHERGT=False 26 | PROPCLIP2=False 27 | EARLYFUSE=True 28 | EARLYCAT=False 29 | CONVLSTM=False 30 | NON_LOCAL=False 31 | 32 | 33 | 34 | FIX=False 35 | ALLSUP=True 36 | ALLSUPSCALE=0.5 37 | LINEAR_COM=True 38 | DISTSOFTMAX=False 39 | DISTNEAREST=False 40 | TEMP=0.05 41 | 42 | DILATION2="3,6,9" 43 | 44 | CLIPOCR_ALL=False 45 | USEMEMORY=True 46 | 47 | METHOD='clip_ocr' 48 | 49 | 50 | 51 | PREROOT='' 52 | PRE_ENC="./imgnetpre/resnet101-imagenet.pth" 53 | MAXDIST='3' 54 | ######### 55 | ARCH=resnet101 56 | CFG='vsp-'$ARCH'dilated-ppm_deepsup_clip.yaml' 57 | #CFG='vsp-'$ARCH'dilated-ppm_clip.yaml' 58 | #CFG="vsp-"$ARCH"dilated_tdnet.yaml" 59 | PREDIR="../data/imgnetpre/"$ARCH"-imagenet.pth" 60 | 61 | NAME='newjob_lr'$LR'_bs'$BATCHSIZE'_epoch'$EPOCH'_FPS'$TRAINFPS'_clipnum'$CLIPNUM"_dilation"$DILATION"_fix"$FIX"_tdnet"$TDNET"_arch"$ARCH'_method'$METHOD'_DISTSOFTMAX'$DISTSOFTMAX'_DISTNEAREST'$DISTNEAREST"_CLIPOCR_ALL"$CLIPOCR_ALL"_USEMEMORY"$USEMEMORY"imgnetpre" 62 | 63 | 64 | SAVEROOT=$SAVE"/"$NAME 65 | python train_clip2.py --cfg config/$CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --clip_num $CLIPNUM --dilation_num $DILATION --clip_up $CLIPUP --clip_middle $CLIPMIDDLE --fix $FIX --othergt $OTHERGT --propclip2 $PROPCLIP2 --earlyfuse $EARLYFUSE --early_usecat $EARLYCAT --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --pre_enc $PRE_ENC --max_distances $MAXDIST --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 66 | 67 | 68 | 69 | ###inference 70 | echo 'val' 71 | BATCHSIZE=1 72 | GPU_NUM=1 73 | ISSAVE=True 74 | LESSLABLE=False 75 | USE720p=False 76 | EARLYFUSE=False 77 | EARLYCAT=False 78 | 79 | #CLIPNUM=5 80 | 81 | IMGSAVEROOT='./clipsaveimg/'$NAME 82 | 83 | LOAD=$SAVEROOT'/model_epoch_'$EPOCH'.pth' 84 | 85 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'val' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 86 | 87 | echo 'test' 88 | 89 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'test' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 90 | 91 | #python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'valtest' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --use_memory $USEMEMORY 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /scripts/run_temporal_psp.sh: -------------------------------------------------------------------------------- 1 | DATAROOT="your/path/to/VSPW_480p" 2 | 3 | 4 | 5 | SAVE="./savemodel" 6 | DATAROOT2='data2' 7 | BATCHSIZE=8 8 | WORKERS=12 9 | CROPSIZE=479 10 | 11 | 12 | 13 | START_GPU=0 14 | GPU_NUM=4 15 | TRAINFPS=1 16 | EPOCH=120 17 | LR=0.002 18 | VAL=False 19 | USETWODATA=False 20 | LESSLABEL=False 21 | CLIPNUM=4 22 | DILATION=0 23 | CLIPUP=False 24 | CLIPMIDDLE=False 25 | OTHERGT=False 26 | PROPCLIP2=False 27 | EARLYFUSE=True 28 | EARLYCAT=False 29 | CONVLSTM=False 30 | NON_LOCAL=False 31 | 32 | 33 | 34 | FIX=False 35 | ALLSUP=True 36 | ALLSUPSCALE=0.5 37 | LINEAR_COM=True 38 | DISTSOFTMAX=False 39 | DISTNEAREST=False 40 | TEMP=0.05 41 | 42 | DILATION2="3,6,9" 43 | 44 | CLIPOCR_ALL=False 45 | USEMEMORY=True 46 | 47 | METHOD='clip_psp' 48 | 49 | 50 | 51 | PREROOT='' 52 | PRE_ENC="./imgnetpre/resnet101-imagenet.pth" 53 | MAXDIST='3' 54 | ######### 55 | ARCH=resnet101 56 | CFG='vsp-'$ARCH'dilated-ppm_deepsup_clip.yaml' 57 | #CFG='vsp-'$ARCH'dilated-ppm_clip.yaml' 58 | #CFG="vsp-"$ARCH"dilated_tdnet.yaml" 59 | PREDIR="../data/imgnetpre/"$ARCH"-imagenet.pth" 60 | 61 | NAME='newjob_lr'$LR'_bs'$BATCHSIZE'_epoch'$EPOCH'_FPS'$TRAINFPS'_clipnum'$CLIPNUM"_dilation"$DILATION"_fix"$FIX"_tdnet"$TDNET"_arch"$ARCH'_method'$METHOD'_DISTSOFTMAX'$DISTSOFTMAX'_DISTNEAREST'$DISTNEAREST"_CLIPOCR_ALL"$CLIPOCR_ALL"_USEMEMORY"$USEMEMORY"imgnetpre" 62 | 63 | 64 | SAVEROOT=$SAVE"/"$NAME 65 | python train_clip2.py --cfg config/$CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --clip_num $CLIPNUM --dilation_num $DILATION --clip_up $CLIPUP --clip_middle $CLIPMIDDLE --fix $FIX --othergt $OTHERGT --propclip2 $PROPCLIP2 --earlyfuse $EARLYFUSE --early_usecat $EARLYCAT --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --pre_enc $PRE_ENC --max_distances $MAXDIST --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 66 | 67 | 68 | 69 | ###inference 70 | echo 'val' 71 | BATCHSIZE=1 72 | GPU_NUM=1 73 | ISSAVE=True 74 | LESSLABLE=False 75 | USE720p=False 76 | EARLYFUSE=False 77 | EARLYCAT=False 78 | 79 | #CLIPNUM=5 80 | 81 | IMGSAVEROOT='./clipsaveimg/'$NAME 82 | 83 | LOAD=$SAVEROOT'/model_epoch_'$EPOCH'.pth' 84 | 85 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'val' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 86 | 87 | echo 'test' 88 | 89 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'test' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY 90 | 91 | 92 | 93 | 94 | --------------------------------------------------------------------------------