├── RAFT
├── 1.png
├── LICENSE
├── RAFT.png
├── README.md
├── __init__.py
├── alt_cuda_corr
│ ├── correlation.cpp
│ ├── correlation_kernel.cu
│ └── setup.py
├── chairs_split.txt
├── core
│ ├── __init__.py
│ ├── corr.py
│ ├── datasets.py
│ ├── extractor.py
│ ├── raft.py
│ ├── update.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── augmentor.py
│ │ ├── flow_viz.py
│ │ ├── frame_utils.py
│ │ └── utils.py
├── demo-frames
│ ├── frame_0016.png
│ └── frame_0017.png
├── demo.py
├── download_models.sh
├── evaluate.py
├── example.png
├── flow_warp.py
├── logs
│ └── hadoop.kylin.libdfs.log
├── models
│ ├── 1.py
│ └── raft-things.pth-no-zip
├── tosave.npy
├── train.py
├── train_mixed.sh
└── train_standard.sh
├── RAFT_core
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── corr.cpython-37.pyc
│ ├── extractor.cpython-37.pyc
│ ├── raft.cpython-37.pyc
│ └── update.cpython-37.pyc
├── corr.py
├── datasets.py
├── extractor.py
├── raft-things.pth-no-zip
├── raft.py
├── update.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── utils.cpython-37.pyc
│ ├── augmentor.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ └── utils.py
├── README.md
├── TC_cal.py
├── VC_perclip.py
├── change2_480p.py
├── config
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── defaults.cpython-37.pyc
├── defaults.py
├── vsp-hrnetv2.yaml
├── vsp-mobilenetv2dilated-c1_deepsup.yaml
├── vsp-mobilenetv2dilated-ppm_deepsup.yaml
├── vsp-resnet101-upernet.yaml
├── vsp-resnet101dilated-deeplab.yaml
├── vsp-resnet101dilated-nonlocal2d.yaml
├── vsp-resnet101dilated-ocr_deepsup.yaml
├── vsp-resnet101dilated-ppm_clip.yaml
├── vsp-resnet101dilated-ppm_deepsup.yaml
├── vsp-resnet101dilated-ppm_deepsup_clip.yaml
├── vsp-resnet101dilated_tdnet.yaml
├── vsp-resnet18dilated-ppm_deepsup.yaml
├── vsp-resnet18dilated-ppm_deepsup_clip.yaml
├── vsp-resnet50-upernet.yaml
├── vsp-resnet50dilated-deeplab.yaml
├── vsp-resnet50dilated-ppm_deepsup.yaml
├── vsp-resnet50dilated-ppm_deepsup_clip.yaml
└── vsp-resnet50dilated-tdnet.yaml
├── dataset.py
├── dataset2.py
├── lib
├── nn
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-37.pyc
│ ├── modules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── batchnorm.cpython-37.pyc
│ │ │ ├── comm.cpython-37.pyc
│ │ │ └── replicate.cpython-37.pyc
│ │ ├── batchnorm.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ ├── tests
│ │ │ ├── test_numeric_batchnorm.py
│ │ │ └── test_sync_batchnorm.py
│ │ └── unittest.py
│ └── parallel
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── data_parallel.cpython-37.pyc
│ │ └── data_parallel.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── th.cpython-37.pyc
│ ├── data
│ ├── __init__.py
│ ├── dataloader.py
│ ├── dataset.py
│ ├── distributed.py
│ └── sampler.py
│ └── th.py
├── models
├── .non_local2d.py.swp
├── .propnet.py.swo
├── .propnet.py.swp
├── BiConvLSTM.py
├── ETC.py
├── ETC_ocr.py
├── __init__.py
├── __pycache__
│ ├── BiConvLSTM.cpython-37.pyc
│ ├── ETC.cpython-37.pyc
│ ├── ETC_ocr.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── clip_ocr.cpython-37.pyc
│ ├── clip_psp.cpython-37.pyc
│ ├── deeplab.cpython-37.pyc
│ ├── hrnet.cpython-37.pyc
│ ├── hrnet_clip.cpython-37.pyc
│ ├── mobilenet.cpython-37.pyc
│ ├── models.cpython-37.pyc
│ ├── netwarp.cpython-37.pyc
│ ├── netwarp_ocr.cpython-37.pyc
│ ├── non_local.cpython-37.pyc
│ ├── non_local_models.cpython-37.pyc
│ ├── ocrnet.cpython-37.pyc
│ ├── propnet.cpython-37.pyc
│ ├── resnet.cpython-37.pyc
│ ├── resnext.cpython-37.pyc
│ ├── utils.cpython-37.pyc
│ ├── warp_our.cpython-37.pyc
│ └── warp_our_merge.cpython-37.pyc
├── clip_ocr.py
├── clip_psp.py
├── deeplab.py
├── deeplabv3
│ ├── aspp.py
│ └── decoder.py
├── hrnet.py
├── hrnet_clip.py
├── hrnet_clip_2.py
├── mobilenet.py
├── models.py
├── netwarp.py
├── netwarp_ocr.py
├── non_local.py
├── non_local_models.py
├── ocr_modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── spatial_ocr_block.cpython-37.pyc
│ ├── spatial_ocr_block.py
│ └── spatial_ocr_block_max.py
├── ocrnet.py
├── propnet.py
├── resnet.py
├── resnext.py
├── sync_batchnorm
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── batchnorm.cpython-37.pyc
│ │ ├── comm.cpython-37.pyc
│ │ └── replicate.cpython-37.pyc
│ ├── batchnorm.py
│ ├── batchnorm_reimpl.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
├── td4_psp
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── loss.cpython-37.pyc
│ │ ├── td4_psp.cpython-37.pyc
│ │ └── transformer.cpython-37.pyc
│ ├── loss.py
│ ├── pspnet_4p.py
│ ├── resnet_bak.py
│ ├── td4_psp.py
│ ├── td4_psp_bak.py
│ ├── transformer.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── files.py
│ │ └── model_store.py
├── utils.py
├── warp_our.py
└── warp_our_merge.py
├── scripts
├── run_etc.sh
├── run_netwarp.sh
├── run_ocr.sh
├── run_psp.sh
├── run_temporal_ocr.sh
└── run_temporal_psp.sh
├── test.py
├── test_clip2.py
├── train.py
├── train_clip2.py
└── utils.py
/RAFT/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/1.png
--------------------------------------------------------------------------------
/RAFT/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, princeton-vl
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/RAFT/RAFT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/RAFT.png
--------------------------------------------------------------------------------
/RAFT/README.md:
--------------------------------------------------------------------------------
1 | # RAFT
2 | This repository contains the source code for our paper:
3 |
4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 |
8 |
9 |
10 | ## Requirements
11 | The code has been tested with PyTorch 1.6 and Cuda 10.1.
12 | ```Shell
13 | conda create --name raft
14 | conda activate raft
15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch
16 | conda install matplotlib
17 | conda install tensorboard
18 | conda install scipy
19 | conda install opencv
20 | ```
21 |
22 | ## Demos
23 | Pretrained models can be downloaded by running
24 | ```Shell
25 | ./download_models.sh
26 | ```
27 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
28 |
29 | You can demo a trained model on a sequence of frames
30 | ```Shell
31 | python demo.py --model=models/raft-things.pth --path=demo-frames
32 | ```
33 |
34 | ## Required Data
35 | To evaluate/train RAFT, you will need to download the required datasets.
36 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
37 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
38 | * [Sintel](http://sintel.is.tue.mpg.de/)
39 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
40 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional)
41 |
42 |
43 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder
44 |
45 | ```Shell
46 | ├── datasets
47 | ├── Sintel
48 | ├── test
49 | ├── training
50 | ├── KITTI
51 | ├── testing
52 | ├── training
53 | ├── devkit
54 | ├── FlyingChairs_release
55 | ├── data
56 | ├── FlyingThings3D
57 | ├── frames_cleanpass
58 | ├── frames_finalpass
59 | ├── optical_flow
60 | ```
61 |
62 | ## Evaluation
63 | You can evaluate a trained model using `evaluate.py`
64 | ```Shell
65 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision
66 | ```
67 |
68 | ## Training
69 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
70 | ```Shell
71 | ./train_standard.sh
72 | ```
73 |
74 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
75 | ```Shell
76 | ./train_mixed.sh
77 | ```
78 |
79 | ## (Optional) Efficent Implementation
80 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
81 | ```Shell
82 | cd alt_cuda_corr && python setup.py install && cd ..
83 | ```
84 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
85 |
--------------------------------------------------------------------------------
/RAFT/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/__init__.py
--------------------------------------------------------------------------------
/RAFT/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 | std::vector corr_cuda_forward(
6 | torch::Tensor fmap1,
7 | torch::Tensor fmap2,
8 | torch::Tensor coords,
9 | int radius);
10 |
11 | std::vector corr_cuda_backward(
12 | torch::Tensor fmap1,
13 | torch::Tensor fmap2,
14 | torch::Tensor coords,
15 | torch::Tensor corr_grad,
16 | int radius);
17 |
18 | // C++ interface
19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22 |
23 | std::vector corr_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius) {
28 | CHECK_INPUT(fmap1);
29 | CHECK_INPUT(fmap2);
30 | CHECK_INPUT(coords);
31 |
32 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
33 | }
34 |
35 |
36 | std::vector corr_backward(
37 | torch::Tensor fmap1,
38 | torch::Tensor fmap2,
39 | torch::Tensor coords,
40 | torch::Tensor corr_grad,
41 | int radius) {
42 | CHECK_INPUT(fmap1);
43 | CHECK_INPUT(fmap2);
44 | CHECK_INPUT(coords);
45 | CHECK_INPUT(corr_grad);
46 |
47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48 | }
49 |
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &corr_forward, "CORR forward");
53 | m.def("backward", &corr_backward, "CORR backward");
54 | }
--------------------------------------------------------------------------------
/RAFT/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name='correlation',
7 | ext_modules=[
8 | CUDAExtension('alt_cuda_corr',
9 | sources=['correlation.cpp', 'correlation_kernel.cu'],
10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
16 |
--------------------------------------------------------------------------------
/RAFT/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/core/__init__.py
--------------------------------------------------------------------------------
/RAFT/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1)
38 | dy = torch.linspace(-r, r, 2*r+1)
39 | delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/RAFT/core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import BasicUpdateBlock, SmallUpdateBlock
7 | from extractor import BasicEncoder, SmallEncoder
8 | from corr import CorrBlock, AlternateCorrBlock
9 | from utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in self.args:
42 | self.args.dropout = 0
43 |
44 | if 'alternate_corr' not in self.args:
45 | self.args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 | def freeze_bn(self):
59 | for m in self.modules():
60 | if isinstance(m, nn.BatchNorm2d):
61 | m.eval()
62 |
63 | def initialize_flow(self, img):
64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
65 | N, C, H, W = img.shape
66 | coords0 = coords_grid(N, H//8, W//8).to(img.device)
67 | coords1 = coords_grid(N, H//8, W//8).to(img.device)
68 |
69 | # optical flow computed as difference: flow = coords1 - coords0
70 | return coords0, coords1
71 |
72 | def upsample_flow(self, flow, mask):
73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
74 | N, _, H, W = flow.shape
75 | mask = mask.view(N, 1, 9, 8, 8, H, W)
76 | mask = torch.softmax(mask, dim=2)
77 |
78 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
80 |
81 | up_flow = torch.sum(mask * up_flow, dim=2)
82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
83 | return up_flow.reshape(N, 2, 8*H, 8*W)
84 |
85 |
86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
87 | """ Estimate optical flow between pair of frames """
88 |
89 | image1 = 2 * (image1 / 255.0) - 1.0
90 | image2 = 2 * (image2 / 255.0) - 1.0
91 |
92 | image1 = image1.contiguous()
93 | image2 = image2.contiguous()
94 |
95 | hdim = self.hidden_dim
96 | cdim = self.context_dim
97 |
98 | # run the feature network
99 | with autocast(enabled=self.args.mixed_precision):
100 | fmap1, fmap2 = self.fnet([image1, image2])
101 |
102 | fmap1 = fmap1.float()
103 | fmap2 = fmap2.float()
104 | if self.args.alternate_corr:
105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
106 | else:
107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 |
109 | # run the context network
110 | with autocast(enabled=self.args.mixed_precision):
111 | cnet = self.cnet(image1)
112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
113 | net = torch.tanh(net)
114 | inp = torch.relu(inp)
115 |
116 | coords0, coords1 = self.initialize_flow(image1)
117 |
118 | if flow_init is not None:
119 | coords1 = coords1 + flow_init
120 |
121 | flow_predictions = []
122 | for itr in range(iters):
123 | coords1 = coords1.detach()
124 | corr = corr_fn(coords1) # index correlation volume
125 |
126 | flow = coords1 - coords0
127 | with autocast(enabled=self.args.mixed_precision):
128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
129 |
130 | # F(t+1) = F(t) + \Delta(t)
131 | coords1 = coords1 + delta_flow
132 |
133 | # upsample predictions
134 | if up_mask is None:
135 | flow_up = upflow8(coords1 - coords0)
136 | else:
137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
138 |
139 | flow_predictions.append(flow_up)
140 |
141 | if test_mode:
142 | return coords1 - coords0, flow_up
143 |
144 | return flow_predictions
145 |
--------------------------------------------------------------------------------
/RAFT/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/RAFT/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/core/utils/__init__.py
--------------------------------------------------------------------------------
/RAFT/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/RAFT/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/RAFT/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, x):
19 | return F.pad(x, self._pad, mode='replicate')
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd):
75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/demo-frames/frame_0016.png
--------------------------------------------------------------------------------
/RAFT/demo-frames/frame_0017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/demo-frames/frame_0017.png
--------------------------------------------------------------------------------
/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | import argparse
5 | import os
6 | import cv2
7 | import glob
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 | from raft import RAFT
13 | from utils import flow_viz
14 | from utils.utils import InputPadder
15 | import torch.nn.functional as F
16 |
17 |
18 |
19 | DEVICE = 'cuda'
20 |
21 | def load_image(imfile):
22 | img = np.array(Image.open(imfile)).astype(np.uint8)
23 | img = torch.from_numpy(img).permute(2, 0, 1).float()
24 | return img
25 |
26 |
27 | def load_image_list(image_files):
28 | images = []
29 | for imfile in sorted(image_files):
30 | images.append(load_image(imfile))
31 |
32 | images = torch.stack(images, dim=0)
33 | images = images.to(DEVICE)
34 |
35 | padder = InputPadder(images.shape)
36 | return padder.pad(images)[0]
37 |
38 |
39 | def viz(img, flo):
40 | img = img[0].permute(1,2,0).cpu().numpy()
41 | flo = flo[0].permute(1,2,0).cpu().numpy()
42 |
43 | # map flow to rgb image
44 | flo = flow_viz.flow_to_image(flo)
45 | flo = Image.fromarray(flo)
46 | flo.save('example.png')
47 |
48 |
49 | def demo(args):
50 | model = torch.nn.DataParallel(RAFT(args))
51 | model.load_state_dict(torch.load(args.model))
52 |
53 | model = model.module
54 | model.to(DEVICE)
55 | model.eval()
56 |
57 | with torch.no_grad():
58 | images = glob.glob(os.path.join(args.path, '*.png')) + \
59 | glob.glob(os.path.join(args.path, '*.jpg'))
60 |
61 | images = load_image_list(images)
62 | for i in range(images.shape[0]-1):
63 | image1 = images[i,None]
64 | image2 = images[i+1,None]
65 | print(image1.size())
66 | print(image2.size())
67 | image1 = F.interpolate(image1,(480,720))
68 | image2 = F.interpolate(image2,(480,720))
69 |
70 |
71 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
72 | to_save = flow_up.squeeze(0).cpu().numpy()
73 | np.save('tosave.npy',to_save)
74 | viz(image1, flow_up)
75 |
76 |
77 | if __name__ == '__main__':
78 | parser = argparse.ArgumentParser()
79 | parser.add_argument('--model', help="restore checkpoint")
80 | parser.add_argument('--path', help="dataset for evaluation")
81 | parser.add_argument('--small', action='store_true', help='use small model')
82 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
83 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
84 | args = parser.parse_args()
85 |
86 | demo(args)
87 |
--------------------------------------------------------------------------------
/RAFT/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
3 | unzip models.zip
4 |
--------------------------------------------------------------------------------
/RAFT/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/example.png
--------------------------------------------------------------------------------
/RAFT/flow_warp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from PIL import Image
5 | import torch.nn.functional as F
6 |
7 | def warp(x, flo):
8 | """
9 | warp an image/tensor (im2) back to im1, according to the optical flow
10 | x: [B, C, H, W] (im2)
11 | flo: [B, 2, H, W] flow
12 | """
13 | B, C, H, W = x.size()
14 | # mesh grid
15 | xx = torch.arange(0, W).view(1,-1).repeat(H,1)
16 | yy = torch.arange(0, H).view(-1,1).repeat(1,W)
17 | xx = xx.view(1,1,H,W).repeat(B,1,1,1)
18 | yy = yy.view(1,1,H,W).repeat(B,1,1,1)
19 | grid = torch.cat((xx,yy),1).float()
20 |
21 | if x.is_cuda:
22 | grid = grid.todevice(x.device)
23 | vgrid = grid + flo
24 |
25 | # scale grid to [-1,1]
26 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
27 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
28 |
29 | vgrid = vgrid.permute(0,2,3,1)
30 | output = nn.functional.grid_sample(x, vgrid)
31 |
32 | return output
33 |
34 |
35 | def flow_warp(x, flow):
36 | """Warp an image or feature map with optical flow
37 | Args:
38 | x (Tensor): size (n, c, h, w)
39 | flow (Tensor): size (n, 2, h, w), values range from -1 to 1 (relevant to image width or height)
40 | padding_mode (str): 'zeros' or 'border'
41 |
42 | Returns:
43 | Tensor: warped image or feature map
44 | """
45 | assert x.size()[-2:] == flow.size()[-2:]
46 | n, _, h, w = x.size()
47 | x_ = torch.arange(w).view(1, -1).expand(h, -1)
48 | y_ = torch.arange(h).view(-1, 1).expand(-1, w)
49 | grid = torch.stack([x_, y_], dim=0).float()
50 | grid = grid.unsqueeze(0).expand(n, -1, -1, -1)
51 | grid[:, 0, :, :] = 2 * grid[:, 0, :, :] / (w - 1) - 1
52 | grid[:, 1, :, :] = 2 * grid[:, 1, :, :] / (h - 1) - 1
53 | grid += 2 * flow
54 | grid = grid.permute(0, 2, 3, 1)
55 | return F.grid_sample(x, grid)
56 |
57 |
58 |
59 |
60 |
61 | if __name__=='__main__':
62 | img = '/home/miaojiaxu/jiaxu3/vsp_segment/RAFT-master/demo-frames/frame_0016.png'
63 | img = Image.open(img)
64 | img = img.resize((1024,440))
65 | img = np.array(img)
66 | img = img/255.
67 |
68 |
69 |
70 | flow = np.load('tosave.npy')
71 |
72 | img = torch.from_numpy(img)
73 | img = img.unsqueeze(0).permute(0,3,1,2)
74 | img = img.float()
75 |
76 | print(img.size())
77 | flow = torch.from_numpy(flow)
78 | flow = flow.unsqueeze(0)
79 | print(flow.size())
80 | img2 = warp(img,flow)
81 | print(img2.size())
82 | img2 = img2.squeeze(0).permute(1,2,0).numpy()
83 | print(img2.shape)
84 | img2 = Image.fromarray((img2*255.).astype('uint8'))
85 | img2.save('1.png')
86 |
87 |
--------------------------------------------------------------------------------
/RAFT/logs/hadoop.kylin.libdfs.log:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/logs/hadoop.kylin.libdfs.log
--------------------------------------------------------------------------------
/RAFT/models/1.py:
--------------------------------------------------------------------------------
1 | print('!@!')
2 |
3 | if __name__=='__main__':
4 | print('22')
5 |
--------------------------------------------------------------------------------
/RAFT/models/raft-things.pth-no-zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/models/raft-things.pth-no-zip
--------------------------------------------------------------------------------
/RAFT/tosave.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT/tosave.npy
--------------------------------------------------------------------------------
/RAFT/train_mixed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
7 |
--------------------------------------------------------------------------------
/RAFT/train_standard.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p checkpoints
3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85
--------------------------------------------------------------------------------
/RAFT_core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__init__.py
--------------------------------------------------------------------------------
/RAFT_core/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT_core/__pycache__/corr.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/corr.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT_core/__pycache__/extractor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/extractor.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT_core/__pycache__/raft.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/raft.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT_core/__pycache__/update.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/__pycache__/update.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT_core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from RAFT_core.utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1)
38 | dy = torch.linspace(-r, r, 2*r+1)
39 | delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class AlternateCorrBlock:
64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
65 | self.num_levels = num_levels
66 | self.radius = radius
67 |
68 | self.pyramid = [(fmap1, fmap2)]
69 | for i in range(self.num_levels):
70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
72 | self.pyramid.append((fmap1, fmap2))
73 |
74 | def __call__(self, coords):
75 | coords = coords.permute(0, 2, 3, 1)
76 | B, H, W, _ = coords.shape
77 | dim = self.pyramid[0][0].shape[1]
78 |
79 | corr_list = []
80 | for i in range(self.num_levels):
81 | r = self.radius
82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
84 |
85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
87 | corr_list.append(corr.squeeze(1))
88 |
89 | corr = torch.stack(corr_list, dim=1)
90 | corr = corr.reshape(B, -1, H, W)
91 | return corr / torch.sqrt(torch.tensor(dim).float())
92 |
--------------------------------------------------------------------------------
/RAFT_core/raft-things.pth-no-zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/raft-things.pth-no-zip
--------------------------------------------------------------------------------
/RAFT_core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import sys
6 | sys.path.append('RAFT_core')
7 |
8 | from update import BasicUpdateBlock, SmallUpdateBlock
9 | from extractor import BasicEncoder, SmallEncoder
10 | from corr import CorrBlock, AlternateCorrBlock
11 | from RAFT_core.utils.utils import bilinear_sampler, coords_grid, upflow8
12 |
13 | try:
14 | autocast = torch.cuda.amp.autocast
15 | except:
16 | # dummy autocast for PyTorch < 1.6
17 | class autocast:
18 | def __init__(self, enabled):
19 | pass
20 | def __enter__(self):
21 | pass
22 | def __exit__(self, *args):
23 | pass
24 |
25 |
26 | class RAFT(nn.Module):
27 | def __init__(self,requires_grad=False):
28 | super(RAFT, self).__init__()
29 | # self.args = args
30 |
31 | self.hidden_dim = hdim = 128
32 | self.context_dim = cdim = 128
33 | corr_levels = 4
34 | corr_radius = 4
35 | self.corr_radius = corr_radius
36 |
37 |
38 |
39 | # feature network, context network, and update block
40 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=0)
41 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=0)
42 | self.update_block = BasicUpdateBlock(corr_levels,corr_radius, hidden_dim=hdim)
43 | if not requires_grad:
44 | for param in self.parameters():
45 | param.requires_grad = False
46 |
47 | def freeze_bn(self):
48 | for m in self.modules():
49 | if isinstance(m, nn.BatchNorm2d):
50 | m.eval()
51 |
52 | def initialize_flow(self, img):
53 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
54 | N, C, H, W = img.shape
55 | coords0 = coords_grid(N, H//8, W//8).to(img.device)
56 | coords1 = coords_grid(N, H//8, W//8).to(img.device)
57 |
58 | # optical flow computed as difference: flow = coords1 - coords0
59 | return coords0, coords1
60 |
61 | def upsample_flow(self, flow, mask):
62 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
63 | N, _, H, W = flow.shape
64 | mask = mask.view(N, 1, 9, 8, 8, H, W)
65 | mask = torch.softmax(mask, dim=2)
66 |
67 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
68 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
69 |
70 | up_flow = torch.sum(mask * up_flow, dim=2)
71 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
72 | return up_flow.reshape(N, 2, 8*H, 8*W)
73 |
74 |
75 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
76 | """ Estimate optical flow between pair of frames """
77 |
78 | image1 = 2 * (image1 / 255.0) - 1.0
79 | image2 = 2 * (image2 / 255.0) - 1.0
80 |
81 | image1 = image1.contiguous()
82 | image2 = image2.contiguous()
83 |
84 | hdim = self.hidden_dim
85 | cdim = self.context_dim
86 |
87 | # run the feature network
88 | fmap1, fmap2 = self.fnet([image1, image2])
89 |
90 | fmap1 = fmap1.float()
91 | fmap2 = fmap2.float()
92 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.corr_radius)
93 |
94 | # run the context network
95 | cnet = self.cnet(image1)
96 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
97 | net = torch.tanh(net)
98 | inp = torch.relu(inp)
99 |
100 | coords0, coords1 = self.initialize_flow(image1)
101 |
102 | if flow_init is not None:
103 | coords1 = coords1 + flow_init
104 |
105 | flow_predictions = []
106 | for itr in range(iters):
107 | coords1 = coords1.detach()
108 | corr = corr_fn(coords1) # index correlation volume
109 |
110 | flow = coords1 - coords0
111 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
112 |
113 | # F(t+1) = F(t) + \Delta(t)
114 | coords1 = coords1 + delta_flow
115 |
116 | # upsample predictions
117 | if up_mask is None:
118 | flow_up = upflow8(coords1 - coords0)
119 | else:
120 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
121 |
122 | flow_predictions.append(flow_up)
123 |
124 | if test_mode:
125 | return coords1 - coords0, flow_up
126 |
127 | return flow_predictions
128 |
--------------------------------------------------------------------------------
/RAFT_core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self,corr_levels,corr_radius):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = corr_levels * (2*corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self,corr_levels,corr_radius ):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = corr_levels * (2*corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, corr_levels,corr_radius, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(corr_levels,corr_radius)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, corr_levels,corr_radius, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.encoder = BasicMotionEncoder(corr_levels,corr_radius)
118 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
119 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
120 |
121 | self.mask = nn.Sequential(
122 | nn.Conv2d(128, 256, 3, padding=1),
123 | nn.ReLU(inplace=True),
124 | nn.Conv2d(256, 64*9, 1, padding=0))
125 |
126 | def forward(self, net, inp, corr, flow, upsample=True):
127 | motion_features = self.encoder(flow, corr)
128 | inp = torch.cat([inp, motion_features], dim=1)
129 |
130 | net = self.gru(net, inp)
131 | delta_flow = self.flow_head(net)
132 |
133 | # scale mask to balence gradients
134 | mask = .25 * self.mask(net)
135 | return net, mask, delta_flow
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/RAFT_core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/utils/__init__.py
--------------------------------------------------------------------------------
/RAFT_core/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT_core/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/RAFT_core/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT_core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/RAFT_core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/RAFT_core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, x):
19 | #return F.pad(x, self._pad, mode='replicate')
20 | return F.pad(x, self._pad, mode='constant')
21 |
22 | def unpad(self,x):
23 | ht, wd = x.shape[-2:]
24 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
25 | return x[:,:, c[0]:c[1], c[2]:c[3]]
26 |
27 | def forward_interpolate(flow):
28 | flow = flow.detach().cpu().numpy()
29 | dx, dy = flow[0], flow[1]
30 |
31 | ht, wd = dx.shape
32 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
33 |
34 | x1 = x0 + dx
35 | y1 = y0 + dy
36 |
37 | x1 = x1.reshape(-1)
38 | y1 = y1.reshape(-1)
39 | dx = dx.reshape(-1)
40 | dy = dy.reshape(-1)
41 |
42 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
43 | x1 = x1[valid]
44 | y1 = y1[valid]
45 | dx = dx[valid]
46 | dy = dy[valid]
47 |
48 | flow_x = interpolate.griddata(
49 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
50 |
51 | flow_y = interpolate.griddata(
52 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
53 |
54 | flow = np.stack([flow_x, flow_y], axis=0)
55 | return torch.from_numpy(flow).float()
56 |
57 |
58 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
59 | """ Wrapper for grid_sample, uses pixel coordinates """
60 | H, W = img.shape[-2:]
61 | xgrid, ygrid = coords.split([1,1], dim=-1)
62 | xgrid = 2*xgrid/(W-1) - 1
63 | ygrid = 2*ygrid/(H-1) - 1
64 |
65 | grid = torch.cat([xgrid, ygrid], dim=-1)
66 | img = F.grid_sample(img, grid, align_corners=True)
67 |
68 | if mask:
69 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
70 | return img, mask.float()
71 |
72 | return img
73 |
74 |
75 | def coords_grid(batch, ht, wd):
76 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
77 | coords = torch.stack(coords[::-1], dim=0).float()
78 | return coords[None].repeat(batch, 1, 1, 1)
79 |
80 |
81 | def upflow8(flow, mode='bilinear'):
82 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
83 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild
2 |
3 | A pytorch implementation of the CVPR2021 paper "VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild"
4 |
5 | # Preparation
6 |
7 | ## Download VSPW dataset
8 |
9 | The VSPW dataset with extracted frames and masks is available [here](https://github.com/sssdddwww2/vspw_dataset_download). Now you can directly download [VSPW_480P dataset](https://github.com/sssdddwww2/vspw_dataset_download).
10 |
11 | ## Dependencies
12 | - Python 3.7
13 | - Pytorch 1.3.1
14 | - Numpy
15 |
16 | Download the ImageNet-pretrained models at [this link](https://drive.google.com/file/d/1VFmObwlx4d_K7FOjFNk5LhEb3jP8_NaD/view?usp=sharing). Put it in the root folder and decompress it.
17 |
18 | # Train and Test
19 |
20 | Resize the frames and masks of the VSPW dataset to *480p*.
21 |
22 | ```
23 | python change2_480p.py
24 | ```
25 |
26 | Edit the *.sh* files in *scripts/* and change the **$DATAROOT** to your path to VSPW_480p.
27 |
28 | ## Image-based methods
29 |
30 | PSPNet
31 |
32 | ```
33 | sh scripts/run_psp.sh
34 | ```
35 |
36 | OCRNet
37 |
38 | ```
39 | sh scripts/run_ocr.sh
40 | ```
41 |
42 | ## Video-based methods
43 |
44 | TCB-PSP
45 |
46 | ```
47 | sh run_temporal_psp.sh
48 | ```
49 |
50 | TCB-OCR
51 |
52 | ```
53 | sh run_temporal_ocr.sh
54 | ```
55 |
56 | ## Evaluation on TC and VC
57 |
58 | Change dataroot and prediction root in *TC_cal.py* and *VC_perclip.py*.
59 |
60 | ```
61 | python TC_cal.py
62 | ```
63 |
64 | ```
65 | python VC_perclip.py
66 | ```
67 |
68 | This implementation utilized [this code](https://github.com/CSAILVision/semantic-segmentation-pytorch) and [RAFT](https://github.com/princeton-vl/RAFT).
69 |
70 |
71 |
72 | # Citation
73 |
74 | ```
75 | @inproceedings{miao2021vspw,
76 |
77 | title={VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild},
78 |
79 | author={Miao, Jiaxu and Wei, Yunchao and Wu, Yu and Liang, Chen and Li, Guangrui and Yang, Yi},
80 |
81 | booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition},
82 |
83 | year={2021}
84 |
85 | }
86 | ```
87 |
88 |
89 |
--------------------------------------------------------------------------------
/TC_cal.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from RAFT_core.raft import RAFT
4 | from RAFT_core.utils.utils import InputPadder
5 | from collections import OrderedDict
6 | from utils import Evaluator
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import sys
11 |
12 | def flowwarp(x, flo):
13 | """
14 | warp an image/tensor (im2) back to im1, according to the optical flow
15 | x: [B, C, H, W] (im2)
16 | flo: [B, 2, H, W] flow
17 | """
18 | B, C, H, W = x.size()
19 | # mesh grid
20 | xx = torch.arange(0, W).view(1,-1).repeat(H,1)
21 | yy = torch.arange(0, H).view(-1,1).repeat(1,W)
22 | xx = xx.view(1,1,H,W).repeat(B,1,1,1)
23 | yy = yy.view(1,1,H,W).repeat(B,1,1,1)
24 | grid = torch.cat((xx,yy),1).float()
25 |
26 | if x.is_cuda:
27 | grid = grid.to(x.device)
28 | vgrid = grid + flo
29 |
30 | # scale grid to [-1,1]
31 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
32 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
33 |
34 | vgrid = vgrid.permute(0,2,3,1)
35 | output = nn.functional.grid_sample(x, vgrid,mode='nearest',align_corners=False)
36 |
37 | return output
38 |
39 |
40 |
41 | num_class=124
42 |
43 | DIR_='/your/path/to/VSPW_480p'
44 |
45 | data_dir=DIR_+'/data'
46 | result_dir='./prediction'
47 | #list_=['1001_5z_ijQjUf_0','1002_QXQ_QoswLOs']
48 |
49 | split='val.txt'
50 | with open(os.path.join(DIR_,split),'r') as f:
51 |
52 | list_ = f.readlines()
53 | list_ = [v[:-1] for v in list_]
54 |
55 | ###
56 | gpu=0
57 | model_raft = RAFT()
58 | to_load = torch.load('./RAFT_core/raft-things.pth-no-zip')
59 | new_state_dict = OrderedDict()
60 | for k, v in to_load.items():
61 | name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
62 | new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
63 | model_raft.load_state_dict(new_state_dict)
64 | model_raft = model_raft.cuda(gpu)
65 | ###
66 | total_TC=0.
67 | evaluator = Evaluator(num_class)
68 | for video in list_[:100]:
69 | if video[0]=='.':
70 | continue
71 | imglist_ = sorted(os.listdir(os.path.join(data_dir,video,'origin')))
72 | for i,img in enumerate(imglist_[:-1]):
73 | if img[0]=='.':
74 | continue
75 | #print('processing video : {} image: {}'.format(video,img))
76 | next_img = imglist_[i+1]
77 | imgname = img
78 | next_imgname = next_img
79 | img = Image.open(os.path.join(data_dir,video,'origin',img))
80 | next_img =Image.open(os.path.join(data_dir,video,'origin',next_img))
81 | image1 = torch.from_numpy(np.array(img))
82 | image2 = torch.from_numpy(np.array(next_img))
83 | padder = InputPadder(image1.size()[:2])
84 | image1 = image1.unsqueeze(0).permute(0,3,1,2)
85 | image2 = image2.unsqueeze(0).permute(0,3,1,2)
86 | image1 = padder.pad(image1)
87 | image2 = padder.pad(image2)
88 | image1 = image1.cuda(gpu)
89 | image2 = image2.cuda(gpu)
90 | with torch.no_grad():
91 | model_raft.eval()
92 | _,flow = model_raft(image1,image2,iters=20, test_mode=True)
93 | flow = padder.unpad(flow)
94 |
95 | flow = flow.data.cpu()
96 | pred = Image.open(os.path.join(result_dir,video,imgname.split('.')[0]+'.png'))
97 | next_pred = Image.open(os.path.join(result_dir,video,next_imgname.split('.')[0]+'.png'))
98 | pred =torch.from_numpy(np.array(pred))
99 | next_pred = torch.from_numpy(np.array(next_pred))
100 | next_pred = next_pred.unsqueeze(0).unsqueeze(0).float()
101 | # print(next_pred)
102 |
103 | warp_pred = flowwarp(next_pred,flow)
104 | # print(warp_pred)
105 | warp_pred = warp_pred.int().squeeze(1).numpy()
106 | pred = pred.unsqueeze(0).numpy()
107 | evaluator.add_batch(pred, warp_pred)
108 | # v_mIoU = evaluator.Mean_Intersection_over_Union()
109 | # total_TC+=v_mIoU
110 | # print('processed video : {} score:{}'.format(video,v_mIoU))
111 |
112 | #TC = total_TC/len(list_)
113 | TC = evaluator.Mean_Intersection_over_Union()
114 |
115 | print("TC score is {}".format(TC))
116 |
117 | print(split)
118 | print(result_dir)
119 |
120 |
121 |
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/VC_perclip.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from PIL import Image
4 | #from utils import Evaluator
5 | import sys
6 |
7 | def get_common(list_,predlist,clip_num,h,w):
8 | accs = []
9 | for i in range(len(list_)-clip_num):
10 | global_common = np.ones((h,w))
11 | predglobal_common = np.ones((h,w))
12 |
13 |
14 | for j in range(1,clip_num):
15 | common = (list_[i] == list_[i+j])
16 | global_common = np.logical_and(global_common,common)
17 | pred_common = (predlist[i]==predlist[i+j])
18 | predglobal_common = np.logical_and(predglobal_common,pred_common)
19 | pred = (predglobal_common*global_common)
20 |
21 | acc = pred.sum()/global_common.sum()
22 | accs.append(acc)
23 | return accs
24 |
25 |
26 |
27 | DIR='/your/path/to/VSPW_480p'
28 |
29 | Pred='./predicts'
30 | split = 'val.txt'
31 |
32 | with open(os.path.join(DIR,split),'r') as f:
33 | lines = f.readlines()
34 | for line in lines:
35 | videolist = [line[:-1] for line in lines]
36 | total_acc=[]
37 |
38 | clip_num=16
39 |
40 |
41 | for video in videolist:
42 | if video[0]=='.':
43 | continue
44 | imglist = []
45 | predlist = []
46 |
47 | images = sorted(os.listdir(os.path.join(DIR,'data',video,'mask')))
48 |
49 | if len(images)<=clip_num:
50 | continue
51 | for imgname in images:
52 | if imgname[0]=='.':
53 | continue
54 | img = Image.open(os.path.join(DIR,'data',video,'mask',imgname))
55 | w,h = img.size
56 | img = np.array(img)
57 | imglist.append(img)
58 | pred = Image.open(os.path.join(Pred,video,imgname))
59 | pred = np.array(pred)
60 | predlist.append(pred)
61 |
62 | accs = get_common(imglist,predlist,clip_num,h,w)
63 | print(sum(accs)/len(accs))
64 | total_acc.extend(accs)
65 | Acc = np.array(total_acc)
66 | Acc = np.nanmean(Acc)
67 | print(Pred)
68 | print('*'*10)
69 | print('VC{} score: {} on {} set'.format(clip_num,Acc,split))
70 | print('*'*10)
71 |
72 |
--------------------------------------------------------------------------------
/change2_480p.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from multiprocessing import Pool
4 |
5 |
6 | DIR='/your/path/to/VSPW'
7 |
8 | Target_Dir = '/your/path/to/VSPW_480p'
9 |
10 |
11 | def change(DIR,video,image):
12 | img = Image.open(os.path.join(DIR,'data',video,'origin',image))
13 | w,h = img.size
14 |
15 | if not os.path.exists(os.path.join(Target_Dir,'data',video,'origin')):
16 | os.makedirs(os.path.join(Target_Dir,'data',video,'origin'))
17 | img = img.resize((int(480*w/h),480),Image.BILINEAR)
18 | img.save(os.path.join(Target_Dir,'data',video,'origin',image))
19 |
20 | if os.path.isfile(os.path.join(DIR,'data',video,'mask',image.split('.')[0]+'.png')):
21 |
22 |
23 | mask = Image.open(os.path.join(DIR,'data',video,'mask',image.split('.')[0]+'.png'))
24 | mask = mask.resize((int(480*w/h),480),Image.NEAREST)
25 |
26 | if not os.path.exists(os.path.join(Target_Dir,'data',video,'mask')):
27 | os.makedirs(os.path.join(Target_Dir,'data',video,'mask'))
28 |
29 | mask.save(os.path.join(Target_Dir,'data',video,'mask',image.split('.')[0]+'.png'))
30 | print('Processing video {} image {}'.format(video,image))
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 | #p = Pool(8)
39 | for video in sorted(os.listdir(os.path.join(DIR,'data'))):
40 | if video[0]=='.':
41 | continue
42 | for image in sorted(os.listdir(os.path.join(DIR,'data',video,'origin'))):
43 | if image[0]=='.':
44 | continue
45 | # p.apply_async(change,args=(DIR,video,image,))
46 | change(DIR,video,image)
47 | #p.close()
48 | #p.join()
49 | print('finish')
50 |
51 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .defaults import _C as cfg
2 |
--------------------------------------------------------------------------------
/config/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/config/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/config/__pycache__/defaults.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/config/__pycache__/defaults.cpython-37.pyc
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Config definition
5 | # -----------------------------------------------------------------------------
6 |
7 | _C = CN()
8 | _C.DIR = "ckpt/ade20k-resnet50dilated-ppm_deepsup"
9 |
10 | # -----------------------------------------------------------------------------
11 | # Dataset
12 | # -----------------------------------------------------------------------------
13 | _C.DATASET = CN()
14 | _C.DATASET.root_dataset = "./data/"
15 | _C.DATASET.list_train = "./data/training.odgt"
16 | _C.DATASET.list_val = "./data/validation.odgt"
17 | _C.DATASET.num_class = 150
18 | # multiscale train/test, size of short edge (int or tuple)
19 | _C.DATASET.imgSizes = (300, 375, 450, 525, 600)
20 | # maximum input image size of long edge
21 | _C.DATASET.imgMaxSize = 1000
22 | # maxmimum downsampling rate of the network
23 | _C.DATASET.padding_constant = 8
24 | # downsampling rate of the segmentation label
25 | _C.DATASET.segm_downsampling_rate = 8
26 | # randomly horizontally flip images when train/test
27 | _C.DATASET.random_flip = True
28 |
29 | # -----------------------------------------------------------------------------
30 | # Model
31 | # -----------------------------------------------------------------------------
32 | _C.MODEL = CN()
33 | # architecture of net_encoder
34 | _C.MODEL.arch_encoder = "resnet50dilated"
35 | # architecture of net_decoder
36 | _C.MODEL.arch_decoder = "ppm_deepsup"
37 | # weights to finetune net_encoder
38 | _C.MODEL.weights_encoder = ""
39 | # weights to finetune net_decoder
40 | _C.MODEL.weights_decoder = ""
41 | # number of feature channels between encoder and decoder
42 | _C.MODEL.fc_dim = 2048
43 |
44 | # -----------------------------------------------------------------------------
45 | # Training
46 | # -----------------------------------------------------------------------------
47 | _C.TRAIN = CN()
48 | _C.TRAIN.batch_size_per_gpu = 2
49 | # epochs to train for
50 | _C.TRAIN.num_epoch = 20
51 | # epoch to start training. useful if continue from a checkpoint
52 | _C.TRAIN.start_epoch = 0
53 | # iterations of each epoch (irrelevant to batch size)
54 | _C.TRAIN.epoch_iters = 5000
55 |
56 | _C.TRAIN.optim = "SGD"
57 | _C.TRAIN.lr_encoder = 0.02
58 | _C.TRAIN.lr_decoder = 0.02
59 | # power in poly to drop LR
60 | _C.TRAIN.lr_pow = 0.9
61 | # momentum for sgd, beta1 for adam
62 | _C.TRAIN.beta1 = 0.9
63 | # weights regularizer
64 | _C.TRAIN.weight_decay = 1e-4
65 | # the weighting of deep supervision loss
66 | _C.TRAIN.deep_sup_scale = 0.4
67 | # fix bn params, only under finetuning
68 | _C.TRAIN.fix_bn = False
69 | # number of data loading workers
70 | _C.TRAIN.workers = 16
71 |
72 | # frequency to display
73 | _C.TRAIN.disp_iter = 20
74 | # manual seed
75 | _C.TRAIN.seed = 304
76 |
77 | # -----------------------------------------------------------------------------
78 | # Validation
79 | # -----------------------------------------------------------------------------
80 | _C.VAL = CN()
81 | # currently only supports 1
82 | _C.VAL.batch_size = 1
83 | # output visualization during validation
84 | _C.VAL.visualize = False
85 | # the checkpoint to evaluate on
86 | _C.VAL.checkpoint = "epoch_20.pth"
87 |
88 | # -----------------------------------------------------------------------------
89 | # Testing
90 | # -----------------------------------------------------------------------------
91 | _C.TEST = CN()
92 | # currently only supports 1
93 | _C.TEST.batch_size = 1
94 | # the checkpoint to test on
95 | _C.TEST.checkpoint = "epoch_20.pth"
96 | # folder to output visualization results
97 | _C.TEST.result = "./"
98 |
--------------------------------------------------------------------------------
/config/vsp-hrnetv2.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 32
9 | segm_downsampling_rate: 4
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "hrnetv2"
14 | arch_decoder: "c1"
15 | fc_dim: 720
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 30
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_30.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_30.pth"
40 | result: "./"
41 |
42 | DIR: "/home/miaojiaxu/jiaxu_2/semantic_seg/ade20k-hrnetv2-c1_pretrain"
43 |
--------------------------------------------------------------------------------
/config/vsp-mobilenetv2dilated-c1_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "mobilenetv2dilated"
14 | arch_decoder: "c1_deepsup"
15 | fc_dim: 320
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 3
19 | num_epoch: 20
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_20.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_20.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-mobilenetv2dilated-c1_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-mobilenetv2dilated-ppm_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "mobilenetv2dilated"
14 | arch_decoder: "ppm_deepsup"
15 | fc_dim: 320
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 3
19 | num_epoch: 20
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_20.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_20.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-mobilenetv2dilated-c1_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet101-upernet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 32
9 | segm_downsampling_rate: 4
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101"
14 | arch_decoder: "upernet"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 50
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_40.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_40.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet101-upernet"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet101dilated-deeplab.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101dilated"
14 | arch_decoder: "deeplab"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet101dilated-nonlocal2d.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101dilated"
14 | arch_decoder: "nonlocal2d"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet101dilated-ocr_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101dilated"
14 | arch_decoder: "ocrnet_deepsup"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet101dilated-ppm_clip.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101dilated"
14 | arch_decoder: "ppm_clip"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet101dilated-ppm_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101dilated"
14 | arch_decoder: "ppm_deepsup"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet101dilated-ppm_deepsup_clip.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101dilated"
14 | arch_decoder: "ppm_deepsup_clip"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet101dilated_tdnet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101"
14 | arch_decoder: "deeplab"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet18dilated-ppm_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet18dilated"
14 | arch_decoder: "ppm_deepsup"
15 | fc_dim: 512
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 20
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_20.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_20.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet18dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet18dilated-ppm_deepsup_clip.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet18"
14 | arch_decoder: "ppm_deepsup_clip"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet50-upernet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 32
9 | segm_downsampling_rate: 4
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet50"
14 | arch_decoder: "upernet"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 30
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_30.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_30.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50-upernet"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet50dilated-deeplab.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet50dilated"
14 | arch_decoder: "deeplab"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet50dilated-ppm_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet50dilated"
14 | arch_decoder: "ppm_deepsup"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 20
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_20.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_20.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet50dilated-ppm_deepsup_clip.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet50dilated"
14 | arch_decoder: "ppm_deepsup_clip"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/vsp-resnet50dilated-tdnet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet50"
14 | arch_decoder: "deeplab"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/lib/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
3 |
--------------------------------------------------------------------------------
/lib/nn/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/nn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/lib/nn/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/nn/modules/__pycache__/batchnorm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/modules/__pycache__/batchnorm.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/nn/modules/__pycache__/comm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/modules/__pycache__/comm.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/nn/modules/__pycache__/replicate.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/modules/__pycache__/replicate.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/nn/modules/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def register_slave(self, identifier):
79 | """
80 | Register an slave device.
81 |
82 | Args:
83 | identifier: an identifier, usually is the device id.
84 |
85 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
86 |
87 | """
88 | if self._activated:
89 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
90 | self._activated = False
91 | self._registry.clear()
92 | future = FutureResult()
93 | self._registry[identifier] = _MasterRegistry(future)
94 | return SlavePipe(identifier, self._queue, future)
95 |
96 | def run_master(self, master_msg):
97 | """
98 | Main entry for the master device in each forward pass.
99 | The messages were first collected from each devices (including the master device), and then
100 | an callback will be invoked to compute the message to be sent back to each devices
101 | (including the master device).
102 |
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 |
107 | Returns: the message to be sent back to the master device.
108 |
109 | """
110 | self._activated = True
111 |
112 | intermediates = [(0, master_msg)]
113 | for i in range(self.nr_slaves):
114 | intermediates.append(self._queue.get())
115 |
116 | results = self._master_callback(intermediates)
117 | assert results[0][0] == 0, 'The first result should belongs to the master.'
118 |
119 | for i, res in results:
120 | if i == 0:
121 | continue
122 | self._registry[i].result.put(res)
123 |
124 | for i in range(self.nr_slaves):
125 | assert self._queue.get() is True
126 |
127 | return results[0][1]
128 |
129 | @property
130 | def nr_slaves(self):
131 | return len(self._registry)
132 |
--------------------------------------------------------------------------------
/lib/nn/modules/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/lib/nn/modules/tests/test_numeric_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_numeric_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm.unittest import TorchTestCase
16 |
17 |
18 | def handy_var(a, unbias=True):
19 | n = a.size(0)
20 | asum = a.sum(dim=0)
21 | as_sum = (a ** 2).sum(dim=0) # a square sum
22 | sumvar = as_sum - asum * asum / n
23 | if unbias:
24 | return sumvar / (n - 1)
25 | else:
26 | return sumvar / n
27 |
28 |
29 | class NumericTestCase(TorchTestCase):
30 | def testNumericBatchNorm(self):
31 | a = torch.rand(16, 10)
32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33 | bn.train()
34 |
35 | a_var1 = Variable(a, requires_grad=True)
36 | b_var1 = bn(a_var1)
37 | loss1 = b_var1.sum()
38 | loss1.backward()
39 |
40 | a_var2 = Variable(a, requires_grad=True)
41 | a_mean2 = a_var2.mean(dim=0, keepdim=True)
42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44 | b_var2 = (a_var2 - a_mean2) / a_std2
45 | loss2 = b_var2.sum()
46 | loss2.backward()
47 |
48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49 | self.assertTensorClose(bn.running_var, handy_var(a))
50 | self.assertTensorClose(a_var1.data, a_var2.data)
51 | self.assertTensorClose(b_var1.data, b_var2.data)
52 | self.assertTensorClose(a_var1.grad, a_var2.grad)
53 |
54 |
55 | if __name__ == '__main__':
56 | unittest.main()
57 |
--------------------------------------------------------------------------------
/lib/nn/modules/tests/test_sync_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_sync_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16 | from sync_batchnorm.unittest import TorchTestCase
17 |
18 |
19 | def handy_var(a, unbias=True):
20 | n = a.size(0)
21 | asum = a.sum(dim=0)
22 | as_sum = (a ** 2).sum(dim=0) # a square sum
23 | sumvar = as_sum - asum * asum / n
24 | if unbias:
25 | return sumvar / (n - 1)
26 | else:
27 | return sumvar / n
28 |
29 |
30 | def _find_bn(module):
31 | for m in module.modules():
32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33 | return m
34 |
35 |
36 | class SyncTestCase(TorchTestCase):
37 | def _syncParameters(self, bn1, bn2):
38 | bn1.reset_parameters()
39 | bn2.reset_parameters()
40 | if bn1.affine and bn2.affine:
41 | bn2.weight.data.copy_(bn1.weight.data)
42 | bn2.bias.data.copy_(bn1.bias.data)
43 |
44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45 | """Check the forward and backward for the customized batch normalization."""
46 | bn1.train(mode=is_train)
47 | bn2.train(mode=is_train)
48 |
49 | if cuda:
50 | input = input.cuda()
51 |
52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53 |
54 | input1 = Variable(input, requires_grad=True)
55 | output1 = bn1(input1)
56 | output1.sum().backward()
57 | input2 = Variable(input, requires_grad=True)
58 | output2 = bn2(input2)
59 | output2.sum().backward()
60 |
61 | self.assertTensorClose(input1.data, input2.data)
62 | self.assertTensorClose(output1.data, output2.data)
63 | self.assertTensorClose(input1.grad, input2.grad)
64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66 |
67 | def testSyncBatchNormNormalTrain(self):
68 | bn = nn.BatchNorm1d(10)
69 | sync_bn = SynchronizedBatchNorm1d(10)
70 |
71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72 |
73 | def testSyncBatchNormNormalEval(self):
74 | bn = nn.BatchNorm1d(10)
75 | sync_bn = SynchronizedBatchNorm1d(10)
76 |
77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78 |
79 | def testSyncBatchNormSyncTrain(self):
80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83 |
84 | bn.cuda()
85 | sync_bn.cuda()
86 |
87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88 |
89 | def testSyncBatchNormSyncEval(self):
90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93 |
94 | bn.cuda()
95 | sync_bn.cuda()
96 |
97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98 |
99 | def testSyncBatchNorm2DSyncTrain(self):
100 | bn = nn.BatchNorm2d(10)
101 | sync_bn = SynchronizedBatchNorm2d(10)
102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103 |
104 | bn.cuda()
105 | sync_bn.cuda()
106 |
107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108 |
109 |
110 | if __name__ == '__main__':
111 | unittest.main()
112 |
--------------------------------------------------------------------------------
/lib/nn/modules/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/lib/nn/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
2 |
--------------------------------------------------------------------------------
/lib/nn/parallel/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/parallel/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/nn/parallel/__pycache__/data_parallel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/nn/parallel/__pycache__/data_parallel.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/nn/parallel/data_parallel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf8 -*-
2 |
3 | import torch.cuda as cuda
4 | import torch.nn as nn
5 | import torch
6 | import collections
7 | from torch.nn.parallel._functions import Gather
8 |
9 |
10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
11 |
12 |
13 | def async_copy_to(obj, dev, main_stream=None):
14 | if torch.is_tensor(obj):
15 | v = obj.cuda(dev, non_blocking=True)
16 | if main_stream is not None:
17 | v.data.record_stream(main_stream)
18 | return v
19 | elif isinstance(obj, collections.Mapping):
20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
21 | elif isinstance(obj, collections.Sequence):
22 | return [async_copy_to(o, dev, main_stream) for o in obj]
23 | else:
24 | return obj
25 |
26 |
27 | def dict_gather(outputs, target_device, dim=0):
28 | """
29 | Gathers variables from different GPUs on a specified device
30 | (-1 means the CPU), with dictionary support.
31 | """
32 | def gather_map(outputs):
33 | out = outputs[0]
34 | if torch.is_tensor(out):
35 | # MJY(20180330) HACK:: force nr_dims > 0
36 | if out.dim() == 0:
37 | outputs = [o.unsqueeze(0) for o in outputs]
38 | return Gather.apply(target_device, dim, *outputs)
39 | elif out is None:
40 | return None
41 | elif isinstance(out, collections.Mapping):
42 | return {k: gather_map([o[k] for o in outputs]) for k in out}
43 | elif isinstance(out, collections.Sequence):
44 | return type(out)(map(gather_map, zip(*outputs)))
45 | return gather_map(outputs)
46 |
47 |
48 | class DictGatherDataParallel(nn.DataParallel):
49 | def gather(self, outputs, output_device):
50 | return dict_gather(outputs, output_device, dim=self.dim)
51 |
52 |
53 | class UserScatteredDataParallel(DictGatherDataParallel):
54 | def scatter(self, inputs, kwargs, device_ids):
55 | assert len(inputs) == 1
56 | inputs = inputs[0]
57 | inputs = _async_copy_stream(inputs, device_ids)
58 | inputs = [[i] for i in inputs]
59 | assert len(kwargs) == 0
60 | kwargs = [{} for _ in range(len(inputs))]
61 |
62 | return inputs, kwargs
63 |
64 |
65 | def user_scattered_collate(batch):
66 | return batch
67 |
68 |
69 | def _async_copy(inputs, device_ids):
70 | nr_devs = len(device_ids)
71 | assert type(inputs) in (tuple, list)
72 | assert len(inputs) == nr_devs
73 |
74 | outputs = []
75 | for i, dev in zip(inputs, device_ids):
76 | with cuda.device(dev):
77 | outputs.append(async_copy_to(i, dev))
78 |
79 | return tuple(outputs)
80 |
81 |
82 | def _async_copy_stream(inputs, device_ids):
83 | nr_devs = len(device_ids)
84 | assert type(inputs) in (tuple, list)
85 | assert len(inputs) == nr_devs
86 |
87 | outputs = []
88 | streams = [_get_stream(d) for d in device_ids]
89 | for i, dev, stream in zip(inputs, device_ids, streams):
90 | with cuda.device(dev):
91 | main_stream = cuda.current_stream()
92 | with cuda.stream(stream):
93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream))
94 | main_stream.wait_stream(stream)
95 |
96 | return outputs
97 |
98 |
99 | """Adapted from: torch/nn/parallel/_functions.py"""
100 | # background streams used for copying
101 | _streams = None
102 |
103 |
104 | def _get_stream(device):
105 | """Gets a background stream for copying between CPU and GPU"""
106 | global _streams
107 | if device == -1:
108 | return None
109 | if _streams is None:
110 | _streams = [None] * cuda.device_count()
111 | if _streams[device] is None: _streams[device] = cuda.Stream(device)
112 | return _streams[device]
113 |
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .th import *
2 |
--------------------------------------------------------------------------------
/lib/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/th.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/lib/utils/__pycache__/th.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .dataset import Dataset, TensorDataset, ConcatDataset
3 | from .dataloader import DataLoader
4 |
--------------------------------------------------------------------------------
/lib/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import warnings
3 |
4 | from torch._utils import _accumulate
5 | from torch import randperm
6 |
7 |
8 | class Dataset(object):
9 | """An abstract class representing a Dataset.
10 |
11 | All other datasets should subclass it. All subclasses should override
12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13 | supporting integer indexing in range from 0 to len(self) exclusive.
14 | """
15 |
16 | def __getitem__(self, index):
17 | raise NotImplementedError
18 |
19 | def __len__(self):
20 | raise NotImplementedError
21 |
22 | def __add__(self, other):
23 | return ConcatDataset([self, other])
24 |
25 |
26 | class TensorDataset(Dataset):
27 | """Dataset wrapping data and target tensors.
28 |
29 | Each sample will be retrieved by indexing both tensors along the first
30 | dimension.
31 |
32 | Arguments:
33 | data_tensor (Tensor): contains sample data.
34 | target_tensor (Tensor): contains sample targets (labels).
35 | """
36 |
37 | def __init__(self, data_tensor, target_tensor):
38 | assert data_tensor.size(0) == target_tensor.size(0)
39 | self.data_tensor = data_tensor
40 | self.target_tensor = target_tensor
41 |
42 | def __getitem__(self, index):
43 | return self.data_tensor[index], self.target_tensor[index]
44 |
45 | def __len__(self):
46 | return self.data_tensor.size(0)
47 |
48 |
49 | class ConcatDataset(Dataset):
50 | """
51 | Dataset to concatenate multiple datasets.
52 | Purpose: useful to assemble different existing datasets, possibly
53 | large-scale datasets as the concatenation operation is done in an
54 | on-the-fly manner.
55 |
56 | Arguments:
57 | datasets (iterable): List of datasets to be concatenated
58 | """
59 |
60 | @staticmethod
61 | def cumsum(sequence):
62 | r, s = [], 0
63 | for e in sequence:
64 | l = len(e)
65 | r.append(l + s)
66 | s += l
67 | return r
68 |
69 | def __init__(self, datasets):
70 | super(ConcatDataset, self).__init__()
71 | assert len(datasets) > 0, 'datasets should not be an empty iterable'
72 | self.datasets = list(datasets)
73 | self.cumulative_sizes = self.cumsum(self.datasets)
74 |
75 | def __len__(self):
76 | return self.cumulative_sizes[-1]
77 |
78 | def __getitem__(self, idx):
79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
80 | if dataset_idx == 0:
81 | sample_idx = idx
82 | else:
83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
84 | return self.datasets[dataset_idx][sample_idx]
85 |
86 | @property
87 | def cummulative_sizes(self):
88 | warnings.warn("cummulative_sizes attribute is renamed to "
89 | "cumulative_sizes", DeprecationWarning, stacklevel=2)
90 | return self.cumulative_sizes
91 |
92 |
93 | class Subset(Dataset):
94 | def __init__(self, dataset, indices):
95 | self.dataset = dataset
96 | self.indices = indices
97 |
98 | def __getitem__(self, idx):
99 | return self.dataset[self.indices[idx]]
100 |
101 | def __len__(self):
102 | return len(self.indices)
103 |
104 |
105 | def random_split(dataset, lengths):
106 | """
107 | Randomly split a dataset into non-overlapping new datasets of given lengths
108 | ds
109 |
110 | Arguments:
111 | dataset (Dataset): Dataset to be split
112 | lengths (iterable): lengths of splits to be produced
113 | """
114 | if sum(lengths) != len(dataset):
115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116 |
117 | indices = randperm(sum(lengths))
118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
119 |
--------------------------------------------------------------------------------
/lib/utils/data/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from .sampler import Sampler
4 | from torch.distributed import get_world_size, get_rank
5 |
6 |
7 | class DistributedSampler(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 |
10 | It is especially useful in conjunction with
11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12 | process can pass a DistributedSampler instance as a DataLoader sampler,
13 | and load a subset of the original dataset that is exclusive to it.
14 |
15 | .. note::
16 | Dataset is assumed to be of constant size.
17 |
18 | Arguments:
19 | dataset: Dataset used for sampling.
20 | num_replicas (optional): Number of processes participating in
21 | distributed training.
22 | rank (optional): Rank of the current process within num_replicas.
23 | """
24 |
25 | def __init__(self, dataset, num_replicas=None, rank=None):
26 | if num_replicas is None:
27 | num_replicas = get_world_size()
28 | if rank is None:
29 | rank = get_rank()
30 | self.dataset = dataset
31 | self.num_replicas = num_replicas
32 | self.rank = rank
33 | self.epoch = 0
34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35 | self.total_size = self.num_samples * self.num_replicas
36 |
37 | def __iter__(self):
38 | # deterministically shuffle based on epoch
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch)
41 | indices = list(torch.randperm(len(self.dataset), generator=g))
42 |
43 | # add extra samples to make it evenly divisible
44 | indices += indices[:(self.total_size - len(indices))]
45 | assert len(indices) == self.total_size
46 |
47 | # subsample
48 | offset = self.num_samples * self.rank
49 | indices = indices[offset:offset + self.num_samples]
50 | assert len(indices) == self.num_samples
51 |
52 | return iter(indices)
53 |
54 | def __len__(self):
55 | return self.num_samples
56 |
57 | def set_epoch(self, epoch):
58 | self.epoch = epoch
59 |
--------------------------------------------------------------------------------
/lib/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Sampler(object):
5 | """Base class for all Samplers.
6 |
7 | Every Sampler subclass has to provide an __iter__ method, providing a way
8 | to iterate over indices of dataset elements, and a __len__ method that
9 | returns the length of the returned iterators.
10 | """
11 |
12 | def __init__(self, data_source):
13 | pass
14 |
15 | def __iter__(self):
16 | raise NotImplementedError
17 |
18 | def __len__(self):
19 | raise NotImplementedError
20 |
21 |
22 | class SequentialSampler(Sampler):
23 | """Samples elements sequentially, always in the same order.
24 |
25 | Arguments:
26 | data_source (Dataset): dataset to sample from
27 | """
28 |
29 | def __init__(self, data_source):
30 | self.data_source = data_source
31 |
32 | def __iter__(self):
33 | return iter(range(len(self.data_source)))
34 |
35 | def __len__(self):
36 | return len(self.data_source)
37 |
38 |
39 | class RandomSampler(Sampler):
40 | """Samples elements randomly, without replacement.
41 |
42 | Arguments:
43 | data_source (Dataset): dataset to sample from
44 | """
45 |
46 | def __init__(self, data_source):
47 | self.data_source = data_source
48 |
49 | def __iter__(self):
50 | return iter(torch.randperm(len(self.data_source)).long())
51 |
52 | def __len__(self):
53 | return len(self.data_source)
54 |
55 |
56 | class SubsetRandomSampler(Sampler):
57 | """Samples elements randomly from a given list of indices, without replacement.
58 |
59 | Arguments:
60 | indices (list): a list of indices
61 | """
62 |
63 | def __init__(self, indices):
64 | self.indices = indices
65 |
66 | def __iter__(self):
67 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
68 |
69 | def __len__(self):
70 | return len(self.indices)
71 |
72 |
73 | class WeightedRandomSampler(Sampler):
74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
75 |
76 | Arguments:
77 | weights (list) : a list of weights, not necessary summing up to one
78 | num_samples (int): number of samples to draw
79 | replacement (bool): if ``True``, samples are drawn with replacement.
80 | If not, they are drawn without replacement, which means that when a
81 | sample index is drawn for a row, it cannot be drawn again for that row.
82 | """
83 |
84 | def __init__(self, weights, num_samples, replacement=True):
85 | self.weights = torch.DoubleTensor(weights)
86 | self.num_samples = num_samples
87 | self.replacement = replacement
88 |
89 | def __iter__(self):
90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
91 |
92 | def __len__(self):
93 | return self.num_samples
94 |
95 |
96 | class BatchSampler(object):
97 | """Wraps another sampler to yield a mini-batch of indices.
98 |
99 | Args:
100 | sampler (Sampler): Base sampler.
101 | batch_size (int): Size of mini-batch.
102 | drop_last (bool): If ``True``, the sampler will drop the last batch if
103 | its size would be less than ``batch_size``
104 |
105 | Example:
106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
110 | """
111 |
112 | def __init__(self, sampler, batch_size, drop_last):
113 | self.sampler = sampler
114 | self.batch_size = batch_size
115 | self.drop_last = drop_last
116 |
117 | def __iter__(self):
118 | batch = []
119 | for idx in self.sampler:
120 | batch.append(idx)
121 | if len(batch) == self.batch_size:
122 | yield batch
123 | batch = []
124 | if len(batch) > 0 and not self.drop_last:
125 | yield batch
126 |
127 | def __len__(self):
128 | if self.drop_last:
129 | return len(self.sampler) // self.batch_size
130 | else:
131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
132 |
--------------------------------------------------------------------------------
/lib/utils/th.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import numpy as np
4 | import collections
5 |
6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile']
7 |
8 | def as_variable(obj):
9 | if isinstance(obj, Variable):
10 | return obj
11 | if isinstance(obj, collections.Sequence):
12 | return [as_variable(v) for v in obj]
13 | elif isinstance(obj, collections.Mapping):
14 | return {k: as_variable(v) for k, v in obj.items()}
15 | else:
16 | return Variable(obj)
17 |
18 | def as_numpy(obj):
19 | if isinstance(obj, collections.Sequence):
20 | return [as_numpy(v) for v in obj]
21 | elif isinstance(obj, collections.Mapping):
22 | return {k: as_numpy(v) for k, v in obj.items()}
23 | elif isinstance(obj, Variable):
24 | return obj.data.cpu().numpy()
25 | elif torch.is_tensor(obj):
26 | return obj.cpu().numpy()
27 | else:
28 | return np.array(obj)
29 |
30 | def mark_volatile(obj):
31 | if torch.is_tensor(obj):
32 | obj = Variable(obj)
33 | if isinstance(obj, Variable):
34 | obj.no_grad = True
35 | return obj
36 | elif isinstance(obj, collections.Mapping):
37 | return {k: mark_volatile(o) for k, o in obj.items()}
38 | elif isinstance(obj, collections.Sequence):
39 | return [mark_volatile(o) for o in obj]
40 | else:
41 | return obj
42 |
--------------------------------------------------------------------------------
/models/.non_local2d.py.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/.non_local2d.py.swp
--------------------------------------------------------------------------------
/models/.propnet.py.swo:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/.propnet.py.swo
--------------------------------------------------------------------------------
/models/.propnet.py.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/.propnet.py.swp
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import ModelBuilder, SegmentationModule,SegmentationModule_clip,SegmentationModule_allclip,ClipWarpNet
2 | from .netwarp import NetWarp
3 | from .ETC import ETC
4 | from .non_local_models import Non_local3d,Non_local2d
5 | from .propnet import PropNet
6 | from .warp_our_merge import OurWarpMerge
7 | from .clip_psp import Clip_PSP
8 | from .clip_ocr import ClipOCRNet
9 | from .netwarp_ocr import NetWarp_ocr
10 | from .ETC_ocr import ETC_ocr
11 |
--------------------------------------------------------------------------------
/models/__pycache__/BiConvLSTM.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/BiConvLSTM.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ETC.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/ETC.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ETC_ocr.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/ETC_ocr.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/clip_ocr.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/clip_ocr.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/clip_psp.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/clip_psp.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/deeplab.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/deeplab.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/hrnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/hrnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/hrnet_clip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/hrnet_clip.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mobilenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/mobilenet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/netwarp.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/netwarp.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/netwarp_ocr.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/netwarp_ocr.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/non_local.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/non_local.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/non_local_models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/non_local_models.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ocrnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/ocrnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/propnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/propnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnext.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/resnext.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/warp_our.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/warp_our.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/warp_our_merge.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/__pycache__/warp_our_merge.cpython-37.pyc
--------------------------------------------------------------------------------
/models/deeplabv3/aspp.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class _ASPPModule(nn.Module):
8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
9 | super(_ASPPModule, self).__init__()
10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
11 | stride=1, padding=padding, dilation=dilation, bias=False)
12 | self.bn = BatchNorm(planes)
13 | self.relu = nn.ReLU()
14 |
15 | self._init_weight()
16 |
17 | def forward(self, x):
18 | x = self.atrous_conv(x)
19 | x = self.bn(x)
20 |
21 | return self.relu(x)
22 |
23 | def _init_weight(self):
24 | for m in self.modules():
25 | if isinstance(m, nn.Conv2d):
26 | torch.nn.init.kaiming_normal_(m.weight)
27 | elif isinstance(m, SynchronizedBatchNorm2d):
28 | m.weight.data.fill_(1)
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.BatchNorm2d):
31 | m.weight.data.fill_(1)
32 | m.bias.data.zero_()
33 |
34 | class ASPP(nn.Module):
35 | def __init__(self, backbone, output_stride, BatchNorm):
36 | super(ASPP, self).__init__()
37 | if backbone == 'drn':
38 | inplanes = 512
39 | elif backbone == 'mobilenet':
40 | inplanes = 320
41 | else:
42 | inplanes = 2048
43 | if output_stride == 16:
44 | dilations = [1, 6, 12, 18]
45 | elif output_stride == 8:
46 | dilations = [1, 12, 24, 36]
47 | else:
48 | raise NotImplementedError
49 |
50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
54 |
55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
57 | BatchNorm(256),
58 | nn.ReLU())
59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
60 | self.bn1 = BatchNorm(256)
61 | self.relu = nn.ReLU()
62 | self.dropout = nn.Dropout(0.5)
63 | self._init_weight()
64 |
65 | def forward(self, x):
66 | x1 = self.aspp1(x)
67 | x2 = self.aspp2(x)
68 | x3 = self.aspp3(x)
69 | x4 = self.aspp4(x)
70 | x5 = self.global_avg_pool(x)
71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
73 |
74 | x = self.conv1(x)
75 | x = self.bn1(x)
76 | x = self.relu(x)
77 |
78 | return self.dropout(x)
79 |
80 | def _init_weight(self):
81 | for m in self.modules():
82 | if isinstance(m, nn.Conv2d):
83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
84 | # m.weight.data.normal_(0, math.sqrt(2. / n))
85 | torch.nn.init.kaiming_normal_(m.weight)
86 | elif isinstance(m, SynchronizedBatchNorm2d):
87 | m.weight.data.fill_(1)
88 | m.bias.data.zero_()
89 | elif isinstance(m, nn.BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 |
93 |
94 | def build_aspp(backbone, output_stride, BatchNorm):
95 | return ASPP(backbone, output_stride, BatchNorm)
--------------------------------------------------------------------------------
/models/deeplabv3/decoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, num_classes, backbone, BatchNorm,args):
9 | super(Decoder, self).__init__()
10 | self.args = args
11 | if backbone == 'resnet' or backbone == 'drn':
12 | low_level_inplanes = 256
13 | elif backbone == 'xception':
14 | low_level_inplanes = 128
15 | elif backbone == 'mobilenet':
16 | low_level_inplanes = 24
17 | else:
18 | raise NotImplementedError
19 |
20 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
21 | self.bn1 = BatchNorm(48)
22 | self.relu = nn.ReLU()
23 | if self.args.deeplab_as_base:
24 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
25 | BatchNorm(256),
26 | nn.ReLU(),
27 | nn.Dropout(0.5),
28 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
29 | BatchNorm(256),
30 | nn.ReLU()
31 | )
32 | self.lastlast_conv = nn.Sequential(
33 | nn.Dropout(0.1),
34 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)
35 | )
36 | # nn.Dropout(0.1),
37 | # nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
38 | else:
39 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
40 | BatchNorm(256),
41 | nn.ReLU(),
42 | nn.Dropout(0.5),
43 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
44 | BatchNorm(256),
45 | nn.ReLU(),
46 | nn.Dropout(0.1),
47 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
48 | self._init_weight()
49 |
50 |
51 | def forward(self, x, low_level_feat):
52 | low_level_feat = self.conv1(low_level_feat)
53 | low_level_feat = self.bn1(low_level_feat)
54 | low_level_feat = self.relu(low_level_feat)
55 |
56 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
57 | x = torch.cat((x, low_level_feat), dim=1)
58 | if self.args.deeplab_as_base:
59 | x = self.last_conv(x)
60 | y = self.lastlast_conv(x)
61 | return y,x
62 | else:
63 | x = self.last_conv(x)
64 |
65 | return x
66 |
67 | def _init_weight(self):
68 | for m in self.modules():
69 | if isinstance(m, nn.Conv2d):
70 | torch.nn.init.kaiming_normal_(m.weight)
71 | elif isinstance(m, SynchronizedBatchNorm2d):
72 | m.weight.data.fill_(1)
73 | m.bias.data.zero_()
74 | elif isinstance(m, nn.BatchNorm2d):
75 | m.weight.data.fill_(1)
76 | m.bias.data.zero_()
77 |
78 | def build_decoder(num_classes, backbone, BatchNorm,args):
79 | return Decoder(num_classes, backbone, BatchNorm,args)
80 |
--------------------------------------------------------------------------------
/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | """
2 | This MobileNetV2 implementation is modified from the following repository:
3 | https://github.com/tonylins/pytorch-mobilenet-v2
4 | """
5 |
6 | import torch.nn as nn
7 | import math
8 | from .utils import load_url
9 | from models.sync_batchnorm import SynchronizedBatchNorm2d
10 |
11 | BatchNorm2d = SynchronizedBatchNorm2d
12 |
13 |
14 | __all__ = ['mobilenetv2']
15 |
16 |
17 | model_urls = {
18 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar',
19 | }
20 |
21 |
22 | def conv_bn(inp, oup, stride):
23 | return nn.Sequential(
24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
25 | BatchNorm2d(oup),
26 | nn.ReLU6(inplace=True)
27 | )
28 |
29 |
30 | def conv_1x1_bn(inp, oup):
31 | return nn.Sequential(
32 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
33 | BatchNorm2d(oup),
34 | nn.ReLU6(inplace=True)
35 | )
36 |
37 |
38 | class InvertedResidual(nn.Module):
39 | def __init__(self, inp, oup, stride, expand_ratio):
40 | super(InvertedResidual, self).__init__()
41 | self.stride = stride
42 | assert stride in [1, 2]
43 |
44 | hidden_dim = round(inp * expand_ratio)
45 | self.use_res_connect = self.stride == 1 and inp == oup
46 |
47 | if expand_ratio == 1:
48 | self.conv = nn.Sequential(
49 | # dw
50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
51 | BatchNorm2d(hidden_dim),
52 | nn.ReLU6(inplace=True),
53 | # pw-linear
54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
55 | BatchNorm2d(oup),
56 | )
57 | else:
58 | self.conv = nn.Sequential(
59 | # pw
60 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
61 | BatchNorm2d(hidden_dim),
62 | nn.ReLU6(inplace=True),
63 | # dw
64 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
65 | BatchNorm2d(hidden_dim),
66 | nn.ReLU6(inplace=True),
67 | # pw-linear
68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
69 | BatchNorm2d(oup),
70 | )
71 |
72 | def forward(self, x):
73 | if self.use_res_connect:
74 | return x + self.conv(x)
75 | else:
76 | return self.conv(x)
77 |
78 |
79 | class MobileNetV2(nn.Module):
80 | def __init__(self, n_class=1000, input_size=224, width_mult=1.):
81 | super(MobileNetV2, self).__init__()
82 | block = InvertedResidual
83 | input_channel = 32
84 | last_channel = 1280
85 | interverted_residual_setting = [
86 | # t, c, n, s
87 | [1, 16, 1, 1],
88 | [6, 24, 2, 2],
89 | [6, 32, 3, 2],
90 | [6, 64, 4, 2],
91 | [6, 96, 3, 1],
92 | [6, 160, 3, 2],
93 | [6, 320, 1, 1],
94 | ]
95 |
96 | # building first layer
97 | assert input_size % 32 == 0
98 | input_channel = int(input_channel * width_mult)
99 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
100 | self.features = [conv_bn(3, input_channel, 2)]
101 | # building inverted residual blocks
102 | for t, c, n, s in interverted_residual_setting:
103 | output_channel = int(c * width_mult)
104 | for i in range(n):
105 | if i == 0:
106 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
107 | else:
108 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
109 | input_channel = output_channel
110 | # building last several layers
111 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
112 | # make it nn.Sequential
113 | self.features = nn.Sequential(*self.features)
114 |
115 | # building classifier
116 | self.classifier = nn.Sequential(
117 | nn.Dropout(0.2),
118 | nn.Linear(self.last_channel, n_class),
119 | )
120 |
121 | self._initialize_weights()
122 |
123 | def forward(self, x):
124 | x = self.features(x)
125 | x = x.mean(3).mean(2)
126 | x = self.classifier(x)
127 | return x
128 |
129 | def _initialize_weights(self):
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
133 | m.weight.data.normal_(0, math.sqrt(2. / n))
134 | if m.bias is not None:
135 | m.bias.data.zero_()
136 | elif isinstance(m, BatchNorm2d):
137 | m.weight.data.fill_(1)
138 | m.bias.data.zero_()
139 | elif isinstance(m, nn.Linear):
140 | n = m.weight.size(1)
141 | m.weight.data.normal_(0, 0.01)
142 | m.bias.data.zero_()
143 |
144 |
145 | def mobilenetv2(pretrained=False, **kwargs):
146 | """Constructs a MobileNet_V2 model.
147 |
148 | Args:
149 | pretrained (bool): If True, returns a model pre-trained on ImageNet
150 | """
151 | model = MobileNetV2(n_class=1000, **kwargs)
152 | if pretrained:
153 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
154 | return model
155 |
--------------------------------------------------------------------------------
/models/ocr_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/ocr_modules/__init__.py
--------------------------------------------------------------------------------
/models/ocr_modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/ocr_modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/ocr_modules/__pycache__/spatial_ocr_block.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/ocr_modules/__pycache__/spatial_ocr_block.cpython-37.pyc
--------------------------------------------------------------------------------
/models/ocrnet.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: RainbowSecret
3 | ## Microsoft Research
4 | ## yuyua@microsoft.com
5 | ## Copyright (c) 2018
6 | ##
7 | ## This source code is licensed under the MIT-style license found in the
8 | ## LICENSE file in the root directory of this source tree
9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10 | import pdb
11 | import torch
12 | import torch.nn as nn
13 | from torch.nn import functional as F
14 | from models.sync_batchnorm import SynchronizedBatchNorm2d
15 | BatchNorm2d = SynchronizedBatchNorm2d
16 | BN_MOMENTUM = 0.1
17 |
18 | #from lib.models.backbones.backbone_selector import BackboneSelector
19 | #from lib.models.tools.module_helper import ModuleHelper
20 |
21 |
22 | class SpatialOCRNet(nn.Module):
23 | """
24 | Object-Contextual Representations for Semantic Segmentation,
25 | Yuan, Yuhui and Chen, Xilin and Wang, Jingdong
26 | """
27 | def __init__(self, num_class):
28 | self.inplanes = 128
29 | super(SpatialOCRNet, self).__init__()
30 | self.num_classes=num_class
31 | in_channels = [1024, 2048]
32 | self.conv_3x3 = nn.Sequential(
33 | nn.Conv2d(in_channels[1], 512, kernel_size=3, stride=1, padding=1),
34 | BatchNorm2d(512),
35 | nn.ReLU(inplace=True)
36 | )
37 |
38 | from models.ocr_modules.spatial_ocr_block import SpatialGather_Module, SpatialOCR_Module
39 | self.spatial_context_head = SpatialGather_Module(self.num_classes)
40 | self.spatial_ocr_head = SpatialOCR_Module(in_channels=512,
41 | key_channels=256,
42 | out_channels=512,
43 | scale=1,
44 | dropout=0.05
45 | )
46 |
47 | self.head = nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
48 | self.dsn_head = nn.Sequential(
49 | nn.Conv2d(in_channels[0], 512, kernel_size=3, stride=1, padding=1),
50 | BatchNorm2d(512),
51 | nn.ReLU(inplace=True),
52 | nn.Dropout2d(0.05),
53 | nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
54 | )
55 |
56 | def forward(self, x,segSize=None):
57 |
58 | x_dsn = self.dsn_head(x[-2])
59 | x = self.conv_3x3(x[-1])
60 | context = self.spatial_context_head(x, x_dsn)
61 | x = self.spatial_ocr_head(x, context)
62 | x = self.head(x)
63 |
64 | if segSize is not None: # is True during inference
65 | x = F.interpolate(
66 | x, size=segSize, mode='bilinear', align_corners=False)
67 | x = F.softmax(x, dim=1)
68 | return x
69 | else:
70 | x = F.log_softmax(x, dim=1)
71 | x_dsn = F.log_softmax(x_dsn, dim=1)
72 | return x,x_dsn
73 |
74 |
75 | #class ASPOCRNet(nn.Module):
76 | # """
77 | # Object-Contextual Representations for Semantic Segmentation,
78 | # Yuan, Yuhui and Chen, Xilin and Wang, Jingdong
79 | # """
80 | # def __init__(self, configer):
81 | # self.inplanes = 128
82 | # super(ASPOCRNet, self).__init__()
83 | # self.configer = configer
84 | # self.num_classes = self.configer.get('data', 'num_classes')
85 | # self.backbone = BackboneSelector(configer).get_backbone()
86 | #
87 | # # extra added layers
88 | # if "wide_resnet38" in self.configer.get('network', 'backbone'):
89 | # in_channels = [2048, 4096]
90 | # else:
91 | # in_channels = [1024, 2048]
92 | #
93 | # # we should increase the dilation rates as the output stride is larger
94 | # from lib.models.modules.spatial_ocr_block import SpatialOCR_ASP_Module
95 | # self.asp_ocr_head = SpatialOCR_ASP_Module(features=2048,
96 | # hidden_features=256,
97 | # out_features=256,
98 | # num_classes=self.num_classes,
99 | # bn_type=self.configer.get('network', 'bn_type'))
100 | #
101 | # self.head = nn.Conv2d(256, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
102 | # self.dsn_head = nn.Sequential(
103 | # nn.Conv2d(in_channels[0], 512, kernel_size=3, stride=1, padding=1),
104 | # ModuleHelper.BNReLU(512, bn_type=self.configer.get('network', 'bn_type')),
105 | # nn.Dropout2d(0.1),
106 | # nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True)
107 | # )
108 | #
109 | # def forward(self, x_):
110 | # x = self.backbone(x_)
111 | # x_dsn = self.dsn_head(x[-2])
112 | # x = self.asp_ocr_head(x[-1], x_dsn)
113 | # x = self.head(x)
114 | # x_dsn = F.interpolate(x_dsn, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True)
115 | # x = F.interpolate(x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True)
116 | # return x_dsn, x
117 |
--------------------------------------------------------------------------------
/models/resnext.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | from .utils import load_url
4 | from models.sync_batchnorm import SynchronizedBatchNorm2d
5 | BatchNorm2d = SynchronizedBatchNorm2d
6 |
7 |
8 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101
9 |
10 |
11 | model_urls = {
12 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth',
13 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth'
14 | }
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | "3x3 convolution with padding"
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=1, bias=False)
21 |
22 |
23 | class GroupBottleneck(nn.Module):
24 | expansion = 2
25 |
26 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None):
27 | super(GroupBottleneck, self).__init__()
28 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
29 | self.bn1 = BatchNorm2d(planes)
30 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
31 | padding=1, groups=groups, bias=False)
32 | self.bn2 = BatchNorm2d(planes)
33 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)
34 | self.bn3 = BatchNorm2d(planes * 2)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.downsample = downsample
37 | self.stride = stride
38 |
39 | def forward(self, x):
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv3(out)
51 | out = self.bn3(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 |
56 | out += residual
57 | out = self.relu(out)
58 |
59 | return out
60 |
61 |
62 | class ResNeXt(nn.Module):
63 |
64 | def __init__(self, block, layers, groups=32, num_classes=1000):
65 | self.inplanes = 128
66 | super(ResNeXt, self).__init__()
67 | self.conv1 = conv3x3(3, 64, stride=2)
68 | self.bn1 = BatchNorm2d(64)
69 | self.relu1 = nn.ReLU(inplace=True)
70 | self.conv2 = conv3x3(64, 64)
71 | self.bn2 = BatchNorm2d(64)
72 | self.relu2 = nn.ReLU(inplace=True)
73 | self.conv3 = conv3x3(64, 128)
74 | self.bn3 = BatchNorm2d(128)
75 | self.relu3 = nn.ReLU(inplace=True)
76 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
77 |
78 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups)
79 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups)
80 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups)
81 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups)
82 | self.avgpool = nn.AvgPool2d(7, stride=1)
83 | self.fc = nn.Linear(1024 * block.expansion, num_classes)
84 |
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups
88 | m.weight.data.normal_(0, math.sqrt(2. / n))
89 | elif isinstance(m, BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 |
93 | def _make_layer(self, block, planes, blocks, stride=1, groups=1):
94 | downsample = None
95 | if stride != 1 or self.inplanes != planes * block.expansion:
96 | downsample = nn.Sequential(
97 | nn.Conv2d(self.inplanes, planes * block.expansion,
98 | kernel_size=1, stride=stride, bias=False),
99 | BatchNorm2d(planes * block.expansion),
100 | )
101 |
102 | layers = []
103 | layers.append(block(self.inplanes, planes, stride, groups, downsample))
104 | self.inplanes = planes * block.expansion
105 | for i in range(1, blocks):
106 | layers.append(block(self.inplanes, planes, groups=groups))
107 |
108 | return nn.Sequential(*layers)
109 |
110 | def forward(self, x):
111 | x = self.relu1(self.bn1(self.conv1(x)))
112 | x = self.relu2(self.bn2(self.conv2(x)))
113 | x = self.relu3(self.bn3(self.conv3(x)))
114 | x = self.maxpool(x)
115 |
116 | x = self.layer1(x)
117 | x = self.layer2(x)
118 | x = self.layer3(x)
119 | x = self.layer4(x)
120 |
121 | x = self.avgpool(x)
122 | x = x.view(x.size(0), -1)
123 | x = self.fc(x)
124 |
125 | return x
126 |
127 |
128 | '''
129 | def resnext50(pretrained=False, **kwargs):
130 | """Constructs a ResNet-50 model.
131 |
132 | Args:
133 | pretrained (bool): If True, returns a model pre-trained on Places
134 | """
135 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs)
136 | if pretrained:
137 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False)
138 | return model
139 | '''
140 |
141 |
142 | def resnext101(pretrained=False, **kwargs):
143 | """Constructs a ResNet-101 model.
144 |
145 | Args:
146 | pretrained (bool): If True, returns a model pre-trained on Places
147 | """
148 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs)
149 | if pretrained:
150 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False)
151 | return model
152 |
153 |
154 | # def resnext152(pretrained=False, **kwargs):
155 | # """Constructs a ResNeXt-152 model.
156 | #
157 | # Args:
158 | # pretrained (bool): If True, returns a model pre-trained on Places
159 | # """
160 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs)
161 | # if pretrained:
162 | # model.load_state_dict(load_url(model_urls['resnext152']))
163 | # return model
164 |
--------------------------------------------------------------------------------
/models/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .batchnorm import patch_sync_batchnorm, convert_model
13 | from .replicate import DataParallelWithCallback, patch_replication_callback
14 |
--------------------------------------------------------------------------------
/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/sync_batchnorm/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc
--------------------------------------------------------------------------------
/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/sync_batchnorm/__pycache__/comm.cpython-37.pyc
--------------------------------------------------------------------------------
/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/sync_batchnorm/__pycache__/replicate.cpython-37.pyc
--------------------------------------------------------------------------------
/models/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNorm2dReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/models/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/models/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/models/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/models/td4_psp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__init__.py
--------------------------------------------------------------------------------
/models/td4_psp/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/td4_psp/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/models/td4_psp/__pycache__/td4_psp.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__pycache__/td4_psp.cpython-37.pyc
--------------------------------------------------------------------------------
/models/td4_psp/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/models/td4_psp/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 |
5 | __all__ = ['SegmentationLosses', 'OhemCELoss2D']
6 |
7 | class SegmentationLosses(nn.CrossEntropyLoss):
8 | """2D Cross Entropy Loss with Auxilary Loss"""
9 | def __init__(self,
10 | weight=None,
11 | ignore_index=-1):
12 |
13 | super(SegmentationLosses, self).__init__(weight, None, ignore_index)
14 |
15 | def forward(self, pred, target):
16 | return super(SegmentationLosses, self).forward(pred, target)
17 |
18 |
19 |
20 |
21 | class OhemCELoss2D(nn.CrossEntropyLoss):
22 | """2D Cross Entropy Loss with Auxilary Loss"""
23 | def __init__(self,
24 | n_min,
25 | thresh=0.7,
26 | ignore_index=-1):
27 |
28 | super(OhemCELoss2D, self).__init__(None, None, ignore_index, reduction='none')
29 |
30 | self.thresh = -math.log(thresh)
31 | self.n_min = n_min
32 | self.ignore_index = ignore_index
33 |
34 | def forward(self, pred, target):
35 | return self.OhemCELoss(pred, target)
36 |
37 | def OhemCELoss(self, logits, labels):
38 | loss = super(OhemCELoss2D, self).forward(logits, labels).view(-1)
39 | loss, _ = torch.sort(loss, descending=True)
40 | if loss[self.n_min] > self.thresh:
41 | loss = loss[loss>self.thresh]
42 | else:
43 | loss = loss[:self.n_min]
44 | return torch.mean(loss)
--------------------------------------------------------------------------------
/models/td4_psp/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sssdddwww2/CVPR2021_VSPW_Implement/3ead088804fe34d3b826b656e026d50562a3f7bd/models/td4_psp/utils/__init__.py
--------------------------------------------------------------------------------
/models/td4_psp/utils/files.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import errno
4 | import shutil
5 | import hashlib
6 | from tqdm import tqdm
7 | import torch
8 |
9 | __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1']
10 |
11 | def save_checkpoint(state, args, is_best, filename='pretrained.pth.tar'):
12 | """Saves pretrained to disk"""
13 | directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname)
14 | if not os.path.exists(directory):
15 | os.makedirs(directory)
16 | filename = directory + filename
17 | torch.save(state, filename)
18 | if is_best:
19 | shutil.copyfile(filename, directory + 'model_best.pth.tar')
20 |
21 |
22 | def download(url, path=None, overwrite=False, sha1_hash=None):
23 | """Download an given URL
24 | Parameters
25 | ----------
26 | url : str
27 | URL to download
28 | path : str, optional
29 | Destination path to store downloaded file. By default stores to the
30 | current directory with same name as in url.
31 | overwrite : bool, optional
32 | Whether to overwrite destination file if already exists.
33 | sha1_hash : str, optional
34 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
35 | but doesn't match.
36 | Returns
37 | -------
38 | str
39 | The file path of the downloaded file.
40 | """
41 | if path is None:
42 | fname = url.split('/')[-1]
43 | else:
44 | path = os.path.expanduser(path)
45 | if os.path.isdir(path):
46 | fname = os.path.join(path, url.split('/')[-1])
47 | else:
48 | fname = path
49 |
50 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
51 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
52 | if not os.path.exists(dirname):
53 | os.makedirs(dirname)
54 |
55 | print('Downloading %s from %s...'%(fname, url))
56 | r = requests.get(url, stream=True)
57 | if r.status_code != 200:
58 | raise RuntimeError("Failed downloading url %s"%url)
59 | total_length = r.headers.get('content-length')
60 | with open(fname, 'wb') as f:
61 | if total_length is None: # no content length header
62 | for chunk in r.iter_content(chunk_size=1024):
63 | if chunk: # filter out keep-alive new chunks
64 | f.write(chunk)
65 | else:
66 | total_length = int(total_length)
67 | for chunk in tqdm(r.iter_content(chunk_size=1024),
68 | total=int(total_length / 1024. + 0.5),
69 | unit='KB', unit_scale=False, dynamic_ncols=True):
70 | f.write(chunk)
71 |
72 | if sha1_hash and not check_sha1(fname, sha1_hash):
73 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \
74 | 'The repo may be outdated or download may be incomplete. ' \
75 | 'If the "repo_url" is overridden, consider switching to ' \
76 | 'the default repo.'.format(fname))
77 |
78 | return fname
79 |
80 |
81 | def check_sha1(filename, sha1_hash):
82 | """Check whether the sha1 hash of the file content matches the expected hash.
83 | Parameters
84 | ----------
85 | filename : str
86 | Path to the file.
87 | sha1_hash : str
88 | Expected sha1 hash in hexadecimal digits.
89 | Returns
90 | -------
91 | bool
92 | Whether the file content matches the expected hash.
93 | """
94 | sha1 = hashlib.sha1()
95 | with open(filename, 'rb') as f:
96 | while True:
97 | data = f.read(1048576)
98 | if not data:
99 | break
100 | sha1.update(data)
101 |
102 | return sha1.hexdigest() == sha1_hash
103 |
104 |
105 | def mkdir(path):
106 | """make dir exists okay"""
107 | try:
108 | os.makedirs(path)
109 | except OSError as exc: # Python >2.5
110 | if exc.errno == errno.EEXIST and os.path.isdir(path):
111 | pass
112 | else:
113 | raise
114 |
--------------------------------------------------------------------------------
/models/td4_psp/utils/model_store.py:
--------------------------------------------------------------------------------
1 | """Model store which provides pretrained models."""
2 | from __future__ import print_function
3 | __all__ = ['get_model_file', 'purge']
4 | import os
5 | import zipfile
6 |
7 | from .files import download, check_sha1
8 |
9 | _model_sha1 = {name: checksum for checksum, name in [
10 | ('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'),
11 | ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
12 | ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
13 | ('da4785cfc837bf00ef95b52fb218feefe703011f', 'wideresnet38'),
14 | ('b41562160173ee2e979b795c551d3c7143b1e5b5', 'wideresnet50'),
15 | ('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'),
16 | ('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50_ade'),
17 | ('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
18 | ('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'),
19 | ('075195c5237b778c718fd73ceddfa1376c18dfd0', 'deeplab_resnet50_ade'),
20 | ('5ee47ee28b480cc781a195d13b5806d5bbc616bf', 'encnet_resnet101_coco'),
21 | ('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50_pcontext'),
22 | ('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101_pcontext'),
23 | ('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50_ade'),
24 | ('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101_ade'),
25 | ]}
26 |
27 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
28 | _url_format = '{repo_url}encoding/models/{file_name}.zip'
29 |
30 | def short_hash(name):
31 | if name not in _model_sha1:
32 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
33 | return _model_sha1[name][:8]
34 |
35 | def get_model_file(name, root=os.path.join('~', '.encoding', 'models')):
36 | r"""Return location for the pretrained on local file system.
37 |
38 | This function will download from online model zoo when model cannot be found or has mismatch.
39 | The root directory will be created if it doesn't exist.
40 |
41 | Parameters
42 | ----------
43 | name : str
44 | Name of the model.
45 | root : str, default '~/.encoding/models'
46 | Location for keeping the model parameters.
47 |
48 | Returns
49 | -------
50 | file_path
51 | Path to the requested pretrained model file.
52 | """
53 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name))
54 | root = os.path.expanduser(root)
55 | file_path = os.path.join(root, file_name+'.pth')
56 | sha1_hash = _model_sha1[name]
57 | if os.path.exists(file_path):
58 | if check_sha1(file_path, sha1_hash):
59 | return file_path
60 | else:
61 | print('Mismatch in the content of model file {} detected.' +
62 | ' Downloading again.'.format(file_path))
63 | else:
64 | print('Model file {} is not found. Downloading.'.format(file_path))
65 |
66 | if not os.path.exists(root):
67 | os.makedirs(root)
68 |
69 | zip_file_path = os.path.join(root, file_name+'.zip')
70 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url)
71 | if repo_url[-1] != '/':
72 | repo_url = repo_url + '/'
73 | download(_url_format.format(repo_url=repo_url, file_name=file_name),
74 | path=zip_file_path,
75 | overwrite=True)
76 | with zipfile.ZipFile(zip_file_path) as zf:
77 | zf.extractall(root)
78 | os.remove(zip_file_path)
79 |
80 | if check_sha1(file_path, sha1_hash):
81 | return file_path
82 | else:
83 | raise ValueError('Downloaded file has different hash. Please try again.')
84 |
85 | def purge(root=os.path.join('~', '.encoding', 'models')):
86 | r"""Purge all pretrained model files in local file store.
87 |
88 | Parameters
89 | ----------
90 | root : str, default '~/.encoding/models'
91 | Location for keeping the model parameters.
92 | """
93 | root = os.path.expanduser(root)
94 | files = os.listdir(root)
95 | for f in files:
96 | if f.endswith(".pth"):
97 | os.remove(os.path.join(root, f))
98 |
99 | def pretrained_model_list():
100 | return list(_model_sha1.keys())
101 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | try:
4 | from urllib import urlretrieve
5 | except ImportError:
6 | from urllib.request import urlretrieve
7 | import torch
8 |
9 |
10 | def load_url(url, model_dir='./pretrained', map_location=None):
11 | if not os.path.exists(model_dir):
12 | os.makedirs(model_dir)
13 | filename = url.split('/')[-1]
14 | cached_file = os.path.join(model_dir, filename)
15 | if not os.path.exists(cached_file):
16 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
17 | urlretrieve(url, cached_file)
18 | return torch.load(cached_file, map_location=map_location)
19 |
--------------------------------------------------------------------------------
/scripts/run_etc.sh:
--------------------------------------------------------------------------------
1 | DATAROOT="your/path/to/VSPW_480p"
2 |
3 |
4 |
5 | SAVE="./savemodel"
6 | DATAROOT2='data2'
7 | BATCHSIZE=8
8 | WORKERS=12
9 | CROPSIZE=479
10 |
11 |
12 |
13 | START_GPU=0
14 | GPU_NUM=2
15 | TRAINFPS=1
16 | EPOCH=120
17 | LR=0.002
18 | VAL=False
19 | USETWODATA=False
20 | LESSLABEL=False
21 | CLIPNUM=2
22 | DILATION=0
23 | CLIPUP=False
24 | CLIPMIDDLE=False
25 | OTHERGT=False
26 | PROPCLIP2=False
27 | EARLYFUSE=True
28 | EARLYCAT=False
29 | CONVLSTM=False
30 | NON_LOCAL=False
31 |
32 |
33 |
34 | FIX=False
35 | ALLSUP=True
36 | ALLSUPSCALE=0.5
37 | LINEAR_COM=True
38 | DISTSOFTMAX=False
39 | DISTNEAREST=False
40 | TEMP=0.05
41 |
42 | DILATION2="3,6,9"
43 |
44 | CLIPOCR_ALL=False
45 | USEMEMORY=True
46 |
47 | METHOD='etc'
48 |
49 |
50 |
51 | PREROOT=''
52 | PRE_ENC="./imgnetpre/resnet101-imagenet.pth"
53 | MAXDIST='3'
54 | #########
55 | ARCH=resnet101
56 | CFG='vsp-'$ARCH'dilated-ppm_deepsup_clip.yaml'
57 | #CFG='vsp-'$ARCH'dilated-ppm_clip.yaml'
58 | #CFG="vsp-"$ARCH"dilated_tdnet.yaml"
59 | PREDIR="../data/imgnetpre/"$ARCH"-imagenet.pth"
60 |
61 | NAME='newjob_lr'$LR'_bs'$BATCHSIZE'_epoch'$EPOCH'_FPS'$TRAINFPS'_clipnum'$CLIPNUM"_dilation"$DILATION"_fix"$FIX"_tdnet"$TDNET"_arch"$ARCH'_method'$METHOD'_DISTSOFTMAX'$DISTSOFTMAX'_DISTNEAREST'$DISTNEAREST"_CLIPOCR_ALL"$CLIPOCR_ALL"_USEMEMORY"$USEMEMORY"imgnetpre"
62 |
63 |
64 | SAVEROOT=$SAVE"/"$NAME
65 | python train_clip2.py --cfg config/$CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --clip_num $CLIPNUM --dilation_num $DILATION --clip_up $CLIPUP --clip_middle $CLIPMIDDLE --fix $FIX --othergt $OTHERGT --propclip2 $PROPCLIP2 --earlyfuse $EARLYFUSE --early_usecat $EARLYCAT --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --pre_enc $PRE_ENC --max_distances $MAXDIST --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
66 |
67 |
68 |
69 | ###inference
70 | echo 'val'
71 | BATCHSIZE=1
72 | GPU_NUM=1
73 | ISSAVE=True
74 | LESSLABLE=False
75 | USE720p=False
76 | EARLYFUSE=False
77 | EARLYCAT=False
78 |
79 | #CLIPNUM=5
80 |
81 | IMGSAVEROOT='./clipsaveimg/'$NAME
82 |
83 | LOAD=$SAVEROOT'/model_epoch_'$EPOCH'.pth'
84 |
85 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'val' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
86 |
87 | echo 'test'
88 |
89 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'test' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/scripts/run_netwarp.sh:
--------------------------------------------------------------------------------
1 | DATAROOT="your/path/to/VSPW_480p"
2 |
3 |
4 |
5 | SAVE="./savemodel"
6 | DATAROOT2='data2'
7 | BATCHSIZE=8
8 | WORKERS=12
9 | CROPSIZE=479
10 |
11 |
12 |
13 | START_GPU=0
14 | GPU_NUM=2
15 | TRAINFPS=1
16 | EPOCH=120
17 | LR=0.002
18 | VAL=False
19 | USETWODATA=False
20 | LESSLABEL=False
21 | CLIPNUM=2
22 | DILATION=0
23 | CLIPUP=False
24 | CLIPMIDDLE=False
25 | OTHERGT=False
26 | PROPCLIP2=False
27 | EARLYFUSE=True
28 | EARLYCAT=False
29 | CONVLSTM=False
30 | NON_LOCAL=False
31 |
32 |
33 |
34 | FIX=False
35 | ALLSUP=True
36 | ALLSUPSCALE=0.5
37 | LINEAR_COM=True
38 | DISTSOFTMAX=False
39 | DISTNEAREST=False
40 | TEMP=0.05
41 |
42 | DILATION2="3,6,9"
43 |
44 | CLIPOCR_ALL=False
45 | USEMEMORY=True
46 |
47 | METHOD='netwarp'
48 |
49 |
50 |
51 | PREROOT=''
52 | PRE_ENC="./imgnetpre/resnet101-imagenet.pth"
53 | MAXDIST='3'
54 | #########
55 | ARCH=resnet101
56 | CFG='vsp-'$ARCH'dilated-ppm_deepsup_clip.yaml'
57 | #CFG='vsp-'$ARCH'dilated-ppm_clip.yaml'
58 | #CFG="vsp-"$ARCH"dilated_tdnet.yaml"
59 | PREDIR="../data/imgnetpre/"$ARCH"-imagenet.pth"
60 |
61 | NAME='newjob_lr'$LR'_bs'$BATCHSIZE'_epoch'$EPOCH'_FPS'$TRAINFPS'_clipnum'$CLIPNUM"_dilation"$DILATION"_fix"$FIX"_tdnet"$TDNET"_arch"$ARCH'_method'$METHOD'_DISTSOFTMAX'$DISTSOFTMAX'_DISTNEAREST'$DISTNEAREST"_CLIPOCR_ALL"$CLIPOCR_ALL"_USEMEMORY"$USEMEMORY"imgnetpre"
62 |
63 |
64 | SAVEROOT=$SAVE"/"$NAME
65 | python train_clip2.py --cfg config/$CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --clip_num $CLIPNUM --dilation_num $DILATION --clip_up $CLIPUP --clip_middle $CLIPMIDDLE --fix $FIX --othergt $OTHERGT --propclip2 $PROPCLIP2 --earlyfuse $EARLYFUSE --early_usecat $EARLYCAT --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --pre_enc $PRE_ENC --max_distances $MAXDIST --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
66 |
67 |
68 |
69 | ###inference
70 | echo 'val'
71 | BATCHSIZE=1
72 | GPU_NUM=1
73 | ISSAVE=True
74 | LESSLABLE=False
75 | USE720p=False
76 | EARLYFUSE=False
77 | EARLYCAT=False
78 |
79 | #CLIPNUM=5
80 |
81 | IMGSAVEROOT='./clipsaveimg/'$NAME
82 |
83 | LOAD=$SAVEROOT'/model_epoch_'$EPOCH'.pth'
84 |
85 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'val' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
86 |
87 | echo 'test'
88 |
89 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'test' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/scripts/run_ocr.sh:
--------------------------------------------------------------------------------
1 | DATAROOT="/your/path/to/LVSP_plus_data_label124_480p"
2 |
3 |
4 | #####
5 | ARCH=res101_ocrnet
6 | CFG="config/vsp-resnet101dilated-ocr_deepsup.yaml"
7 | ####
8 |
9 |
10 |
11 | PREDIR='./imgnetpre/resnet101-imagenet.pth'
12 |
13 |
14 | SAVE="./savemodel"
15 |
16 |
17 | DATAROOT2=../data/adeour
18 | BATCHSIZE=8
19 | WORKERS=12
20 | USETWODATA=False
21 | START_GPU=0
22 | GPU_NUM=2
23 | TRAINFPS=2
24 | LR=0.002
25 |
26 |
27 |
28 | CROPSIZE=479
29 |
30 | LESSLABEL=False
31 |
32 |
33 | USE_CLIPDATASET=True
34 | EPOCH=120
35 | NAME='job_lr'$LR'batchsize'$BATCHSIZE'_EPOCH'$EPOCH'_FPS'$TRAINFPS"_arch"$ARCH"new124_gpu"$GPU_NUM"_480p""USE_CLIPDATASET"$USE_CLIPDATASET
36 | SAVEROOT=$SAVE"/"$NAME
37 | VAL=False
38 | echo $CFG
39 |
40 |
41 | echo 'train...'
42 | python train.py --cfg $CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --use_clipdataset $USE_CLIPDATASET
43 |
44 |
45 | LOAD_EN=$SAVEROOT'/encoder_epoch_'$EPOCH'.pth'
46 | LOAD_DE=$SAVEROOT'/decoder_epoch_'$EPOCH'.pth'
47 |
48 |
49 |
50 | TESTBATCHSIZE=2
51 | ISSAVE=False
52 | IMGSAVEROOT='./saveimg/'$NAME'_train'
53 | USE720p=False
54 | LESSLABLE=False
55 |
56 | echo 'val...'
57 | python test.py --cfg $CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --load_en $LOAD_EN --load_de $LOAD_DE --batchsize $TESTBATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --split 'val'
58 | echo 'test...'
59 |
60 | python test.py --cfg $CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --load_en $LOAD_EN --load_de $LOAD_DE --batchsize $TESTBATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --split 'test'
61 |
62 |
--------------------------------------------------------------------------------
/scripts/run_psp.sh:
--------------------------------------------------------------------------------
1 | DATAROOT="/your/path/to/LVSP_plus_data_label124_480p"
2 |
3 |
4 | #####
5 | #ARCH=res101_ocrnet
6 | #CFG="config/vsp-resnet101dilated-ocr_deepsup.yaml"
7 | ####
8 |
9 | #ARCH=res101_deeplab
10 | #CFG="config/vsp-resnet101dilated-deeplab.yaml"
11 | #CFG="config/vsp-resnet50dilated-deeplab.yaml"
12 |
13 | ######
14 | #ARCH=res101_nonlocal2d_nodown
15 | #CFG="config/vsp-resnet101dilated-nonlocal2d.yaml"
16 | ########
17 | #ARCH=mobile_ppm
18 | ARCH=res101_ppm
19 | CFG="config/vsp-resnet101dilated-ppm_deepsup.yaml"
20 | #CFG="config/ade20k-resnet50dilated-ppm_deepsup.yaml"
21 | #CFG="config/ade20k-mobilenetv2dilated-ppm_deepsup.yaml"
22 |
23 | #ARCH=resnet101uper
24 | #CFG="config/ade20k-resnet101-upernet.yaml"
25 | #CFG="config/ade20k-resnet50-upernet.yaml"
26 |
27 |
28 | #PREDIR="../data/imgnetpre/resnet101-imagenet.pth"
29 | PREDIR='./imgnetpre/resnet101-imagenet.pth'
30 | #PREDIR="../data/imgnetpre/resnet50-imagenet.pth"
31 | #PREDIR="../data/imgnetpre/mobilenet_v2.pth.tar"
32 |
33 |
34 | #SAVE="../afs/video_seg/vsp_124"
35 | SAVE="./savemodel"
36 |
37 | #ARCH='hrnet'
38 | #CFG="config/ade20k-hrnetv2.yaml"
39 | #PREDIR="../data/imgnetpre/hrnetv2_w48-imagenet.pth"
40 |
41 | DATAROOT2=../data/adeour
42 | BATCHSIZE=8
43 | WORKERS=12
44 | USETWODATA=False
45 | START_GPU=0
46 | GPU_NUM=2
47 | TRAINFPS=2
48 | LR=0.002
49 |
50 |
51 |
52 | CROPSIZE=479
53 |
54 | LESSLABEL=False
55 |
56 |
57 | USE_CLIPDATASET=True
58 | EPOCH=120
59 | NAME='job_lr'$LR'batchsize'$BATCHSIZE'_EPOCH'$EPOCH'_FPS'$TRAINFPS"_arch"$ARCH"new124_gpu"$GPU_NUM"_480p""USE_CLIPDATASET"$USE_CLIPDATASET
60 | SAVEROOT=$SAVE"/"$NAME
61 | VAL=False
62 | echo $CFG
63 |
64 |
65 | echo 'train...'
66 | python train.py --cfg $CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --use_clipdataset $USE_CLIPDATASET
67 |
68 |
69 | LOAD_EN=$SAVEROOT'/encoder_epoch_'$EPOCH'.pth'
70 | LOAD_DE=$SAVEROOT'/decoder_epoch_'$EPOCH'.pth'
71 |
72 |
73 |
74 | TESTBATCHSIZE=2
75 | ISSAVE=True
76 | IMGSAVEROOT='./saveimg/'$NAME'_train'
77 | USE720p=False
78 | LESSLABLE=False
79 |
80 | echo 'val...'
81 | python test.py --cfg $CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --load_en $LOAD_EN --load_de $LOAD_DE --batchsize $TESTBATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --split 'val'
82 | echo 'test...'
83 |
84 | python test.py --cfg $CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --load_en $LOAD_EN --load_de $LOAD_DE --batchsize $TESTBATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --split 'test'
85 |
86 |
--------------------------------------------------------------------------------
/scripts/run_temporal_ocr.sh:
--------------------------------------------------------------------------------
1 | DATAROOT="your/path/to/VSPW_480p"
2 |
3 |
4 |
5 | SAVE="./savemodel"
6 | DATAROOT2='data2'
7 | BATCHSIZE=8
8 | WORKERS=12
9 | CROPSIZE=479
10 |
11 |
12 |
13 | START_GPU=0
14 | GPU_NUM=4
15 | TRAINFPS=1
16 | EPOCH=120
17 | LR=0.002
18 | VAL=False
19 | USETWODATA=False
20 | LESSLABEL=False
21 | CLIPNUM=4
22 | DILATION=0
23 | CLIPUP=False
24 | CLIPMIDDLE=False
25 | OTHERGT=False
26 | PROPCLIP2=False
27 | EARLYFUSE=True
28 | EARLYCAT=False
29 | CONVLSTM=False
30 | NON_LOCAL=False
31 |
32 |
33 |
34 | FIX=False
35 | ALLSUP=True
36 | ALLSUPSCALE=0.5
37 | LINEAR_COM=True
38 | DISTSOFTMAX=False
39 | DISTNEAREST=False
40 | TEMP=0.05
41 |
42 | DILATION2="3,6,9"
43 |
44 | CLIPOCR_ALL=False
45 | USEMEMORY=True
46 |
47 | METHOD='clip_ocr'
48 |
49 |
50 |
51 | PREROOT=''
52 | PRE_ENC="./imgnetpre/resnet101-imagenet.pth"
53 | MAXDIST='3'
54 | #########
55 | ARCH=resnet101
56 | CFG='vsp-'$ARCH'dilated-ppm_deepsup_clip.yaml'
57 | #CFG='vsp-'$ARCH'dilated-ppm_clip.yaml'
58 | #CFG="vsp-"$ARCH"dilated_tdnet.yaml"
59 | PREDIR="../data/imgnetpre/"$ARCH"-imagenet.pth"
60 |
61 | NAME='newjob_lr'$LR'_bs'$BATCHSIZE'_epoch'$EPOCH'_FPS'$TRAINFPS'_clipnum'$CLIPNUM"_dilation"$DILATION"_fix"$FIX"_tdnet"$TDNET"_arch"$ARCH'_method'$METHOD'_DISTSOFTMAX'$DISTSOFTMAX'_DISTNEAREST'$DISTNEAREST"_CLIPOCR_ALL"$CLIPOCR_ALL"_USEMEMORY"$USEMEMORY"imgnetpre"
62 |
63 |
64 | SAVEROOT=$SAVE"/"$NAME
65 | python train_clip2.py --cfg config/$CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --clip_num $CLIPNUM --dilation_num $DILATION --clip_up $CLIPUP --clip_middle $CLIPMIDDLE --fix $FIX --othergt $OTHERGT --propclip2 $PROPCLIP2 --earlyfuse $EARLYFUSE --early_usecat $EARLYCAT --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --pre_enc $PRE_ENC --max_distances $MAXDIST --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
66 |
67 |
68 |
69 | ###inference
70 | echo 'val'
71 | BATCHSIZE=1
72 | GPU_NUM=1
73 | ISSAVE=True
74 | LESSLABLE=False
75 | USE720p=False
76 | EARLYFUSE=False
77 | EARLYCAT=False
78 |
79 | #CLIPNUM=5
80 |
81 | IMGSAVEROOT='./clipsaveimg/'$NAME
82 |
83 | LOAD=$SAVEROOT'/model_epoch_'$EPOCH'.pth'
84 |
85 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'val' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
86 |
87 | echo 'test'
88 |
89 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'test' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
90 |
91 | #python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'valtest' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --use_memory $USEMEMORY
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/scripts/run_temporal_psp.sh:
--------------------------------------------------------------------------------
1 | DATAROOT="your/path/to/VSPW_480p"
2 |
3 |
4 |
5 | SAVE="./savemodel"
6 | DATAROOT2='data2'
7 | BATCHSIZE=8
8 | WORKERS=12
9 | CROPSIZE=479
10 |
11 |
12 |
13 | START_GPU=0
14 | GPU_NUM=4
15 | TRAINFPS=1
16 | EPOCH=120
17 | LR=0.002
18 | VAL=False
19 | USETWODATA=False
20 | LESSLABEL=False
21 | CLIPNUM=4
22 | DILATION=0
23 | CLIPUP=False
24 | CLIPMIDDLE=False
25 | OTHERGT=False
26 | PROPCLIP2=False
27 | EARLYFUSE=True
28 | EARLYCAT=False
29 | CONVLSTM=False
30 | NON_LOCAL=False
31 |
32 |
33 |
34 | FIX=False
35 | ALLSUP=True
36 | ALLSUPSCALE=0.5
37 | LINEAR_COM=True
38 | DISTSOFTMAX=False
39 | DISTNEAREST=False
40 | TEMP=0.05
41 |
42 | DILATION2="3,6,9"
43 |
44 | CLIPOCR_ALL=False
45 | USEMEMORY=True
46 |
47 | METHOD='clip_psp'
48 |
49 |
50 |
51 | PREROOT=''
52 | PRE_ENC="./imgnetpre/resnet101-imagenet.pth"
53 | MAXDIST='3'
54 | #########
55 | ARCH=resnet101
56 | CFG='vsp-'$ARCH'dilated-ppm_deepsup_clip.yaml'
57 | #CFG='vsp-'$ARCH'dilated-ppm_clip.yaml'
58 | #CFG="vsp-"$ARCH"dilated_tdnet.yaml"
59 | PREDIR="../data/imgnetpre/"$ARCH"-imagenet.pth"
60 |
61 | NAME='newjob_lr'$LR'_bs'$BATCHSIZE'_epoch'$EPOCH'_FPS'$TRAINFPS'_clipnum'$CLIPNUM"_dilation"$DILATION"_fix"$FIX"_tdnet"$TDNET"_arch"$ARCH'_method'$METHOD'_DISTSOFTMAX'$DISTSOFTMAX'_DISTNEAREST'$DISTNEAREST"_CLIPOCR_ALL"$CLIPOCR_ALL"_USEMEMORY"$USEMEMORY"imgnetpre"
62 |
63 |
64 | SAVEROOT=$SAVE"/"$NAME
65 | python train_clip2.py --cfg config/$CFG --predir $PREDIR --batchsize $BATCHSIZE --workers $WORKERS --start_gpu $START_GPU --gpu_num $GPU_NUM --dataroot $DATAROOT --trainfps $TRAINFPS --lr $LR --multi_scale True --saveroot $SAVEROOT --totalepoch $EPOCH --dataroot2 $DATAROOT2 --usetwodata $USETWODATA --cropsize $CROPSIZE --validation $VAL --lesslabel $LESSLABEL --clip_num $CLIPNUM --dilation_num $DILATION --clip_up $CLIPUP --clip_middle $CLIPMIDDLE --fix $FIX --othergt $OTHERGT --propclip2 $PROPCLIP2 --earlyfuse $EARLYFUSE --early_usecat $EARLYCAT --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --pre_enc $PRE_ENC --max_distances $MAXDIST --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
66 |
67 |
68 |
69 | ###inference
70 | echo 'val'
71 | BATCHSIZE=1
72 | GPU_NUM=1
73 | ISSAVE=True
74 | LESSLABLE=False
75 | USE720p=False
76 | EARLYFUSE=False
77 | EARLYCAT=False
78 |
79 | #CLIPNUM=5
80 |
81 | IMGSAVEROOT='./clipsaveimg/'$NAME
82 |
83 | LOAD=$SAVEROOT'/model_epoch_'$EPOCH'.pth'
84 |
85 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'val' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
86 |
87 | echo 'test'
88 |
89 | python test_clip2.py --cfg config/$CFG --start_gpu $START_GPU --dataroot $DATAROOT --saveroot $IMGSAVEROOT --batchsize $BATCHSIZE --is_save $ISSAVE --lesslabel $LESSLABLE --use_720p $USE720p --clip_num $CLIPNUM --dilation_num $DILATION --load $LOAD --split 'test' --allsup $ALLSUP --allsup_scale $ALLSUPSCALE --linear_combine $LINEAR_COM --distsoftmax $DISTSOFTMAX --distnearest $DISTNEAREST --temp $TEMP --max_distances $MAXDIST --gpu_num $GPU_NUM --method $METHOD --dilation2 $DILATION2 --clipocr_all $CLIPOCR_ALL --use_memory $USEMEMORY
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------