├── LICENSE
├── MODEL_ZOO.md
├── README.md
├── configs
├── default.py
├── models
│ ├── aotb.py
│ ├── aotl.py
│ ├── aots.py
│ ├── aott.py
│ ├── deaotb.py
│ ├── deaotl.py
│ ├── deaots.py
│ ├── deaott.py
│ ├── default.py
│ ├── default_deaot.py
│ ├── r101_aotl.py
│ ├── r50_aotl.py
│ ├── r50_deaotl.py
│ ├── rs101_aotl.py
│ ├── swinb_aotl.py
│ └── swinb_deaotl.py
├── pre.py
├── pre_dav.py
├── pre_ytb.py
├── pre_ytb_dav.py
└── ytb.py
├── dataloaders
├── __init__.py
├── eval_datasets.py
├── image_transforms.py
├── train_datasets.py
└── video_transforms.py
├── datasets
├── DAVIS
│ └── README.md
├── Demo
│ ├── images
│ │ ├── 1001_3iEIq5HBY1s
│ │ │ ├── 00002058.jpg
│ │ │ ├── 00002059.jpg
│ │ │ ├── 00002060.jpg
│ │ │ ├── 00002061.jpg
│ │ │ ├── 00002062.jpg
│ │ │ ├── 00002063.jpg
│ │ │ ├── 00002064.jpg
│ │ │ ├── 00002065.jpg
│ │ │ ├── 00002066.jpg
│ │ │ ├── 00002067.jpg
│ │ │ ├── 00002068.jpg
│ │ │ ├── 00002069.jpg
│ │ │ ├── 00002070.jpg
│ │ │ ├── 00002071.jpg
│ │ │ ├── 00002072.jpg
│ │ │ ├── 00002073.jpg
│ │ │ ├── 00002074.jpg
│ │ │ ├── 00002075.jpg
│ │ │ ├── 00002076.jpg
│ │ │ ├── 00002077.jpg
│ │ │ ├── 00002078.jpg
│ │ │ ├── 00002079.jpg
│ │ │ ├── 00002080.jpg
│ │ │ ├── 00002081.jpg
│ │ │ ├── 00002082.jpg
│ │ │ ├── 00002083.jpg
│ │ │ ├── 00002084.jpg
│ │ │ ├── 00002085.jpg
│ │ │ ├── 00002086.jpg
│ │ │ ├── 00002087.jpg
│ │ │ ├── 00002088.jpg
│ │ │ ├── 00002089.jpg
│ │ │ ├── 00002090.jpg
│ │ │ ├── 00002091.jpg
│ │ │ ├── 00002092.jpg
│ │ │ ├── 00002093.jpg
│ │ │ ├── 00002094.jpg
│ │ │ ├── 00002095.jpg
│ │ │ ├── 00002096.jpg
│ │ │ ├── 00002097.jpg
│ │ │ ├── 00002098.jpg
│ │ │ ├── 00002099.jpg
│ │ │ ├── 00002100.jpg
│ │ │ ├── 00002101.jpg
│ │ │ └── 00002102.jpg
│ │ └── 1007_YCTBBdbKSSg
│ │ │ ├── 00000693.jpg
│ │ │ ├── 00000694.jpg
│ │ │ ├── 00000695.jpg
│ │ │ ├── 00000696.jpg
│ │ │ ├── 00000697.jpg
│ │ │ ├── 00000698.jpg
│ │ │ ├── 00000699.jpg
│ │ │ ├── 00000700.jpg
│ │ │ ├── 00000701.jpg
│ │ │ ├── 00000702.jpg
│ │ │ ├── 00000703.jpg
│ │ │ ├── 00000704.jpg
│ │ │ ├── 00000705.jpg
│ │ │ ├── 00000706.jpg
│ │ │ ├── 00000707.jpg
│ │ │ ├── 00000708.jpg
│ │ │ ├── 00000709.jpg
│ │ │ ├── 00000710.jpg
│ │ │ ├── 00000711.jpg
│ │ │ ├── 00000712.jpg
│ │ │ ├── 00000713.jpg
│ │ │ ├── 00000714.jpg
│ │ │ ├── 00000715.jpg
│ │ │ ├── 00000716.jpg
│ │ │ ├── 00000717.jpg
│ │ │ ├── 00000718.jpg
│ │ │ ├── 00000719.jpg
│ │ │ ├── 00000720.jpg
│ │ │ ├── 00000721.jpg
│ │ │ ├── 00000722.jpg
│ │ │ ├── 00000723.jpg
│ │ │ ├── 00000724.jpg
│ │ │ ├── 00000725.jpg
│ │ │ ├── 00000726.jpg
│ │ │ ├── 00000727.jpg
│ │ │ ├── 00000728.jpg
│ │ │ ├── 00000729.jpg
│ │ │ ├── 00000730.jpg
│ │ │ ├── 00000731.jpg
│ │ │ ├── 00000732.jpg
│ │ │ ├── 00000733.jpg
│ │ │ ├── 00000734.jpg
│ │ │ ├── 00000735.jpg
│ │ │ └── 00000736.jpg
│ └── masks
│ │ ├── 1001_3iEIq5HBY1s
│ │ └── 00002058.png
│ │ └── 1007_YCTBBdbKSSg
│ │ └── 00000693.png
├── Static
│ └── README.md
└── YTB
│ ├── 2018
│ ├── train
│ │ └── README.md
│ ├── valid
│ │ └── README.md
│ └── valid_all_frames
│ │ └── README.md
│ └── 2019
│ ├── train
│ └── README.md
│ ├── valid
│ └── README.md
│ └── valid_all_frames
│ └── README.md
├── networks
├── __init__.py
├── decoders
│ ├── __init__.py
│ └── fpn.py
├── encoders
│ ├── __init__.py
│ ├── mobilenetv2.py
│ ├── mobilenetv3.py
│ ├── resnest
│ │ ├── __init__.py
│ │ ├── resnest.py
│ │ ├── resnet.py
│ │ └── splat.py
│ ├── resnet.py
│ └── swin
│ │ ├── __init__.py
│ │ ├── build.py
│ │ └── swin_transformer.py
├── engines
│ ├── __init__.py
│ ├── aot_engine.py
│ └── deaot_engine.py
├── layers
│ ├── __init__.py
│ ├── attention.py
│ ├── basic.py
│ ├── loss.py
│ ├── normalization.py
│ ├── position.py
│ └── transformer.py
├── managers
│ ├── evaluator.py
│ └── trainer.py
└── models
│ ├── __init__.py
│ ├── aot.py
│ └── deaot.py
├── pretrain_models
└── README.md
├── source
├── 1001_3iEIq5HBY1s.gif
├── 1007_YCTBBdbKSSg.gif
├── kobe.gif
├── messi.gif
├── overview.png
├── overview_deaot.png
└── some_results.png
├── tools
├── demo.py
├── eval.py
└── train.py
├── train_eval.sh
└── utils
├── __init__.py
├── checkpoint.py
├── cp_ckpt.py
├── ema.py
├── eval.py
├── image.py
├── learning.py
├── math.py
├── meters.py
└── metric.py
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, z-x-yang
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/MODEL_ZOO.md:
--------------------------------------------------------------------------------
1 | ## Model Zoo and Results
2 |
3 | ### Environment and Settings
4 | - 4/1 NVIDIA V100 GPUs for training/evaluation.
5 | - Auto-mixed precision was enabled in training but disabled in evaluation.
6 | - Test-time augmentations were not used.
7 | - The inference resolution of DAVIS/YouTube-VOS was 480p/1.3x480p as [CFBI](https://github.com/z-x-yang/CFBI).
8 | - Fully online inference. We passed all the modules frame by frame.
9 | - Multi-object FPS was recorded instead of single-object one.
10 |
11 | ### Pre-trained Models
12 | Stages:
13 |
14 | - `PRE`: the pre-training stage with static images.
15 |
16 | - `PRE_YTB_DAV`: the main-training stage with YouTube-VOS and DAVIS. All the kinds of evaluation share an **identical** model and the **same** parameters.
17 |
18 |
19 | | Model | Param (M) | PRE | PRE_YTB_DAV |
20 | |:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|
21 | | AOTT | 5.7 | [gdrive](https://drive.google.com/file/d/1_513h8Hok9ySQPMs_dHgX5sPexUhyCmy/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1owPmwV4owd_ll6GuilzklqTyAd0ZvbCu/view?usp=sharing) |
22 | | AOTS | 7.0 | [gdrive](https://drive.google.com/file/d/1QUP0-VED-lOF1oX_ppYWnXyBjvUzJJB7/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1beU5E6Mdnr_pPrgjWvdWurKAIwJSz1xf/view?usp=sharing) |
23 | | AOTB | 8.3 | [gdrive](https://drive.google.com/file/d/11Bx8n_INAha1IdpHjueGpf7BrKmCJDvK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1hH-GOn4GAxHkV8ARcQzsUy8Ax6ndot-A/view?usp=sharing) |
24 | | AOTL | 8.3 | [gdrive](https://drive.google.com/file/d/1WL6QCsYeT7Bt-Gain9ZIrNNXpR2Hgh29/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1L1N2hkSPqrwGgnW9GyFHuG59_EYYfTG4/view?usp=sharing) |
25 | | R50-AOTL | 14.9 | [gdrive](https://drive.google.com/file/d/1hS4JIvOXeqvbs-CokwV6PwZV-EvzE6x8/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) |
26 | | SwinB-AOTL | 65.4 | [gdrive](https://drive.google.com/file/d/1LlhKQiXD8JyZGGs3hZiNzcaCLqyvL9tj/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/192jCGQZdnuTsvX-CVra-KVZl2q1ZR0vW/view?usp=sharing) |
27 |
28 | | Model | Param (M) | PRE | PRE_YTB_DAV |
29 | |:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|
30 | | DeAOTT | 7.2 | [gdrive](https://drive.google.com/file/d/11C1ZBoFpL3ztKtINS8qqwPSldfYXexFK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1ThWIZQS03cYWx1EKNN8MIMnJS5eRowzr/view?usp=sharing) |
31 | | DeAOTS | 10.2 | [gdrive](https://drive.google.com/file/d/1uUidrWVoaP9A5B5-EzQLbielUnRLRF3j/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1YwIAV5tBtn5spSFxKLBQBEQGwPHyQlHi/view?usp=sharing) |
32 | | DeAOTB | 13.2 | [gdrive](https://drive.google.com/file/d/1bEQr6vIgQMVITrSOtxWTMgycKpS0cor9/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1BHxsonnvJXylqHlZ1zJHHc-ymKyq-CFf/view?usp=sharing) |
33 | | DeAOTL | 13.2 | [gdrive](https://drive.google.com/file/d/1_vBL4KJlmBy0oBE4YFDOvsYL1ZtpEL32/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/18elNz_wi9JyVBcIUYKhRdL08MA-FqHD5/view?usp=sharing) |
34 | | R50-DeAOTL | 19.8 | [gdrive](https://drive.google.com/file/d/1sTRQ1g0WCpqVCdavv7uJiZNkXunBt3-R/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ/view?usp=sharing) |
35 | | SwinB-DeAOTL | 70.3 | [gdrive](https://drive.google.com/file/d/16BZEE53no8CxT-pPLDC2q1d6Xlg8mWPU/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1g4E-F0RPOx9Nd6J7tU9AE1TjsouL4oZq/view?usp=sharing) |
36 |
37 | To use our pre-trained model to infer, a simple way is to set `--model` and `--ckpt_path` to your downloaded checkpoint's model type and file path when running `eval.py`.
38 |
39 | ### YouTube-VOS 2018 val
40 | `ALL-F`: all frames. The default evaluation setting of YouTube-VOS is 6fps, but 30fps sequences (all the frames) are also supplied by the dataset organizers. We noticed that many VOS methods prefer to evaluate with 30fps videos. Thus, we also supply our results here. Denser video sequences can significantly improve VOS performance when using the memory reading strategy (like AOTL, R50-AOTL, and SwinB-AOTL), but the efficiency will be influenced since more memorized frames are stored for object matching.
41 | | Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions |
42 | |:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:|
43 | | AOTT | PRE_YTB_DAV | 41.0 | | 80.2 | 80.4 | 85.0 | 73.6 | 81.7 | [gdrive](https://drive.google.com/file/d/1u8mvPRT08ENZHsw9Xf_4C6Sv9BoCzENR/view?usp=sharing) |
44 | | AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 80.0 | 84.7 | 75.2 | 83.5 | [gdrive](https://drive.google.com/file/d/1RGMI5-29Z0odq73rt26eCxOUYUd-fvVv/view?usp=sharing) |
45 | | DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.6** | **86.3** | **75.8** | **84.2** | - |
46 | | AOTS | PRE_YTB_DAV | 27.1 | | 82.9 | 82.3 | 87.0 | 77.1 | 85.1 | [gdrive](https://drive.google.com/file/d/1a4-rNnxjMuPBq21IKo31WDYZXMPgS7r2/view?usp=sharing) |
47 | | AOTS | PRE_YTB_DAV | 27.1 | √ | 83.0 | 82.2 | 87.0 | 77.3 | 85.7 | [gdrive](https://drive.google.com/file/d/1Z0cndyoCw5Na6u-VFRE8CyiIG2RbMIUO/view?usp=sharing) |
48 | | DeAOTS | PRE_YTB_DAV | **38.7** | | **84.0** | **83.3** | **88.3** | **77.9** | **86.6** | - |
49 | | AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.2 | 88.1 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1J5nhuQbbjVLYNXViBIgo21ddQy-MiOLG/view?usp=sharing) |
50 | | AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.6 | 88.5 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1gFaweB_GTJjHzSD61v_ZsY9K7UEND30O/view?usp=sharing) |
51 | | DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.9** | **88.9** | **78.5** | **87.0** | - |
52 | | AOTL | PRE_YTB_DAV | 16.0 | | 84.1 | 83.2 | 88.2 | 78.2 | 86.8 | [gdrive](https://drive.google.com/file/d/1kS8KWQ2L3wzxt44ROLTxwZOT7ZpT8Igc/view?usp=sharing) |
53 | | AOTL | PRE_YTB_DAV | 6.5 | √ | 84.5 | 83.7 | 88.8 | 78.4 | **87.1** | [gdrive](https://drive.google.com/file/d/1Rpm3e215kJOUvb562lJ2kYg2I3hkrxiM/view?usp=sharing) |
54 | | DeAOTL | PRE_YTB_DAV | **24.7** | | **84.8** | **84.2** | **89.4** | **78.6** | 87.0 | - |
55 | | R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.6 | 83.7 | 88.5 | 78.8 | 87.3 | [gdrive](https://drive.google.com/file/d/1nbJZ1bbmEgyK-bg6HQ8LwCz5gVJ6wzIZ/view?usp=sharing) |
56 | | R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.5 | 84.5 | 89.5 | 79.6 | 88.2 | [gdrive](https://drive.google.com/file/d/1NbB54ZhYvfJh38KFOgovYYPjWopd-2TE/view?usp=sharing) |
57 | | R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **86.0** | **84.9** | **89.9** | **80.4** | **88.7** | - |
58 | | SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.5 | 89.5 | 78.1 | 86.7 | [gdrive](https://drive.google.com/file/d/1QFowulSY0LHfpsjUV8ZE9rYc55L9DOC7/view?usp=sharing) |
59 | | SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.1 | 85.1 | 90.1 | 78.4 | 86.9 | [gdrive](https://drive.google.com/file/d/1TulhVOhh01rkssNYbOQASeWKu7CQ5Azx/view?usp=sharing) |
60 | | SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.2** | **85.6** | **90.6** | **80.0** | **88.4** | - |
61 |
62 | ### YouTube-VOS 2019 val
63 | | Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions |
64 | |:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:|
65 | | AOTT | PRE_YTB_DAV | 41.0 | | 80.0 | 79.8 | 84.2 | 74.1 | 82.1 | [gdrive](https://drive.google.com/file/d/1zzyhN1XYtajte5nbZ7opOdfXeDJgCxC5/view?usp=sharing) |
66 | | AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 79.9 | 84.4 | 75.6 | 83.8 | [gdrive](https://drive.google.com/file/d/1V_5vi9dAXOis_WrDieacSESm7OX20Bv-/view?usp=sharing) |
67 | | DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.2** | **85.6** | **76.4** | **84.7** | - |
68 | | AOTS | PRE_YTB_DAV | 27.1 | | 82.7 | 81.9 | 86.5 | 77.3 | 85.2 | [gdrive](https://drive.google.com/file/d/11YdkUeyjkTv8Uw7xMgPCBzJs6v5SDt6n/view?usp=sharing) |
69 | | AOTS | PRE_YTB_DAV | 27.1 | √ | 82.8 | 81.9 | 86.5 | 77.3 | 85.6 | [gdrive](https://drive.google.com/file/d/1UhyurGTJeAw412czU3_ebzNwF8xQ4QG_/view?usp=sharing) |
70 | | DeAOTS | PRE_YTB_DAV | **38.7** | | **83.8** | **82.8** | **87.5** | **78.1** | **86.8** | - |
71 | | AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.1 | 87.7 | 78.5 | 86.8 | [gdrive](https://drive.google.com/file/d/1NeI8cT4kVqTqVWAwtwiga1rkrvksNWaO/view?usp=sharing) |
72 | | AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.3 | 88.0 | 78.2 | 86.7 | [gdrive](https://drive.google.com/file/d/1kpYV2XFR0sOfLWD-wMhd-nUO6CFiLjlL/view?usp=sharing) |
73 | | DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.5** | **88.3** | **79.1** | **87.5** | - |
74 | | AOTL | PRE_YTB_DAV | 16.0 | | 84.0 | 82.8 | 87.6 | 78.6 | 87.1 | [gdrive](https://drive.google.com/file/d/1qKLlNXxmT31bW0weEHI_zAf4QwU8Lhou/view?usp=sharing) |
75 | | AOTL | PRE_YTB_DAV | 6.5 | √ | 84.2 | 83.0 | 87.8 | 78.7 | 87.3 | [gdrive](https://drive.google.com/file/d/1o3fwZ0cH71bqHSA3bYNjhP4GGv9Vyuwa/view?usp=sharing) |
76 | | DeAOTL | PRE_YTB_DAV | **24.7** | | **84.7** | **83.8** | **88.8** | **79.0** | **87.2** | - |
77 | | R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.4 | 83.4 | 88.1 | 78.7 | 87.2 | [gdrive](https://drive.google.com/file/d/1I7ooSp8EYfU6fvkP6QcCMaxeencA68AH/view?usp=sharing) |
78 | | R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.3 | 83.9 | 88.8 | 79.9 | 88.5 | [gdrive](https://drive.google.com/file/d/1OGqlkEu0uXa8QVWIVz_M5pmXXiYR2sh3/view?usp=sharing) |
79 | | R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **85.9** | **84.6** | **89.4** | **80.8** | **88.9** | - |
80 | | SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.0 | 88.8 | 78.7 | 87.1 | [gdrive](https://drive.google.com/file/d/1fPzCxi5GM7N2sLKkhoTC2yoY_oTQCHp1/view?usp=sharing) |
81 | | SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.3 | 84.6 | 89.5 | 79.3 | 87.7 | [gdrive](https://drive.google.com/file/d/1e3D22s_rJ7Y2X2MHo7x5lcNtwmHFlwYB/view?usp=sharing) |
82 | | SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.1** | **85.3** | **90.2** | **80.4** | **88.6** | - |
83 |
84 | ### DAVIS-2017 test
85 |
86 | | Model | Stage | FPS | Mean | J Score | F Score | Predictions |
87 | | ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:|
88 | | AOTT | PRE_YTB_DAV | **51.4** | 73.7 | 70.0 | 77.3 | [gdrive](https://drive.google.com/file/d/14Pu-6Uz4rfmJ_WyL2yl57KTx_pSSUNAf/view?usp=sharing) |
89 | | AOTS | PRE_YTB_DAV | 40.0 | 75.2 | 71.4 | 78.9 | [gdrive](https://drive.google.com/file/d/1zzAPZCRLgnBWuAXqejPPEYLqBxu67Rj1/view?usp=sharing) |
90 | | AOTB | PRE_YTB_DAV | 29.6 | 77.4 | 73.7 | 81.1 | [gdrive](https://drive.google.com/file/d/1WpQ-_Jrs7Ssfw0oekrejM2OVWEx_tBN1/view?usp=sharing) |
91 | | AOTL | PRE_YTB_DAV | 18.7 | 79.3 | 75.5 | 83.2 | [gdrive](https://drive.google.com/file/d/1rP1Zdgc0N1d8RR2EaXMz3F-o5zqcNVe8/view?usp=sharing) |
92 | | R50-AOTL | PRE_YTB_DAV | 18.0 | 79.5 | 76.0 | 83.0 | [gdrive](https://drive.google.com/file/d/1iQ5iNlvlS-In586ZNc4LIZMSdNIWDvle/view?usp=sharing) |
93 | | SwinB-AOTL | PRE_YTB_DAV | 12.1 | **82.1** | **78.2** | **85.9** | [gdrive](https://drive.google.com/file/d/1oVt4FPcZdfVHiOxjYYKef0q7Ovy4f5Q_/view?usp=sharing) |
94 |
95 | ### DAVIS-2017 val
96 |
97 | | Model | Stage | FPS | Mean | J Score | F Score | Predictions |
98 | | ---------- |:-----------:|:----:|:--------:|:--------:|:---------:|:----:|
99 | | AOTT | PRE_YTB_DAV | **51.4** | 79.2 | 76.5 | 81.9 | [gdrive](https://drive.google.com/file/d/10OUFhK2Sz-hOJrTDoTI0mA45KO1qodZt/view?usp=sharing) |
100 | | AOTS | PRE_YTB_DAV | 40.0 | 82.1 | 79.3 | 84.8 | [gdrive](https://drive.google.com/file/d/1T-JTYyksWlq45jxcLjnRaBvvYUhWgHFH/view?usp=sharing) |
101 | | AOTB | PRE_YTB_DAV | 29.6 | 83.3 | 80.6 | 85.9 | [gdrive](https://drive.google.com/file/d/1EVUnxQm9TLBTuwK82QyiSKk9R9V8NwRL/view?usp=sharing) |
102 | | AOTL | PRE_YTB_DAV | 18.7 | 83.6 | 80.8 | 86.3 | [gdrive](https://drive.google.com/file/d/1CFauSni2BxAe_fcl8W_6bFByuwJRbDYm/view?usp=sharing) |
103 | | R50-AOTL | PRE_YTB_DAV | 18.0 | 85.2 | 82.5 | 87.9 | [gdrive](https://drive.google.com/file/d/1vjloxnP8R4PZdsH2DDizfU2CrkdRHHyo/view?usp=sharing) |
104 | | SwinB-AOTL | PRE_YTB_DAV | 12.1 | **85.9** | **82.9** | **88.9** | [gdrive](https://drive.google.com/file/d/1tYCbKOas0i7Et2iyUAyDwaXnaD9YWxLr/view?usp=sharing) |
105 |
106 | ### DAVIS-2016 val
107 |
108 | | Model | Stage | FPS | Mean | J Score | F Score | Predictions |
109 | | ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:|
110 | | AOTT | PRE_YTB_DAV | **51.4** | 87.5 | 86.5 | 88.4 | [gdrive](https://drive.google.com/file/d/1LeW8WQhnylZ3umT7E379KdII92uUsGA9/view?usp=sharing) |
111 | | AOTS | PRE_YTB_DAV | 40.0 | 89.6 | 88.6 | 90.5 | [gdrive](https://drive.google.com/file/d/1vqGei5tLu1FPVrTi5bwRAsaGy3Upf7B1/view?usp=sharing) |
112 | | AOTB | PRE_YTB_DAV | 29.6 | 90.9 | 89.6 | 92.1 | [gdrive](https://drive.google.com/file/d/1qAppo2uOVu0FbE9t1FBUpymC3yWgw1LM/view?usp=sharing) |
113 | | AOTL | PRE_YTB_DAV | 18.7 | 91.1 | 89.5 | 92.7 | [gdrive](https://drive.google.com/file/d/1g6cjYhgBWjMaY3RGAm31qm3SPEF3QcKV/view?usp=sharing) |
114 | | R50-AOTL | PRE_YTB_DAV | 18.0 | 91.7 | 90.4 | 93.0 | [gdrive](https://drive.google.com/file/d/1QzxojqWKsvRf53K2AgKsK523ZVuYU4O-/view?usp=sharing) |
115 | | SwinB-AOTL | PRE_YTB_DAV | 12.1 | **92.2** | **90.6** | **93.8** | [gdrive](https://drive.google.com/file/d/1RIqUtAyVnopeogfT520d7a0yiULg1obp/view?usp=sharing) |
116 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AOT Series Frameworks in PyTorch
2 |
3 | [](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-15?p=decoupling-features-in-hierarchical)
4 | [](https://paperswithcode.com/sota/video-object-segmentation-on-youtube-vos?p=associating-objects-with-scalable)
5 | [](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-18?p=associating-objects-with-scalable)
6 | [](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-1?p=associating-objects-with-scalable)
7 | [](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2017?p=associating-objects-with-scalable)
8 | [](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2016?p=associating-objects-with-scalable)
9 |
10 | ## News
11 | - `2024/03`: **AOST** - [AOST](https://arxiv.org/abs/2203.11442), the journal extension of AOT, has been accepted by TPAMI. AOST is the first scalable VOS framework supporting run-time speed-accuracy trade-offs, from real-time efficiency to SOTA performance.
12 | - `2023/07`: **Pyramid/Panoptic AOT** - The code of PAOT has been released in [paot](https://github.com/yoxu515/aot-benchmark/tree/paot) branch of this repository. We propose a benchmark [**VIPOSeg**](https://github.com/yoxu515/VIPOSeg-Benchmark) for panoptic VOS, and PAOT is designed to tackle the challenges in panoptic VOS and achieves SOTA performance. PAOT consists of a multi-scale architecture of LSTT (same as MS-AOT in VOT2022) and panoptic ID banks for thing and stuff. Please refer to the [paper](https://arxiv.org/abs/2305.04470) for more details.
13 | - `2023/07`: **WINNER** - DeAOT-based Tracker ranked **1st** in the [**VOTS 2023**](https://www.votchallenge.net/vots2023/) challenge ([leaderboard](https://eu.aihub.ml/competitions/201#results)). In detail, our [DMAOT](https://eu.aihub.ml/my/competition/submission/1139/detailed_results/) improves DeAOT by storing object-wise long-term memories instead of frame-wise long-term memories. This avoids the memory growth problem when processing long video sequences and produces better results when handling multiple objects.
14 | - `2023/06`: **WINNER** - DeAOT-based Tracker ranked **1st** in **two tracks** of [**EPIC-Kitchens**](https://epic-kitchens.github.io/2023) challenges ([leaderboard](http://epic-kitchens.github.io/2023)). In detail, our MS-DeAOT is a multi-scale version of DeAOT and is the winner of Semi-Supervised Video Object Segmentation (segmentation-based tracking) and TREK-150 Object Tracking (BBox-based tracking). Technical reports are coming soon.
15 | - `2023/04`: **SAM-Track** - We are pleased to announce the release of our latest project, [Segment and Track Anything (SAM-Track)](https://github.com/z-x-yang/Segment-and-Track-Anything). This innovative project merges two kinds of models, [SAM](https://github.com/facebookresearch/segment-anything) and our [DeAOT](https://github.com/yoxu515/aot-benchmark), to achieve seamless segmentation and efficient tracking of any objects in videos.
16 | - `2022/10`: **WINNER** - AOT-based Tracker ranked **1st** in **four tracks** of the **VOT 2022** challenge ([presentation of results](https://data.votchallenge.net/vot2022/vot2022_st_rt.pdf)). In detail, our MS-AOT is the winner of two segmentation tracks, VOT-STs2022 and VOT-RTs2022 (real-time). In addition, the bounding box results of MS-AOT (initialized by [AlphaRef](https://github.com/MasterBin-IIAU/AlphaRefine), and output is bounding box fitted to mask prediction) surpass the winners of two bounding box tracks, VOT-STb2022 and VOT-RTb2022 (real-time). The bounding box results were required by the organizers after the competition deadline but were highlighted in the [workshop presentation](https://data.votchallenge.net/vot2022/vot2022_st_rt.pdf) (ECCV 2022).
17 |
18 | ## Intro
19 | A modular reference PyTorch implementation of AOT series frameworks:
20 | - **DeAOT**: Decoupling Features in Hierarchical Propagation for Video Object Segmentation (NeurIPS 2022, Spotlight) [[OpenReview](https://openreview.net/forum?id=DgM7-7eMkq0)][[PDF](https://arxiv.org/pdf/2210.09782.pdf)]
21 |
22 |
23 | - **AOT**: Associating Objects with Transformers for Video Object Segmentation (NeurIPS 2021, Score 8/8/7/8) [[OpenReview](https://openreview.net/forum?id=hl3v8io3ZYt)][[PDF](https://arxiv.org/abs/2106.02638)]
24 |
25 |
26 | An extension of AOT, [AOST](https://arxiv.org/abs/2203.11442) (under review), is available now. AOST is a more robust and flexible framework, supporting run-time speed-accuracy trade-offs.
27 |
28 | ## Examples
29 | Benchmark examples:
30 |
31 |
32 |
33 | General examples (Messi and Kobe):
34 |
35 |
36 |
37 | ## Highlights
38 | - **High performance:** up to **85.5%** ([R50-AOTL](MODEL_ZOO.md#youtube-vos-2018-val)) on YouTube-VOS 2018 and **82.1%** ([SwinB-AOTL]((MODEL_ZOO.md#youtube-vos-2018-val))) on DAVIS-2017 Test-dev under standard settings (without any test-time augmentation and post processing).
39 | - **High efficiency:** up to **51fps** ([AOTT](MODEL_ZOO.md#davis-2017-test)) on DAVIS-2017 (480p) even with **10** objects and **41fps** on YouTube-VOS (1.3x480p). AOT can process multiple objects (less than a pre-defined number, 10 is the default) as efficiently as processing a single object. This project also supports inferring any number of objects together within a video by automatic separation and aggregation.
40 | - **Multi-GPU training and inference**
41 | - **Mixed precision training and inference**
42 | - **Test-time augmentation:** multi-scale and flipping augmentations are supported.
43 |
44 | ## Requirements
45 | * Python3
46 | * pytorch >= 1.7.0 and torchvision
47 | * opencv-python
48 | * Pillow
49 | * Pytorch Correlation. Recommend to install from [source](https://github.com/ClementPinard/Pytorch-Correlation-extension) instead of using `pip`:
50 | ```bash
51 | git clone https://github.com/ClementPinard/Pytorch-Correlation-extension.git
52 | cd Pytorch-Correlation-extension
53 | python setup.py install
54 | cd -
55 | ```
56 |
57 | Optional:
58 | * scikit-image (if you want to run our **Demo**, please install)
59 |
60 | ## Model Zoo and Results
61 | Pre-trained models, benckmark scores, and pre-computed results reproduced by this project can be found in [MODEL_ZOO.md](MODEL_ZOO.md).
62 |
63 | ## Demo - Panoptic Propagation
64 | We provide a simple demo to demonstrate AOT's effectiveness. The demo will propagate more than **40** objects, including semantic regions (like sky) and instances (like person), together within a single complex scenario and predict its video panoptic segmentation.
65 |
66 | To run the demo, download the [checkpoint](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) of R50-AOTL into [pretrain_models](pretrain_models), and then run:
67 | ```bash
68 | python tools/demo.py
69 | ```
70 | which will predict the given scenarios in the resolution of 1.3x480p. You can also run this demo with other AOTs ([MODEL_ZOO.md](MODEL_ZOO.md)) by setting `--model` (model type) and `--ckpt_path` (checkpoint path).
71 |
72 | Two scenarios from [VSPW](https://www.vspwdataset.com/home) are supplied in [datasets/Demo](datasets/Demo):
73 |
74 | - 1001_3iEIq5HBY1s: 44 objects. 1080P.
75 | - 1007_YCTBBdbKSSg: 43 objects. 1080P.
76 |
77 | Results:
78 |
79 |
80 |
81 |
82 | ## Getting Started
83 | 0. Prepare a valid environment follow the [requirements](#requirements).
84 |
85 | 1. Prepare datasets:
86 |
87 | Please follow the below instruction to prepare datasets in each corresponding folder.
88 | * **Static**
89 |
90 | [datasets/Static](datasets/Static): pre-training dataset with static images. Guidance can be found in [AFB-URR](https://github.com/xmlyqing00/AFB-URR), which we referred to in the implementation of the pre-training.
91 | * **YouTube-VOS**
92 |
93 | A commonly-used large-scale VOS dataset.
94 |
95 | [datasets/YTB/2019](datasets/YTB/2019): version 2019, download [link](https://drive.google.com/drive/folders/1BWzrCWyPEmBEKm0lOHe5KLuBuQxUSwqz?usp=sharing). `train` is required for training. `valid` (6fps) and `valid_all_frames` (30fps, optional) are used for evaluation.
96 |
97 | [datasets/YTB/2018](datasets/YTB/2018): version 2018, download [link](https://drive.google.com/drive/folders/1bI5J1H3mxsIGo7Kp-pPZU8i6rnykOw7f?usp=sharing). Only `valid` (6fps) and `valid_all_frames` (30fps, optional) are required for this project and used for evaluation.
98 |
99 | * **DAVIS**
100 |
101 | A commonly-used small-scale VOS dataset.
102 |
103 | [datasets/DAVIS](datasets/DAVIS): [TrainVal](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) (480p) contains both the training and validation split. [Test-Dev](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) (480p) contains the Test-dev split. The [full-resolution version](https://davischallenge.org/davis2017/code.html) is also supported for training and evaluation but not required.
104 |
105 |
106 | 2. Prepare ImageNet pre-trained encoders
107 |
108 | Select and download below checkpoints into [pretrain_models](pretrain_models):
109 |
110 | - [MobileNet-V2](https://download.pytorch.org/models/mobilenet_v2-b0353104.pth) (default encoder)
111 | - [MobileNet-V3](https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth)
112 | - [ResNet-50](https://download.pytorch.org/models/resnet50-0676ba61.pth)
113 | - [ResNet-101](https://download.pytorch.org/models/resnet101-63fe2227.pth)
114 | - [ResNeSt-50](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest50-528c19ca.pth)
115 | - [ResNeSt-101](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth)
116 | - [Swin-Base](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)
117 |
118 | The current default training configs are not optimized for encoders larger than ResNet-50. If you want to use larger encoders, we recommend early stopping the main-training stage at 80,000 iterations (100,000 in default) to avoid over-fitting on the seen classes of YouTube-VOS.
119 |
120 |
121 |
122 | 3. Training and Evaluation
123 |
124 | The [example script](train_eval.sh) will train AOTT with 2 stages using 4 GPUs and auto-mixed precision (`--amp`). The first stage is a pre-training stage using `Static` dataset, and the second stage is a main-training stage, which uses both `YouTube-VOS 2019 train` and `DAVIS-2017 train` for training, resulting in a model that can generalize to different domains (YouTube-VOS and DAVIS) and different frame rates (6fps, 24fps, and 30fps).
125 |
126 | Notably, you can use only the `YouTube-VOS 2019 train` split in the second stage by changing `pre_ytb_dav` to `pre_ytb`, which leads to better YouTube-VOS performance on unseen classes. Besides, if you don't want to do the first stage, you can start the training from stage `ytb`, but the performance will drop about 1~2% absolutely.
127 |
128 | After the training is finished (about 0.6 days for each stage with 4 Tesla V100 GPUs), the [example script](train_eval.sh) will evaluate the model on YouTube-VOS and DAVIS, and the results will be packed into Zip files. For calculating scores, please use official YouTube-VOS servers ([2018 server](https://competitions.codalab.org/competitions/19544) and [2019 server](https://competitions.codalab.org/competitions/20127)), official [DAVIS toolkit](https://github.com/davisvideochallenge/davis-2017) (for Val), and official [DAVIS server](https://competitions.codalab.org/competitions/20516#learn_the_details) (for Test-dev).
129 |
130 | ## Adding your own dataset
131 | Coming
132 |
133 | ## Troubleshooting
134 | Waiting
135 |
136 | ## TODO
137 | - [ ] Code documentation
138 | - [ ] Adding your own dataset
139 | - [ ] Results with test-time augmentations in Model Zoo
140 | - [ ] Support gradient accumulation
141 | - [x] Demo tool
142 |
143 | ## Citations
144 | Please consider citing the related paper(s) in your publications if it helps your research.
145 | ```
146 | @article{yang2021aost,
147 | title={Scalable Video Object Segmentation with Identification Mechanism},
148 | author={Yang, Zongxin and Miao, Jiaxu and Wei, Yunchao and Wang, Wenguan and Wang, Xiaohan and Yang, Yi},
149 | journal={TPAMI},
150 | year={2024}
151 | }
152 | @inproceedings{xu2023video,
153 | title={Video object segmentation in panoptic wild scenes},
154 | author={Xu, Yuanyou and Yang, Zongxin and Yang, Yi},
155 | booktitle={IJCAI},
156 | year={2023}
157 | }
158 | @inproceedings{yang2022deaot,
159 | title={Decoupling Features in Hierarchical Propagation for Video Object Segmentation},
160 | author={Yang, Zongxin and Yang, Yi},
161 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
162 | year={2022}
163 | }
164 | @inproceedings{yang2021aot,
165 | title={Associating Objects with Transformers for Video Object Segmentation},
166 | author={Yang, Zongxin and Wei, Yunchao and Yang, Yi},
167 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
168 | year={2021}
169 | }
170 | ```
171 |
172 | ## License
173 | This project is released under the BSD-3-Clause license. See [LICENSE](LICENSE) for additional details.
174 |
--------------------------------------------------------------------------------
/configs/default.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 |
4 |
5 | class DefaultEngineConfig():
6 | def __init__(self, exp_name='default', model='aott'):
7 | model_cfg = importlib.import_module('configs.models.' +
8 | model).ModelConfig()
9 | self.__dict__.update(model_cfg.__dict__) # add model config
10 |
11 | self.EXP_NAME = exp_name + '_' + self.MODEL_NAME
12 |
13 | self.STAGE_NAME = 'YTB'
14 |
15 | self.DATASETS = ['youtubevos']
16 | self.DATA_WORKERS = 8
17 | self.DATA_RANDOMCROP = (465,
18 | 465) if self.MODEL_ALIGN_CORNERS else (464,
19 | 464)
20 | self.DATA_RANDOMFLIP = 0.5
21 | self.DATA_MAX_CROP_STEPS = 10
22 | self.DATA_SHORT_EDGE_LEN = 480
23 | self.DATA_MIN_SCALE_FACTOR = 0.7
24 | self.DATA_MAX_SCALE_FACTOR = 1.3
25 | self.DATA_RANDOM_REVERSE_SEQ = True
26 | self.DATA_SEQ_LEN = 5
27 | self.DATA_DAVIS_REPEAT = 5
28 | self.DATA_RANDOM_GAP_DAVIS = 12 # max frame interval between two sampled frames for DAVIS (24fps)
29 | self.DATA_RANDOM_GAP_YTB = 3 # max frame interval between two sampled frames for YouTube-VOS (6fps)
30 | self.DATA_DYNAMIC_MERGE_PROB = 0.3
31 |
32 | self.PRETRAIN = True
33 | self.PRETRAIN_FULL = False # if False, load encoder only
34 | self.PRETRAIN_MODEL = './data_wd/pretrain_model/mobilenet_v2.pth'
35 | # self.PRETRAIN_MODEL = './pretrain_models/mobilenet_v2-b0353104.pth'
36 |
37 | self.TRAIN_TOTAL_STEPS = 100000
38 | self.TRAIN_START_STEP = 0
39 | self.TRAIN_WEIGHT_DECAY = 0.07
40 | self.TRAIN_WEIGHT_DECAY_EXCLUSIVE = {
41 | # 'encoder.': 0.01
42 | }
43 | self.TRAIN_WEIGHT_DECAY_EXEMPTION = [
44 | 'absolute_pos_embed', 'relative_position_bias_table',
45 | 'relative_emb_v', 'conv_out'
46 | ]
47 | self.TRAIN_LR = 2e-4
48 | self.TRAIN_LR_MIN = 2e-5 if 'mobilenetv2' in self.MODEL_ENCODER else 1e-5
49 | self.TRAIN_LR_POWER = 0.9
50 | self.TRAIN_LR_ENCODER_RATIO = 0.1
51 | self.TRAIN_LR_WARM_UP_RATIO = 0.05
52 | self.TRAIN_LR_COSINE_DECAY = False
53 | self.TRAIN_LR_RESTART = 1
54 | self.TRAIN_LR_UPDATE_STEP = 1
55 | self.TRAIN_AUX_LOSS_WEIGHT = 1.0
56 | self.TRAIN_AUX_LOSS_RATIO = 1.0
57 | self.TRAIN_OPT = 'adamw'
58 | self.TRAIN_SGD_MOMENTUM = 0.9
59 | self.TRAIN_GPUS = 4
60 | self.TRAIN_BATCH_SIZE = 16
61 | self.TRAIN_TBLOG = False
62 | self.TRAIN_TBLOG_STEP = 50
63 | self.TRAIN_LOG_STEP = 20
64 | self.TRAIN_IMG_LOG = True
65 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15
66 | self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank']
67 | self.TRAIN_SEQ_TRAINING_START_RATIO = 0.5
68 | self.TRAIN_HARD_MINING_RATIO = 0.5
69 | self.TRAIN_EMA_RATIO = 0.1
70 | self.TRAIN_CLIP_GRAD_NORM = 5.
71 | self.TRAIN_SAVE_STEP = 5000
72 | self.TRAIN_MAX_KEEP_CKPT = 8
73 | self.TRAIN_RESUME = False
74 | self.TRAIN_RESUME_CKPT = None
75 | self.TRAIN_RESUME_STEP = 0
76 | self.TRAIN_AUTO_RESUME = True
77 | self.TRAIN_DATASET_FULL_RESOLUTION = False
78 | self.TRAIN_ENABLE_PREV_FRAME = False
79 | self.TRAIN_ENCODER_FREEZE_AT = 2
80 | self.TRAIN_LSTT_EMB_DROPOUT = 0.
81 | self.TRAIN_LSTT_ID_DROPOUT = 0.
82 | self.TRAIN_LSTT_DROPPATH = 0.1
83 | self.TRAIN_LSTT_DROPPATH_SCALING = False
84 | self.TRAIN_LSTT_DROPPATH_LST = False
85 | self.TRAIN_LSTT_LT_DROPOUT = 0.
86 | self.TRAIN_LSTT_ST_DROPOUT = 0.
87 |
88 | self.TEST_GPU_ID = 0
89 | self.TEST_GPU_NUM = 1
90 | self.TEST_FRAME_LOG = False
91 | self.TEST_DATASET = 'youtubevos'
92 | self.TEST_DATASET_FULL_RESOLUTION = False
93 | self.TEST_DATASET_SPLIT = 'val'
94 | self.TEST_CKPT_PATH = None
95 | # if "None", evaluate the latest checkpoint.
96 | self.TEST_CKPT_STEP = None
97 | self.TEST_FLIP = False
98 | self.TEST_MULTISCALE = [1]
99 | self.TEST_MAX_SHORT_EDGE = None
100 | self.TEST_MAX_LONG_EDGE = 800 * 1.3
101 | self.TEST_WORKERS = 4
102 |
103 | # GPU distribution
104 | self.DIST_ENABLE = True
105 | self.DIST_BACKEND = "nccl" # "gloo"
106 | self.DIST_URL = "tcp://127.0.0.1:13241"
107 | self.DIST_START_GPU = 0
108 |
109 | def init_dir(self):
110 | self.DIR_DATA = '../VOS02/datasets'#'./datasets'
111 | self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS')
112 | self.DIR_YTB = os.path.join(self.DIR_DATA, 'YTB')
113 | self.DIR_STATIC = os.path.join(self.DIR_DATA, 'Static')
114 |
115 | self.DIR_ROOT = './'#'./data_wd/youtube_vos_jobs'
116 |
117 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME,
118 | self.STAGE_NAME)
119 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt')
120 | self.DIR_EMA_CKPT = os.path.join(self.DIR_RESULT, 'ema_ckpt')
121 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log')
122 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard')
123 | # self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img')
124 | # self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval')
125 | self.DIR_IMG_LOG = './img_logs'
126 | self.DIR_EVALUATION = './results'
127 |
128 | for path in [
129 | self.DIR_RESULT, self.DIR_CKPT, self.DIR_EMA_CKPT,
130 | self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG,
131 | self.DIR_TB_LOG
132 | ]:
133 | if not os.path.isdir(path):
134 | try:
135 | os.makedirs(path)
136 | except Exception as inst:
137 | print(inst)
138 | print('Failed to make dir: {}.'.format(path))
139 |
--------------------------------------------------------------------------------
/configs/models/aotb.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .default import DefaultModelConfig
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'AOTB'
8 |
9 | self.MODEL_LSTT_NUM = 3
10 |
--------------------------------------------------------------------------------
/configs/models/aotl.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .default import DefaultModelConfig
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'AOTL'
8 |
9 | self.MODEL_LSTT_NUM = 3
10 |
11 | self.TRAIN_LONG_TERM_MEM_GAP = 2
12 |
13 | self.TEST_LONG_TERM_MEM_GAP = 5
--------------------------------------------------------------------------------
/configs/models/aots.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .default import DefaultModelConfig
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'AOTS'
8 |
9 | self.MODEL_LSTT_NUM = 2
10 |
--------------------------------------------------------------------------------
/configs/models/aott.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .default import DefaultModelConfig
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'AOTT'
8 |
--------------------------------------------------------------------------------
/configs/models/deaotb.py:
--------------------------------------------------------------------------------
1 | from .default_deaot import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'DeAOTB'
8 |
9 | self.MODEL_LSTT_NUM = 3
10 |
--------------------------------------------------------------------------------
/configs/models/deaotl.py:
--------------------------------------------------------------------------------
1 | from .default_deaot import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'DeAOTL'
8 |
9 | self.MODEL_LSTT_NUM = 3
10 |
11 | self.TRAIN_LONG_TERM_MEM_GAP = 2
12 |
13 | self.TEST_LONG_TERM_MEM_GAP = 5
14 |
--------------------------------------------------------------------------------
/configs/models/deaots.py:
--------------------------------------------------------------------------------
1 | from .default_deaot import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'DeAOTS'
8 |
9 | self.MODEL_LSTT_NUM = 2
10 |
--------------------------------------------------------------------------------
/configs/models/deaott.py:
--------------------------------------------------------------------------------
1 | from .default_deaot import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'DeAOTT'
8 |
--------------------------------------------------------------------------------
/configs/models/default.py:
--------------------------------------------------------------------------------
1 | class DefaultModelConfig():
2 | def __init__(self):
3 | self.MODEL_NAME = 'AOTDefault'
4 |
5 | self.MODEL_VOS = 'aot'
6 | self.MODEL_ENGINE = 'aotengine'
7 | self.MODEL_ALIGN_CORNERS = True
8 | self.MODEL_ENCODER = 'mobilenetv2'
9 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/mobilenet_v2-b0353104.pth'
10 | self.MODEL_ENCODER_DIM = [24, 32, 96, 1280] # 4x, 8x, 16x, 16x
11 | self.MODEL_ENCODER_EMBEDDING_DIM = 256
12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = True
13 | self.MODEL_FREEZE_BN = True
14 | self.MODEL_FREEZE_BACKBONE = False
15 | self.MODEL_MAX_OBJ_NUM = 10
16 | self.MODEL_SELF_HEADS = 8
17 | self.MODEL_ATT_HEADS = 8
18 | self.MODEL_LSTT_NUM = 1
19 | self.MODEL_EPSILON = 1e-5
20 | self.MODEL_USE_PREV_PROB = False
21 |
22 | self.TRAIN_LONG_TERM_MEM_GAP = 9999
23 | self.TRAIN_AUG_TYPE = 'v1'
24 |
25 | self.TEST_LONG_TERM_MEM_GAP = 9999
26 |
27 | self.TEST_SHORT_TERM_MEM_SKIP = 1
28 |
--------------------------------------------------------------------------------
/configs/models/default_deaot.py:
--------------------------------------------------------------------------------
1 | from .default import DefaultModelConfig as BaseConfig
2 |
3 |
4 | class DefaultModelConfig(BaseConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'DeAOTDefault'
8 |
9 | self.MODEL_VOS = 'deaot'
10 | self.MODEL_ENGINE = 'deaotengine'
11 |
12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = False
13 |
14 | self.MODEL_SELF_HEADS = 1
15 | self.MODEL_ATT_HEADS = 1
16 |
17 | self.TRAIN_AUG_TYPE = 'v2'
18 |
--------------------------------------------------------------------------------
/configs/models/r101_aotl.py:
--------------------------------------------------------------------------------
1 | from .default import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'R101_AOTL'
8 |
9 | self.MODEL_ENCODER = 'resnet101'
10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth
11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
12 | self.MODEL_LSTT_NUM = 3
13 |
14 | self.TRAIN_LONG_TERM_MEM_GAP = 2
15 |
16 | self.TEST_LONG_TERM_MEM_GAP = 5
--------------------------------------------------------------------------------
/configs/models/r50_aotl.py:
--------------------------------------------------------------------------------
1 | from .default import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'R50_AOTL'
8 |
9 | self.MODEL_ENCODER = 'resnet50'
10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth
11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
12 | self.MODEL_LSTT_NUM = 3
13 |
14 | self.TRAIN_LONG_TERM_MEM_GAP = 2
15 |
16 | self.TEST_LONG_TERM_MEM_GAP = 5
--------------------------------------------------------------------------------
/configs/models/r50_deaotl.py:
--------------------------------------------------------------------------------
1 | from .default_deaot import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'R50_DeAOTL'
8 |
9 | self.MODEL_ENCODER = 'resnet50'
10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth
11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
12 |
13 | self.MODEL_LSTT_NUM = 3
14 |
15 | self.TRAIN_LONG_TERM_MEM_GAP = 2
16 |
17 | self.TEST_LONG_TERM_MEM_GAP = 5
18 |
--------------------------------------------------------------------------------
/configs/models/rs101_aotl.py:
--------------------------------------------------------------------------------
1 | from .default import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'R101_AOTL'
8 |
9 | self.MODEL_ENCODER = 'resnest101'
10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnest101-22405ba7.pth' # https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth
11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
12 | self.MODEL_LSTT_NUM = 3
13 |
14 | self.TRAIN_LONG_TERM_MEM_GAP = 2
15 |
16 | self.TEST_LONG_TERM_MEM_GAP = 5
--------------------------------------------------------------------------------
/configs/models/swinb_aotl.py:
--------------------------------------------------------------------------------
1 | from .default import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'SwinB_AOTL'
8 |
9 | self.MODEL_ENCODER = 'swin_base'
10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
11 | self.MODEL_ALIGN_CORNERS = False
12 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x
13 | self.MODEL_LSTT_NUM = 3
14 |
15 | self.TRAIN_LONG_TERM_MEM_GAP = 2
16 |
17 | self.TEST_LONG_TERM_MEM_GAP = 5
--------------------------------------------------------------------------------
/configs/models/swinb_deaotl.py:
--------------------------------------------------------------------------------
1 | from .default_deaot import DefaultModelConfig
2 |
3 |
4 | class ModelConfig(DefaultModelConfig):
5 | def __init__(self):
6 | super().__init__()
7 | self.MODEL_NAME = 'SwinB_DeAOTL'
8 |
9 | self.MODEL_ENCODER = 'swin_base'
10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
11 | self.MODEL_ALIGN_CORNERS = False
12 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x
13 |
14 | self.MODEL_LSTT_NUM = 3
15 |
16 | self.TRAIN_LONG_TERM_MEM_GAP = 2
17 |
18 | self.TEST_LONG_TERM_MEM_GAP = 5
19 |
--------------------------------------------------------------------------------
/configs/pre.py:
--------------------------------------------------------------------------------
1 | from .default import DefaultEngineConfig
2 |
3 |
4 | class EngineConfig(DefaultEngineConfig):
5 | def __init__(self, exp_name='default', model='AOTT'):
6 | super().__init__(exp_name, model)
7 | self.STAGE_NAME = 'PRE'
8 |
9 | self.init_dir()
10 |
11 | self.DATASETS = ['static']
12 |
13 | self.DATA_DYNAMIC_MERGE_PROB = 1.0
14 |
15 | self.TRAIN_LR = 4e-4
16 | self.TRAIN_LR_MIN = 2e-5
17 | self.TRAIN_WEIGHT_DECAY = 0.03
18 | self.TRAIN_SEQ_TRAINING_START_RATIO = 1.0
19 | self.TRAIN_AUX_LOSS_RATIO = 0.1
20 |
--------------------------------------------------------------------------------
/configs/pre_dav.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .default import DefaultEngineConfig
3 |
4 |
5 | class EngineConfig(DefaultEngineConfig):
6 | def __init__(self, exp_name='default', model='AOTT'):
7 | super().__init__(exp_name, model)
8 | self.STAGE_NAME = 'PRE_DAV'
9 |
10 | self.init_dir()
11 |
12 | self.DATASETS = ['davis2017']
13 |
14 | self.TRAIN_TOTAL_STEPS = 50000
15 |
16 | pretrain_stage = 'PRE'
17 | pretrain_ckpt = 'save_step_100000.pth'
18 | self.PRETRAIN_FULL = True # if False, load encoder only
19 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result',
20 | self.EXP_NAME, pretrain_stage,
21 | 'ema_ckpt', pretrain_ckpt)
22 |
--------------------------------------------------------------------------------
/configs/pre_ytb.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .default import DefaultEngineConfig
3 |
4 |
5 | class EngineConfig(DefaultEngineConfig):
6 | def __init__(self, exp_name='default', model='AOTT'):
7 | super().__init__(exp_name, model)
8 | self.STAGE_NAME = 'PRE_YTB'
9 |
10 | self.init_dir()
11 |
12 | pretrain_stage = 'PRE'
13 | pretrain_ckpt = 'save_step_100000.pth'
14 | self.PRETRAIN_FULL = True # if False, load encoder only
15 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result',
16 | self.EXP_NAME, pretrain_stage,
17 | 'ema_ckpt', pretrain_ckpt)
18 |
--------------------------------------------------------------------------------
/configs/pre_ytb_dav.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .default import DefaultEngineConfig
3 |
4 |
5 | class EngineConfig(DefaultEngineConfig):
6 | def __init__(self, exp_name='default', model='AOTT'):
7 | super().__init__(exp_name, model)
8 | self.STAGE_NAME = 'PRE_YTB_DAV'
9 |
10 | self.init_dir()
11 |
12 | self.DATASETS = ['youtubevos', 'davis2017']
13 |
14 | pretrain_stage = 'PRE'
15 | pretrain_ckpt = 'save_step_100000.pth'
16 | self.PRETRAIN_FULL = True # if False, load encoder only
17 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result',
18 | self.EXP_NAME, pretrain_stage,
19 | 'ema_ckpt', pretrain_ckpt)
20 |
--------------------------------------------------------------------------------
/configs/ytb.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .default import DefaultEngineConfig
3 |
4 |
5 | class EngineConfig(DefaultEngineConfig):
6 | def __init__(self, exp_name='default', model='AOTT'):
7 | super().__init__(exp_name, model)
8 | self.STAGE_NAME = 'YTB'
9 |
10 | self.init_dir()
11 |
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/dataloaders/__init__.py
--------------------------------------------------------------------------------
/dataloaders/eval_datasets.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import shutil
4 | import json
5 | import cv2
6 | from PIL import Image
7 |
8 | import numpy as np
9 | from torch.utils.data import Dataset
10 |
11 | from utils.image import _palette
12 |
13 |
14 | class VOSTest(Dataset):
15 | def __init__(self,
16 | image_root,
17 | label_root,
18 | seq_name,
19 | images,
20 | labels,
21 | rgb=True,
22 | transform=None,
23 | single_obj=False,
24 | resolution=None):
25 | self.image_root = image_root
26 | self.label_root = label_root
27 | self.seq_name = seq_name
28 | self.images = images
29 | self.labels = labels
30 | self.obj_num = 1
31 | self.num_frame = len(self.images)
32 | self.transform = transform
33 | self.rgb = rgb
34 | self.single_obj = single_obj
35 | self.resolution = resolution
36 |
37 | self.obj_nums = []
38 | self.obj_indices = []
39 |
40 | curr_objs = [0]
41 | for img_name in self.images:
42 | self.obj_nums.append(len(curr_objs) - 1)
43 | current_label_name = img_name.split('.')[0] + '.png'
44 | if current_label_name in self.labels:
45 | current_label = self.read_label(current_label_name)
46 | curr_obj = list(np.unique(current_label))
47 | for obj_idx in curr_obj:
48 | if obj_idx not in curr_objs:
49 | curr_objs.append(obj_idx)
50 | self.obj_indices.append(curr_objs.copy())
51 |
52 | self.obj_nums[0] = self.obj_nums[1]
53 |
54 | def __len__(self):
55 | return len(self.images)
56 |
57 | def read_image(self, idx):
58 | img_name = self.images[idx]
59 | img_path = os.path.join(self.image_root, self.seq_name, img_name)
60 | img = cv2.imread(img_path)
61 | img = np.array(img, dtype=np.float32)
62 | if self.rgb:
63 | img = img[:, :, [2, 1, 0]]
64 | return img
65 |
66 | def read_label(self, label_name, squeeze_idx=None):
67 | label_path = os.path.join(self.label_root, self.seq_name, label_name)
68 | label = Image.open(label_path)
69 | label = np.array(label, dtype=np.uint8)
70 | if self.single_obj:
71 | label = (label > 0).astype(np.uint8)
72 | elif squeeze_idx is not None:
73 | squeezed_label = label * 0
74 | for idx in range(len(squeeze_idx)):
75 | obj_id = squeeze_idx[idx]
76 | if obj_id == 0:
77 | continue
78 | mask = label == obj_id
79 | squeezed_label += (mask * idx).astype(np.uint8)
80 | label = squeezed_label
81 | return label
82 |
83 | def __getitem__(self, idx):
84 | img_name = self.images[idx]
85 | current_img = self.read_image(idx)
86 | height, width, channels = current_img.shape
87 | if self.resolution is not None:
88 | width = int(np.ceil(
89 | float(width) * self.resolution / float(height)))
90 | height = int(self.resolution)
91 |
92 | current_label_name = img_name.split('.')[0] + '.png'
93 | obj_num = self.obj_nums[idx]
94 | obj_idx = self.obj_indices[idx]
95 |
96 | if current_label_name in self.labels:
97 | current_label = self.read_label(current_label_name, obj_idx)
98 | sample = {
99 | 'current_img': current_img,
100 | 'current_label': current_label
101 | }
102 | else:
103 | sample = {'current_img': current_img}
104 |
105 | sample['meta'] = {
106 | 'seq_name': self.seq_name,
107 | 'frame_num': self.num_frame,
108 | 'obj_num': obj_num,
109 | 'current_name': img_name,
110 | 'height': height,
111 | 'width': width,
112 | 'flip': False,
113 | 'obj_idx': obj_idx
114 | }
115 |
116 | if self.transform is not None:
117 | sample = self.transform(sample)
118 | return sample
119 |
120 |
121 | class YOUTUBEVOS_Test(object):
122 | def __init__(self,
123 | root='./datasets/YTB',
124 | year=2018,
125 | split='val',
126 | transform=None,
127 | rgb=True,
128 | result_root=None):
129 | if split == 'val':
130 | split = 'valid'
131 | root = os.path.join(root, str(year), split)
132 | self.db_root_dir = root
133 | self.result_root = result_root
134 | self.rgb = rgb
135 | self.transform = transform
136 | self.seq_list_file = os.path.join(self.db_root_dir, 'meta.json')
137 | self._check_preprocess()
138 | self.seqs = list(self.ann_f.keys())
139 | self.image_root = os.path.join(root, 'JPEGImages')
140 | self.label_root = os.path.join(root, 'Annotations')
141 |
142 | def __len__(self):
143 | return len(self.seqs)
144 |
145 | def __getitem__(self, idx):
146 | seq_name = self.seqs[idx]
147 | data = self.ann_f[seq_name]['objects']
148 | obj_names = list(data.keys())
149 | images = []
150 | labels = []
151 | for obj_n in obj_names:
152 | images += map(lambda x: x + '.jpg', list(data[obj_n]["frames"]))
153 | labels.append(data[obj_n]["frames"][0] + '.png')
154 | images = np.sort(np.unique(images))
155 | labels = np.sort(np.unique(labels))
156 |
157 | try:
158 | if not os.path.isfile(
159 | os.path.join(self.result_root, seq_name, labels[0])):
160 | if not os.path.exists(os.path.join(self.result_root,
161 | seq_name)):
162 | os.makedirs(os.path.join(self.result_root, seq_name))
163 | shutil.copy(
164 | os.path.join(self.label_root, seq_name, labels[0]),
165 | os.path.join(self.result_root, seq_name, labels[0]))
166 | except Exception as inst:
167 | print(inst)
168 | print('Failed to create a result folder for sequence {}.'.format(
169 | seq_name))
170 |
171 | seq_dataset = VOSTest(self.image_root,
172 | self.label_root,
173 | seq_name,
174 | images,
175 | labels,
176 | transform=self.transform,
177 | rgb=self.rgb)
178 | return seq_dataset
179 |
180 | def _check_preprocess(self):
181 | _seq_list_file = self.seq_list_file
182 | if not os.path.isfile(_seq_list_file):
183 | print(_seq_list_file)
184 | return False
185 | else:
186 | self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos']
187 | return True
188 |
189 |
190 | class YOUTUBEVOS_DenseTest(object):
191 | def __init__(self,
192 | root='./datasets/YTB',
193 | year=2018,
194 | split='val',
195 | transform=None,
196 | rgb=True,
197 | result_root=None):
198 | if split == 'val':
199 | split = 'valid'
200 | root_sparse = os.path.join(root, str(year), split)
201 | root_dense = root_sparse + '_all_frames'
202 | self.db_root_dir = root_dense
203 | self.result_root = result_root
204 | self.rgb = rgb
205 | self.transform = transform
206 | self.seq_list_file = os.path.join(root_sparse, 'meta.json')
207 | self._check_preprocess()
208 | self.seqs = list(self.ann_f.keys())
209 | self.image_root = os.path.join(root_dense, 'JPEGImages')
210 | self.label_root = os.path.join(root_sparse, 'Annotations')
211 |
212 | def __len__(self):
213 | return len(self.seqs)
214 |
215 | def __getitem__(self, idx):
216 | seq_name = self.seqs[idx]
217 |
218 | data = self.ann_f[seq_name]['objects']
219 | obj_names = list(data.keys())
220 | images_sparse = []
221 | for obj_n in obj_names:
222 | images_sparse += map(lambda x: x + '.jpg',
223 | list(data[obj_n]["frames"]))
224 | images_sparse = np.sort(np.unique(images_sparse))
225 |
226 | images = np.sort(
227 | list(os.listdir(os.path.join(self.image_root, seq_name))))
228 | start_img = images_sparse[0]
229 | end_img = images_sparse[-1]
230 | for start_idx in range(len(images)):
231 | if start_img in images[start_idx]:
232 | break
233 | for end_idx in range(len(images))[::-1]:
234 | if end_img in images[end_idx]:
235 | break
236 | images = images[start_idx:(end_idx + 1)]
237 | labels = np.sort(
238 | list(os.listdir(os.path.join(self.label_root, seq_name))))
239 |
240 | try:
241 | if not os.path.isfile(
242 | os.path.join(self.result_root, seq_name, labels[0])):
243 | if not os.path.exists(os.path.join(self.result_root,
244 | seq_name)):
245 | os.makedirs(os.path.join(self.result_root, seq_name))
246 | shutil.copy(
247 | os.path.join(self.label_root, seq_name, labels[0]),
248 | os.path.join(self.result_root, seq_name, labels[0]))
249 | except Exception as inst:
250 | print(inst)
251 | print('Failed to create a result folder for sequence {}.'.format(
252 | seq_name))
253 |
254 | seq_dataset = VOSTest(self.image_root,
255 | self.label_root,
256 | seq_name,
257 | images,
258 | labels,
259 | transform=self.transform,
260 | rgb=self.rgb)
261 | seq_dataset.images_sparse = images_sparse
262 |
263 | return seq_dataset
264 |
265 | def _check_preprocess(self):
266 | _seq_list_file = self.seq_list_file
267 | if not os.path.isfile(_seq_list_file):
268 | print(_seq_list_file)
269 | return False
270 | else:
271 | self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos']
272 | return True
273 |
274 |
275 | class DAVIS_Test(object):
276 | def __init__(self,
277 | split=['val'],
278 | root='./DAVIS',
279 | year=2017,
280 | transform=None,
281 | rgb=True,
282 | full_resolution=False,
283 | result_root=None):
284 | self.transform = transform
285 | self.rgb = rgb
286 | self.result_root = result_root
287 | if year == 2016:
288 | self.single_obj = True
289 | else:
290 | self.single_obj = False
291 | if full_resolution:
292 | resolution = 'Full-Resolution'
293 | else:
294 | resolution = '480p'
295 | self.image_root = os.path.join(root, 'JPEGImages', resolution)
296 | self.label_root = os.path.join(root, 'Annotations', resolution)
297 | seq_names = []
298 | for spt in split:
299 | if spt == 'test':
300 | spt = 'test-dev'
301 | with open(os.path.join(root, 'ImageSets', str(year),
302 | spt + '.txt')) as f:
303 | seqs_tmp = f.readlines()
304 | seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
305 | seq_names.extend(seqs_tmp)
306 | self.seqs = list(np.unique(seq_names))
307 |
308 | def __len__(self):
309 | return len(self.seqs)
310 |
311 | def __getitem__(self, idx):
312 | seq_name = self.seqs[idx]
313 | images = list(
314 | np.sort(os.listdir(os.path.join(self.image_root, seq_name))))
315 | labels = [images[0].replace('jpg', 'png')]
316 |
317 | if not os.path.isfile(
318 | os.path.join(self.result_root, seq_name, labels[0])):
319 | seq_result_folder = os.path.join(self.result_root, seq_name)
320 | try:
321 | if not os.path.exists(seq_result_folder):
322 | os.makedirs(seq_result_folder)
323 | except Exception as inst:
324 | print(inst)
325 | print(
326 | 'Failed to create a result folder for sequence {}.'.format(
327 | seq_name))
328 | source_label_path = os.path.join(self.label_root, seq_name,
329 | labels[0])
330 | result_label_path = os.path.join(self.result_root, seq_name,
331 | labels[0])
332 | if self.single_obj:
333 | label = Image.open(source_label_path)
334 | label = np.array(label, dtype=np.uint8)
335 | label = (label > 0).astype(np.uint8)
336 | label = Image.fromarray(label).convert('P')
337 | label.putpalette(_palette)
338 | label.save(result_label_path)
339 | else:
340 | shutil.copy(source_label_path, result_label_path)
341 |
342 | seq_dataset = VOSTest(self.image_root,
343 | self.label_root,
344 | seq_name,
345 | images,
346 | labels,
347 | transform=self.transform,
348 | rgb=self.rgb,
349 | single_obj=self.single_obj,
350 | resolution=480)
351 | return seq_dataset
352 |
353 |
354 | class _EVAL_TEST(Dataset):
355 | def __init__(self, transform, seq_name):
356 | self.seq_name = seq_name
357 | self.num_frame = 10
358 | self.transform = transform
359 |
360 | def __len__(self):
361 | return self.num_frame
362 |
363 | def __getitem__(self, idx):
364 | current_frame_obj_num = 2
365 | height = 400
366 | width = 400
367 | img_name = 'test{}.jpg'.format(idx)
368 | current_img = np.zeros((height, width, 3)).astype(np.float32)
369 | if idx == 0:
370 | current_label = (current_frame_obj_num * np.ones(
371 | (height, width))).astype(np.uint8)
372 | sample = {
373 | 'current_img': current_img,
374 | 'current_label': current_label
375 | }
376 | else:
377 | sample = {'current_img': current_img}
378 |
379 | sample['meta'] = {
380 | 'seq_name': self.seq_name,
381 | 'frame_num': self.num_frame,
382 | 'obj_num': current_frame_obj_num,
383 | 'current_name': img_name,
384 | 'height': height,
385 | 'width': width,
386 | 'flip': False
387 | }
388 |
389 | if self.transform is not None:
390 | sample = self.transform(sample)
391 | return sample
392 |
393 |
394 | class EVAL_TEST(object):
395 | def __init__(self, transform=None, result_root=None):
396 | self.transform = transform
397 | self.result_root = result_root
398 |
399 | self.seqs = ['test1', 'test2', 'test3']
400 |
401 | def __len__(self):
402 | return len(self.seqs)
403 |
404 | def __getitem__(self, idx):
405 | seq_name = self.seqs[idx]
406 |
407 | if not os.path.exists(os.path.join(self.result_root, seq_name)):
408 | os.makedirs(os.path.join(self.result_root, seq_name))
409 |
410 | seq_dataset = _EVAL_TEST(self.transform, seq_name)
411 | return seq_dataset
412 |
--------------------------------------------------------------------------------
/datasets/DAVIS/README.md:
--------------------------------------------------------------------------------
1 | Put DAVIS 2017 here.
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002058.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002058.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002059.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002059.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002060.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002060.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002061.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002061.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002062.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002062.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002063.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002063.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002064.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002064.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002065.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002065.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002066.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002066.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002067.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002067.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002068.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002068.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002069.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002069.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002070.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002070.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002071.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002071.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002072.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002072.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002073.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002073.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002074.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002074.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002075.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002075.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002076.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002076.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002077.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002077.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002078.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002078.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002079.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002079.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002080.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002080.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002081.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002081.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002082.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002082.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002083.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002083.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002084.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002084.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002085.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002085.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002086.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002086.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002087.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002087.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002088.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002088.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002089.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002089.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002090.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002090.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002091.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002091.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002092.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002092.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002093.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002093.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002094.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002094.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002095.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002095.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002096.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002096.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002097.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002097.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002098.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002098.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002099.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002099.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002100.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002100.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002101.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002101.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1001_3iEIq5HBY1s/00002102.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002102.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000693.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000693.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000694.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000694.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000695.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000695.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000696.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000696.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000697.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000697.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000698.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000698.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000699.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000699.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000700.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000700.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000701.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000701.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000702.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000702.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000703.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000703.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000704.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000704.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000705.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000705.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000706.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000706.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000707.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000707.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000708.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000708.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000709.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000709.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000710.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000710.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000711.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000711.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000712.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000712.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000713.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000713.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000714.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000714.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000715.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000715.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000716.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000716.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000717.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000717.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000718.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000718.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000719.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000719.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000720.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000720.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000721.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000721.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000722.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000722.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000723.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000723.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000724.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000724.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000725.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000725.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000726.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000726.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000727.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000727.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000728.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000728.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000729.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000729.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000730.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000730.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000731.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000731.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000732.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000732.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000733.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000733.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000734.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000734.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000735.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000735.jpg
--------------------------------------------------------------------------------
/datasets/Demo/images/1007_YCTBBdbKSSg/00000736.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000736.jpg
--------------------------------------------------------------------------------
/datasets/Demo/masks/1001_3iEIq5HBY1s/00002058.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/masks/1001_3iEIq5HBY1s/00002058.png
--------------------------------------------------------------------------------
/datasets/Demo/masks/1007_YCTBBdbKSSg/00000693.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/masks/1007_YCTBBdbKSSg/00000693.png
--------------------------------------------------------------------------------
/datasets/Static/README.md:
--------------------------------------------------------------------------------
1 | Put the static dataset here. Guidance can be found in [AFB-URR](https://github.com/xmlyqing00/AFB-URR), which we referred to in the implementation of the pre-training.
2 |
--------------------------------------------------------------------------------
/datasets/YTB/2018/train/README.md:
--------------------------------------------------------------------------------
1 | Put the training split of YouTube-VOS 2018 here.
--------------------------------------------------------------------------------
/datasets/YTB/2018/valid/README.md:
--------------------------------------------------------------------------------
1 | Put the validation split of YouTube-VOS 2018 here.
--------------------------------------------------------------------------------
/datasets/YTB/2018/valid_all_frames/README.md:
--------------------------------------------------------------------------------
1 | Put the all-frame validation split of YouTube-VOS 2018 here.
--------------------------------------------------------------------------------
/datasets/YTB/2019/train/README.md:
--------------------------------------------------------------------------------
1 | Put the training split of YouTube-VOS 2019 here.
--------------------------------------------------------------------------------
/datasets/YTB/2019/valid/README.md:
--------------------------------------------------------------------------------
1 | Put the validation split of YouTube-VOS 2019 here.
--------------------------------------------------------------------------------
/datasets/YTB/2019/valid_all_frames/README.md:
--------------------------------------------------------------------------------
1 | Put the all-frame validation split of YouTube-VOS 2018 here.
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/networks/__init__.py
--------------------------------------------------------------------------------
/networks/decoders/__init__.py:
--------------------------------------------------------------------------------
1 | from networks.decoders.fpn import FPNSegmentationHead
2 |
3 |
4 | def build_decoder(name, **kwargs):
5 |
6 | if name == 'fpn':
7 | return FPNSegmentationHead(**kwargs)
8 | else:
9 | raise NotImplementedError
10 |
--------------------------------------------------------------------------------
/networks/decoders/fpn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from networks.layers.basic import ConvGN
5 |
6 |
7 | class FPNSegmentationHead(nn.Module):
8 | def __init__(self,
9 | in_dim,
10 | out_dim,
11 | decode_intermediate_input=True,
12 | hidden_dim=256,
13 | shortcut_dims=[24, 32, 96, 1280],
14 | align_corners=True):
15 | super().__init__()
16 | self.align_corners = align_corners
17 |
18 | self.decode_intermediate_input = decode_intermediate_input
19 |
20 | self.conv_in = ConvGN(in_dim, hidden_dim, 1)
21 |
22 | self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3)
23 | self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3)
24 | self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3)
25 |
26 | self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1)
27 | self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1)
28 | self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1)
29 |
30 | self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1)
31 |
32 | self._init_weight()
33 |
34 | def forward(self, inputs, shortcuts):
35 |
36 | if self.decode_intermediate_input:
37 | x = torch.cat(inputs, dim=1)
38 | else:
39 | x = inputs[-1]
40 |
41 | x = F.relu_(self.conv_in(x))
42 | x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x))
43 |
44 | x = F.interpolate(x,
45 | size=shortcuts[-3].size()[-2:],
46 | mode="bilinear",
47 | align_corners=self.align_corners)
48 | x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x))
49 |
50 | x = F.interpolate(x,
51 | size=shortcuts[-4].size()[-2:],
52 | mode="bilinear",
53 | align_corners=self.align_corners)
54 | x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x))
55 |
56 | x = self.conv_out(x)
57 |
58 | return x
59 |
60 | def _init_weight(self):
61 | for p in self.parameters():
62 | if p.dim() > 1:
63 | nn.init.xavier_uniform_(p)
64 |
--------------------------------------------------------------------------------
/networks/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from networks.encoders.mobilenetv2 import MobileNetV2
2 | from networks.encoders.mobilenetv3 import MobileNetV3Large
3 | from networks.encoders.resnet import ResNet101, ResNet50
4 | from networks.encoders.resnest import resnest
5 | from networks.encoders.swin import build_swin_model
6 | from networks.layers.normalization import FrozenBatchNorm2d
7 | from torch import nn
8 |
9 |
10 | def build_encoder(name, frozen_bn=True, freeze_at=-1):
11 | if frozen_bn:
12 | BatchNorm = FrozenBatchNorm2d
13 | else:
14 | BatchNorm = nn.BatchNorm2d
15 |
16 | if name == 'mobilenetv2':
17 | return MobileNetV2(16, BatchNorm, freeze_at=freeze_at)
18 | elif name == 'mobilenetv3':
19 | return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at)
20 | elif name == 'resnet50':
21 | return ResNet50(16, BatchNorm, freeze_at=freeze_at)
22 | elif name == 'resnet101':
23 | return ResNet101(16, BatchNorm, freeze_at=freeze_at)
24 | elif name == 'resnest50':
25 | return resnest.resnest50(norm_layer=BatchNorm,
26 | dilation=2,
27 | freeze_at=freeze_at)
28 | elif name == 'resnest101':
29 | return resnest.resnest101(norm_layer=BatchNorm,
30 | dilation=2,
31 | freeze_at=freeze_at)
32 | elif 'swin' in name:
33 | return build_swin_model(name, freeze_at=freeze_at)
34 | else:
35 | raise NotImplementedError
36 |
--------------------------------------------------------------------------------
/networks/encoders/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch import Tensor
3 | from typing import Callable, Optional, List
4 | from utils.learning import freeze_params
5 |
6 | __all__ = ['MobileNetV2']
7 |
8 |
9 | def _make_divisible(v: float,
10 | divisor: int,
11 | min_value: Optional[int] = None) -> int:
12 | """
13 | This function is taken from the original tf repo.
14 | It ensures that all layers have a channel number that is divisible by 8
15 | It can be seen here:
16 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
17 | """
18 | if min_value is None:
19 | min_value = divisor
20 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
21 | # Make sure that round down does not go down by more than 10%.
22 | if new_v < 0.9 * v:
23 | new_v += divisor
24 | return new_v
25 |
26 |
27 | class ConvBNActivation(nn.Sequential):
28 | def __init__(
29 | self,
30 | in_planes: int,
31 | out_planes: int,
32 | kernel_size: int = 3,
33 | stride: int = 1,
34 | groups: int = 1,
35 | padding: int = -1,
36 | norm_layer: Optional[Callable[..., nn.Module]] = None,
37 | activation_layer: Optional[Callable[..., nn.Module]] = None,
38 | dilation: int = 1,
39 | ) -> None:
40 | if padding == -1:
41 | padding = (kernel_size - 1) // 2 * dilation
42 | if norm_layer is None:
43 | norm_layer = nn.BatchNorm2d
44 | if activation_layer is None:
45 | activation_layer = nn.ReLU6
46 | super().__init__(
47 | nn.Conv2d(in_planes,
48 | out_planes,
49 | kernel_size,
50 | stride,
51 | padding,
52 | dilation=dilation,
53 | groups=groups,
54 | bias=False), norm_layer(out_planes),
55 | activation_layer(inplace=True))
56 | self.out_channels = out_planes
57 |
58 |
59 | # necessary for backwards compatibility
60 | ConvBNReLU = ConvBNActivation
61 |
62 |
63 | class InvertedResidual(nn.Module):
64 | def __init__(
65 | self,
66 | inp: int,
67 | oup: int,
68 | stride: int,
69 | dilation: int,
70 | expand_ratio: int,
71 | norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
72 | super(InvertedResidual, self).__init__()
73 | self.stride = stride
74 | assert stride in [1, 2]
75 |
76 | if norm_layer is None:
77 | norm_layer = nn.BatchNorm2d
78 |
79 | self.kernel_size = 3
80 | self.dilation = dilation
81 |
82 | hidden_dim = int(round(inp * expand_ratio))
83 | self.use_res_connect = self.stride == 1 and inp == oup
84 |
85 | layers: List[nn.Module] = []
86 | if expand_ratio != 1:
87 | # pw
88 | layers.append(
89 | ConvBNReLU(inp,
90 | hidden_dim,
91 | kernel_size=1,
92 | norm_layer=norm_layer))
93 | layers.extend([
94 | # dw
95 | ConvBNReLU(hidden_dim,
96 | hidden_dim,
97 | stride=stride,
98 | dilation=dilation,
99 | groups=hidden_dim,
100 | norm_layer=norm_layer),
101 | # pw-linear
102 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
103 | norm_layer(oup),
104 | ])
105 | self.conv = nn.Sequential(*layers)
106 | self.out_channels = oup
107 | self._is_cn = stride > 1
108 |
109 | def forward(self, x: Tensor) -> Tensor:
110 | if self.use_res_connect:
111 | return x + self.conv(x)
112 | else:
113 | return self.conv(x)
114 |
115 |
116 | class MobileNetV2(nn.Module):
117 | def __init__(self,
118 | output_stride=8,
119 | norm_layer: Optional[Callable[..., nn.Module]] = None,
120 | width_mult: float = 1.0,
121 | inverted_residual_setting: Optional[List[List[int]]] = None,
122 | round_nearest: int = 8,
123 | block: Optional[Callable[..., nn.Module]] = None,
124 | freeze_at=0) -> None:
125 | """
126 | MobileNet V2 main class
127 | Args:
128 | num_classes (int): Number of classes
129 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
130 | inverted_residual_setting: Network structure
131 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
132 | Set to 1 to turn off rounding
133 | block: Module specifying inverted residual building block for mobilenet
134 | norm_layer: Module specifying the normalization layer to use
135 | """
136 | super(MobileNetV2, self).__init__()
137 |
138 | if block is None:
139 | block = InvertedResidual
140 |
141 | if norm_layer is None:
142 | norm_layer = nn.BatchNorm2d
143 |
144 | last_channel = 1280
145 | input_channel = 32
146 | current_stride = 1
147 | rate = 1
148 |
149 | if inverted_residual_setting is None:
150 | inverted_residual_setting = [
151 | # t, c, n, s
152 | [1, 16, 1, 1],
153 | [6, 24, 2, 2],
154 | [6, 32, 3, 2],
155 | [6, 64, 4, 2],
156 | [6, 96, 3, 1],
157 | [6, 160, 3, 2],
158 | [6, 320, 1, 1],
159 | ]
160 |
161 | # only check the first element, assuming user knows t,c,n,s are required
162 | if len(inverted_residual_setting) == 0 or len(
163 | inverted_residual_setting[0]) != 4:
164 | raise ValueError("inverted_residual_setting should be non-empty "
165 | "or a 4-element list, got {}".format(
166 | inverted_residual_setting))
167 |
168 | # building first layer
169 | input_channel = _make_divisible(input_channel * width_mult,
170 | round_nearest)
171 | self.last_channel = _make_divisible(
172 | last_channel * max(1.0, width_mult), round_nearest)
173 | features: List[nn.Module] = [
174 | ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)
175 | ]
176 | current_stride *= 2
177 | # building inverted residual blocks
178 | for t, c, n, s in inverted_residual_setting:
179 | if current_stride == output_stride:
180 | stride = 1
181 | dilation = rate
182 | rate *= s
183 | else:
184 | stride = s
185 | dilation = 1
186 | current_stride *= s
187 | output_channel = _make_divisible(c * width_mult, round_nearest)
188 | for i in range(n):
189 | if i == 0:
190 | features.append(
191 | block(input_channel, output_channel, stride, dilation,
192 | t, norm_layer))
193 | else:
194 | features.append(
195 | block(input_channel, output_channel, 1, rate, t,
196 | norm_layer))
197 | input_channel = output_channel
198 |
199 | # building last several layers
200 | features.append(
201 | ConvBNReLU(input_channel,
202 | self.last_channel,
203 | kernel_size=1,
204 | norm_layer=norm_layer))
205 | # make it nn.Sequential
206 | self.features = nn.Sequential(*features)
207 |
208 | self._initialize_weights()
209 |
210 | feature_4x = self.features[0:4]
211 | feautre_8x = self.features[4:7]
212 | feature_16x = self.features[7:14]
213 | feature_32x = self.features[14:]
214 |
215 | self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x]
216 |
217 | self.freeze(freeze_at)
218 |
219 | def forward(self, x):
220 | xs = []
221 | for stage in self.stages:
222 | x = stage(x)
223 | xs.append(x)
224 | return xs
225 |
226 | def _initialize_weights(self):
227 | # weight initialization
228 | for m in self.modules():
229 | if isinstance(m, nn.Conv2d):
230 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
231 | if m.bias is not None:
232 | nn.init.zeros_(m.bias)
233 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
234 | nn.init.ones_(m.weight)
235 | nn.init.zeros_(m.bias)
236 | elif isinstance(m, nn.Linear):
237 | nn.init.normal_(m.weight, 0, 0.01)
238 | nn.init.zeros_(m.bias)
239 |
240 | def freeze(self, freeze_at):
241 | if freeze_at >= 1:
242 | for m in self.stages[0][0]:
243 | freeze_params(m)
244 |
245 | for idx, stage in enumerate(self.stages, start=2):
246 | if freeze_at >= idx:
247 | freeze_params(stage)
248 |
--------------------------------------------------------------------------------
/networks/encoders/mobilenetv3.py:
--------------------------------------------------------------------------------
1 | """
2 | Creates a MobileNetV3 Model as defined in:
3 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019).
4 | Searching for MobileNetV3
5 | arXiv preprint arXiv:1905.02244.
6 | """
7 |
8 | import torch.nn as nn
9 | import math
10 | from utils.learning import freeze_params
11 |
12 |
13 | def _make_divisible(v, divisor, min_value=None):
14 | """
15 | This function is taken from the original tf repo.
16 | It ensures that all layers have a channel number that is divisible by 8
17 | It can be seen here:
18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
19 | :param v:
20 | :param divisor:
21 | :param min_value:
22 | :return:
23 | """
24 | if min_value is None:
25 | min_value = divisor
26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
27 | # Make sure that round down does not go down by more than 10%.
28 | if new_v < 0.9 * v:
29 | new_v += divisor
30 | return new_v
31 |
32 |
33 | class h_sigmoid(nn.Module):
34 | def __init__(self, inplace=True):
35 | super(h_sigmoid, self).__init__()
36 | self.relu = nn.ReLU6(inplace=inplace)
37 |
38 | def forward(self, x):
39 | return self.relu(x + 3) / 6
40 |
41 |
42 | class h_swish(nn.Module):
43 | def __init__(self, inplace=True):
44 | super(h_swish, self).__init__()
45 | self.sigmoid = h_sigmoid(inplace=inplace)
46 |
47 | def forward(self, x):
48 | return x * self.sigmoid(x)
49 |
50 |
51 | class SELayer(nn.Module):
52 | def __init__(self, channel, reduction=4):
53 | super(SELayer, self).__init__()
54 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
55 | self.fc = nn.Sequential(
56 | nn.Linear(channel, _make_divisible(channel // reduction, 8)),
57 | nn.ReLU(inplace=True),
58 | nn.Linear(_make_divisible(channel // reduction, 8), channel),
59 | h_sigmoid())
60 |
61 | def forward(self, x):
62 | b, c, _, _ = x.size()
63 | y = self.avg_pool(x).view(b, c)
64 | y = self.fc(y).view(b, c, 1, 1)
65 | return x * y
66 |
67 |
68 | def conv_3x3_bn(inp, oup, stride, norm_layer=nn.BatchNorm2d):
69 | return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
70 | norm_layer(oup), h_swish())
71 |
72 |
73 | def conv_1x1_bn(inp, oup, norm_layer=nn.BatchNorm2d):
74 | return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
75 | norm_layer(oup), h_swish())
76 |
77 |
78 | class InvertedResidual(nn.Module):
79 | def __init__(self,
80 | inp,
81 | hidden_dim,
82 | oup,
83 | kernel_size,
84 | stride,
85 | use_se,
86 | use_hs,
87 | dilation=1,
88 | norm_layer=nn.BatchNorm2d):
89 | super(InvertedResidual, self).__init__()
90 | assert stride in [1, 2]
91 |
92 | self.identity = stride == 1 and inp == oup
93 |
94 | if inp == hidden_dim:
95 | self.conv = nn.Sequential(
96 | # dw
97 | nn.Conv2d(hidden_dim,
98 | hidden_dim,
99 | kernel_size,
100 | stride, (kernel_size - 1) // 2 * dilation,
101 | dilation=dilation,
102 | groups=hidden_dim,
103 | bias=False),
104 | norm_layer(hidden_dim),
105 | h_swish() if use_hs else nn.ReLU(inplace=True),
106 | # Squeeze-and-Excite
107 | SELayer(hidden_dim) if use_se else nn.Identity(),
108 | # pw-linear
109 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
110 | norm_layer(oup),
111 | )
112 | else:
113 | self.conv = nn.Sequential(
114 | # pw
115 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
116 | norm_layer(hidden_dim),
117 | h_swish() if use_hs else nn.ReLU(inplace=True),
118 | # dw
119 | nn.Conv2d(hidden_dim,
120 | hidden_dim,
121 | kernel_size,
122 | stride, (kernel_size - 1) // 2 * dilation,
123 | dilation=dilation,
124 | groups=hidden_dim,
125 | bias=False),
126 | norm_layer(hidden_dim),
127 | # Squeeze-and-Excite
128 | SELayer(hidden_dim) if use_se else nn.Identity(),
129 | h_swish() if use_hs else nn.ReLU(inplace=True),
130 | # pw-linear
131 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
132 | norm_layer(oup),
133 | )
134 |
135 | def forward(self, x):
136 | if self.identity:
137 | return x + self.conv(x)
138 | else:
139 | return self.conv(x)
140 |
141 |
142 | class MobileNetV3Large(nn.Module):
143 | def __init__(self,
144 | output_stride=16,
145 | norm_layer=nn.BatchNorm2d,
146 | width_mult=1.,
147 | freeze_at=0):
148 | super(MobileNetV3Large, self).__init__()
149 | """
150 | Constructs a MobileNetV3-Large model
151 | """
152 | cfgs = [
153 | # k, t, c, SE, HS, s
154 | [3, 1, 16, 0, 0, 1],
155 | [3, 4, 24, 0, 0, 2],
156 | [3, 3, 24, 0, 0, 1],
157 | [5, 3, 40, 1, 0, 2],
158 | [5, 3, 40, 1, 0, 1],
159 | [5, 3, 40, 1, 0, 1],
160 | [3, 6, 80, 0, 1, 2],
161 | [3, 2.5, 80, 0, 1, 1],
162 | [3, 2.3, 80, 0, 1, 1],
163 | [3, 2.3, 80, 0, 1, 1],
164 | [3, 6, 112, 1, 1, 1],
165 | [3, 6, 112, 1, 1, 1],
166 | [5, 6, 160, 1, 1, 2],
167 | [5, 6, 160, 1, 1, 1],
168 | [5, 6, 160, 1, 1, 1]
169 | ]
170 | self.cfgs = cfgs
171 |
172 | # building first layer
173 | input_channel = _make_divisible(16 * width_mult, 8)
174 | layers = [conv_3x3_bn(3, input_channel, 2, norm_layer)]
175 | # building inverted residual blocks
176 | block = InvertedResidual
177 | now_stride = 2
178 | rate = 1
179 | for k, t, c, use_se, use_hs, s in self.cfgs:
180 | if now_stride == output_stride:
181 | dilation = rate
182 | rate *= s
183 | s = 1
184 | else:
185 | dilation = 1
186 | now_stride *= s
187 | output_channel = _make_divisible(c * width_mult, 8)
188 | exp_size = _make_divisible(input_channel * t, 8)
189 | layers.append(
190 | block(input_channel, exp_size, output_channel, k, s, use_se,
191 | use_hs, dilation, norm_layer))
192 | input_channel = output_channel
193 |
194 | self.features = nn.Sequential(*layers)
195 | self.conv = conv_1x1_bn(input_channel, exp_size, norm_layer)
196 | # building last several layers
197 |
198 | self._initialize_weights()
199 |
200 | feature_4x = self.features[0:4]
201 | feautre_8x = self.features[4:7]
202 | feature_16x = self.features[7:13]
203 | feature_32x = self.features[13:]
204 |
205 | self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x]
206 |
207 | self.freeze(freeze_at)
208 |
209 | def forward(self, x):
210 | xs = []
211 | for stage in self.stages:
212 | x = stage(x)
213 | xs.append(x)
214 | xs[-1] = self.conv(xs[-1])
215 | return xs
216 |
217 | def _initialize_weights(self):
218 | for m in self.modules():
219 | if isinstance(m, nn.Conv2d):
220 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
221 | m.weight.data.normal_(0, math.sqrt(2. / n))
222 | if m.bias is not None:
223 | m.bias.data.zero_()
224 | elif isinstance(m, nn.BatchNorm2d):
225 | m.weight.data.fill_(1)
226 | m.bias.data.zero_()
227 | elif isinstance(m, nn.Linear):
228 | n = m.weight.size(1)
229 | m.weight.data.normal_(0, 0.01)
230 | m.bias.data.zero_()
231 |
232 | def freeze(self, freeze_at):
233 | if freeze_at >= 1:
234 | for m in self.stages[0][0]:
235 | freeze_params(m)
236 |
237 | for idx, stage in enumerate(self.stages, start=2):
238 | if freeze_at >= idx:
239 | freeze_params(stage)
240 |
--------------------------------------------------------------------------------
/networks/encoders/resnest/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnest import *
2 |
--------------------------------------------------------------------------------
/networks/encoders/resnest/resnest.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .resnet import ResNet, Bottleneck
3 |
4 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269']
5 |
6 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth'
7 |
8 | _model_sha256 = {
9 | name: checksum
10 | for checksum, name in [
11 | ('528c19ca', 'resnest50'),
12 | ('22405ba7', 'resnest101'),
13 | ('75117900', 'resnest200'),
14 | ('0cc87c48', 'resnest269'),
15 | ]
16 | }
17 |
18 |
19 | def short_hash(name):
20 | if name not in _model_sha256:
21 | raise ValueError(
22 | 'Pretrained model for {name} is not available.'.format(name=name))
23 | return _model_sha256[name][:8]
24 |
25 |
26 | resnest_model_urls = {
27 | name: _url_format.format(name, short_hash(name))
28 | for name in _model_sha256.keys()
29 | }
30 |
31 |
32 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs):
33 | model = ResNet(Bottleneck, [3, 4, 6, 3],
34 | radix=2,
35 | groups=1,
36 | bottleneck_width=64,
37 | deep_stem=True,
38 | stem_width=32,
39 | avg_down=True,
40 | avd=True,
41 | avd_first=False,
42 | **kwargs)
43 | if pretrained:
44 | model.load_state_dict(
45 | torch.hub.load_state_dict_from_url(resnest_model_urls['resnest50'],
46 | progress=True,
47 | check_hash=True))
48 | return model
49 |
50 |
51 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs):
52 | model = ResNet(Bottleneck, [3, 4, 23, 3],
53 | radix=2,
54 | groups=1,
55 | bottleneck_width=64,
56 | deep_stem=True,
57 | stem_width=64,
58 | avg_down=True,
59 | avd=True,
60 | avd_first=False,
61 | **kwargs)
62 | if pretrained:
63 | model.load_state_dict(
64 | torch.hub.load_state_dict_from_url(
65 | resnest_model_urls['resnest101'],
66 | progress=True,
67 | check_hash=True))
68 | return model
69 |
70 |
71 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
72 | model = ResNet(Bottleneck, [3, 24, 36, 3],
73 | radix=2,
74 | groups=1,
75 | bottleneck_width=64,
76 | deep_stem=True,
77 | stem_width=64,
78 | avg_down=True,
79 | avd=True,
80 | avd_first=False,
81 | **kwargs)
82 | if pretrained:
83 | model.load_state_dict(
84 | torch.hub.load_state_dict_from_url(
85 | resnest_model_urls['resnest200'],
86 | progress=True,
87 | check_hash=True))
88 | return model
89 |
90 |
91 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs):
92 | model = ResNet(Bottleneck, [3, 30, 48, 8],
93 | radix=2,
94 | groups=1,
95 | bottleneck_width=64,
96 | deep_stem=True,
97 | stem_width=64,
98 | avg_down=True,
99 | avd=True,
100 | avd_first=False,
101 | **kwargs)
102 | if pretrained:
103 | model.load_state_dict(
104 | torch.hub.load_state_dict_from_url(
105 | resnest_model_urls['resnest269'],
106 | progress=True,
107 | check_hash=True))
108 | return model
109 |
--------------------------------------------------------------------------------
/networks/encoders/resnest/splat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn import Conv2d, Module, ReLU
5 | from torch.nn.modules.utils import _pair
6 |
7 | __all__ = ['SplAtConv2d', 'DropBlock2D']
8 |
9 |
10 | class DropBlock2D(object):
11 | def __init__(self, *args, **kwargs):
12 | raise NotImplementedError
13 |
14 |
15 | class SplAtConv2d(Module):
16 | """Split-Attention Conv2d
17 | """
18 | def __init__(self,
19 | in_channels,
20 | channels,
21 | kernel_size,
22 | stride=(1, 1),
23 | padding=(0, 0),
24 | dilation=(1, 1),
25 | groups=1,
26 | bias=True,
27 | radix=2,
28 | reduction_factor=4,
29 | rectify=False,
30 | rectify_avg=False,
31 | norm_layer=None,
32 | dropblock_prob=0.0,
33 | **kwargs):
34 | super(SplAtConv2d, self).__init__()
35 | padding = _pair(padding)
36 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
37 | self.rectify_avg = rectify_avg
38 | inter_channels = max(in_channels * radix // reduction_factor, 32)
39 | self.radix = radix
40 | self.cardinality = groups
41 | self.channels = channels
42 | self.dropblock_prob = dropblock_prob
43 | if self.rectify:
44 | from rfconv import RFConv2d
45 | self.conv = RFConv2d(in_channels,
46 | channels * radix,
47 | kernel_size,
48 | stride,
49 | padding,
50 | dilation,
51 | groups=groups * radix,
52 | bias=bias,
53 | average_mode=rectify_avg,
54 | **kwargs)
55 | else:
56 | self.conv = Conv2d(in_channels,
57 | channels * radix,
58 | kernel_size,
59 | stride,
60 | padding,
61 | dilation,
62 | groups=groups * radix,
63 | bias=bias,
64 | **kwargs)
65 | self.use_bn = norm_layer is not None
66 | if self.use_bn:
67 | self.bn0 = norm_layer(channels * radix)
68 | self.relu = ReLU(inplace=True)
69 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
70 | if self.use_bn:
71 | self.bn1 = norm_layer(inter_channels)
72 | self.fc2 = Conv2d(inter_channels,
73 | channels * radix,
74 | 1,
75 | groups=self.cardinality)
76 | if dropblock_prob > 0.0:
77 | self.dropblock = DropBlock2D(dropblock_prob, 3)
78 | self.rsoftmax = rSoftMax(radix, groups)
79 |
80 | def forward(self, x):
81 | x = self.conv(x)
82 | if self.use_bn:
83 | x = self.bn0(x)
84 | if self.dropblock_prob > 0.0:
85 | x = self.dropblock(x)
86 | x = self.relu(x)
87 |
88 | batch, rchannel = x.shape[:2]
89 | if self.radix > 1:
90 | if torch.__version__ < '1.5':
91 | splited = torch.split(x, int(rchannel // self.radix), dim=1)
92 | else:
93 | splited = torch.split(x, rchannel // self.radix, dim=1)
94 | gap = sum(splited)
95 | else:
96 | gap = x
97 | gap = F.adaptive_avg_pool2d(gap, 1)
98 | gap = self.fc1(gap)
99 |
100 | if self.use_bn:
101 | gap = self.bn1(gap)
102 | gap = self.relu(gap)
103 |
104 | atten = self.fc2(gap)
105 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
106 |
107 | if self.radix > 1:
108 | if torch.__version__ < '1.5':
109 | attens = torch.split(atten, int(rchannel // self.radix), dim=1)
110 | else:
111 | attens = torch.split(atten, rchannel // self.radix, dim=1)
112 | out = sum([att * split for (att, split) in zip(attens, splited)])
113 | else:
114 | out = atten * x
115 | return out.contiguous()
116 |
117 |
118 | class rSoftMax(nn.Module):
119 | def __init__(self, radix, cardinality):
120 | super().__init__()
121 | self.radix = radix
122 | self.cardinality = cardinality
123 |
124 | def forward(self, x):
125 | batch = x.size(0)
126 | if self.radix > 1:
127 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
128 | x = F.softmax(x, dim=1)
129 | x = x.reshape(batch, -1)
130 | else:
131 | x = torch.sigmoid(x)
132 | return x
133 |
--------------------------------------------------------------------------------
/networks/encoders/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | from utils.learning import freeze_params
4 |
5 |
6 | class Bottleneck(nn.Module):
7 | expansion = 4
8 |
9 | def __init__(self,
10 | inplanes,
11 | planes,
12 | stride=1,
13 | dilation=1,
14 | downsample=None,
15 | BatchNorm=None):
16 | super(Bottleneck, self).__init__()
17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
18 | self.bn1 = BatchNorm(planes)
19 | self.conv2 = nn.Conv2d(planes,
20 | planes,
21 | kernel_size=3,
22 | stride=stride,
23 | dilation=dilation,
24 | padding=dilation,
25 | bias=False)
26 | self.bn2 = BatchNorm(planes)
27 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
28 | self.bn3 = BatchNorm(planes * 4)
29 | self.relu = nn.ReLU(inplace=True)
30 | self.downsample = downsample
31 | self.stride = stride
32 | self.dilation = dilation
33 |
34 | def forward(self, x):
35 | residual = x
36 |
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv2(out)
42 | out = self.bn2(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv3(out)
46 | out = self.bn3(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class ResNet(nn.Module):
58 | def __init__(self, block, layers, output_stride, BatchNorm, freeze_at=0):
59 | self.inplanes = 64
60 | super(ResNet, self).__init__()
61 |
62 | if output_stride == 16:
63 | strides = [1, 2, 2, 1]
64 | dilations = [1, 1, 1, 2]
65 | elif output_stride == 8:
66 | strides = [1, 2, 1, 1]
67 | dilations = [1, 1, 2, 4]
68 | else:
69 | raise NotImplementedError
70 |
71 | # Modules
72 | self.conv1 = nn.Conv2d(3,
73 | 64,
74 | kernel_size=7,
75 | stride=2,
76 | padding=3,
77 | bias=False)
78 | self.bn1 = BatchNorm(64)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
81 |
82 | self.layer1 = self._make_layer(block,
83 | 64,
84 | layers[0],
85 | stride=strides[0],
86 | dilation=dilations[0],
87 | BatchNorm=BatchNorm)
88 | self.layer2 = self._make_layer(block,
89 | 128,
90 | layers[1],
91 | stride=strides[1],
92 | dilation=dilations[1],
93 | BatchNorm=BatchNorm)
94 | self.layer3 = self._make_layer(block,
95 | 256,
96 | layers[2],
97 | stride=strides[2],
98 | dilation=dilations[2],
99 | BatchNorm=BatchNorm)
100 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
101 |
102 | self.stem = [self.conv1, self.bn1]
103 | self.stages = [self.layer1, self.layer2, self.layer3]
104 |
105 | self._init_weight()
106 | self.freeze(freeze_at)
107 |
108 | def _make_layer(self,
109 | block,
110 | planes,
111 | blocks,
112 | stride=1,
113 | dilation=1,
114 | BatchNorm=None):
115 | downsample = None
116 | if stride != 1 or self.inplanes != planes * block.expansion:
117 | downsample = nn.Sequential(
118 | nn.Conv2d(self.inplanes,
119 | planes * block.expansion,
120 | kernel_size=1,
121 | stride=stride,
122 | bias=False),
123 | BatchNorm(planes * block.expansion),
124 | )
125 |
126 | layers = []
127 | layers.append(
128 | block(self.inplanes, planes, stride, max(dilation // 2, 1),
129 | downsample, BatchNorm))
130 | self.inplanes = planes * block.expansion
131 | for i in range(1, blocks):
132 | layers.append(
133 | block(self.inplanes,
134 | planes,
135 | dilation=dilation,
136 | BatchNorm=BatchNorm))
137 |
138 | return nn.Sequential(*layers)
139 |
140 | def forward(self, input):
141 | x = self.conv1(input)
142 | x = self.bn1(x)
143 | x = self.relu(x)
144 | x = self.maxpool(x)
145 |
146 | xs = []
147 |
148 | x = self.layer1(x)
149 | xs.append(x) # 4X
150 | x = self.layer2(x)
151 | xs.append(x) # 8X
152 | x = self.layer3(x)
153 | xs.append(x) # 16X
154 | # Following STMVOS, we drop stage 5.
155 | xs.append(x) # 16X
156 |
157 | return xs
158 |
159 | def _init_weight(self):
160 | for m in self.modules():
161 | if isinstance(m, nn.Conv2d):
162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
163 | m.weight.data.normal_(0, math.sqrt(2. / n))
164 | elif isinstance(m, nn.BatchNorm2d):
165 | m.weight.data.fill_(1)
166 | m.bias.data.zero_()
167 |
168 | def freeze(self, freeze_at):
169 | if freeze_at >= 1:
170 | for m in self.stem:
171 | freeze_params(m)
172 |
173 | for idx, stage in enumerate(self.stages, start=2):
174 | if freeze_at >= idx:
175 | freeze_params(stage)
176 |
177 |
178 | def ResNet50(output_stride, BatchNorm, freeze_at=0):
179 | """Constructs a ResNet-50 model.
180 | Args:
181 | pretrained (bool): If True, returns a model pre-trained on ImageNet
182 | """
183 | model = ResNet(Bottleneck, [3, 4, 6, 3],
184 | output_stride,
185 | BatchNorm,
186 | freeze_at=freeze_at)
187 | return model
188 |
189 |
190 | def ResNet101(output_stride, BatchNorm, freeze_at=0):
191 | """Constructs a ResNet-101 model.
192 | Args:
193 | pretrained (bool): If True, returns a model pre-trained on ImageNet
194 | """
195 | model = ResNet(Bottleneck, [3, 4, 23, 3],
196 | output_stride,
197 | BatchNorm,
198 | freeze_at=freeze_at)
199 | return model
200 |
201 |
202 | if __name__ == "__main__":
203 | import torch
204 | model = ResNet101(BatchNorm=nn.BatchNorm2d, output_stride=8)
205 | input = torch.rand(1, 3, 512, 512)
206 | output, low_level_feat = model(input)
207 | print(output.size())
208 | print(low_level_feat.size())
209 |
--------------------------------------------------------------------------------
/networks/encoders/swin/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_swin_model
--------------------------------------------------------------------------------
/networks/encoders/swin/build.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | from .swin_transformer import SwinTransformer
9 |
10 |
11 | def build_swin_model(model_type, freeze_at=0):
12 | if model_type == 'swin_base':
13 | model = SwinTransformer(embed_dim=128,
14 | depths=[2, 2, 18, 2],
15 | num_heads=[4, 8, 16, 32],
16 | window_size=7,
17 | drop_path_rate=0.3,
18 | out_indices=(0, 1, 2),
19 | ape=False,
20 | patch_norm=True,
21 | frozen_stages=freeze_at,
22 | use_checkpoint=False)
23 |
24 | else:
25 | raise NotImplementedError(f"Unkown model: {model_type}")
26 |
27 | return model
28 |
--------------------------------------------------------------------------------
/networks/engines/__init__.py:
--------------------------------------------------------------------------------
1 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine
2 | from networks.engines.deaot_engine import DeAOTEngine, DeAOTInferEngine
3 |
4 |
5 | def build_engine(name, phase='train', **kwargs):
6 | if name == 'aotengine':
7 | if phase == 'train':
8 | return AOTEngine(**kwargs)
9 | elif phase == 'eval':
10 | return AOTInferEngine(**kwargs)
11 | else:
12 | raise NotImplementedError
13 | elif name == 'deaotengine':
14 | if phase == 'train':
15 | return DeAOTEngine(**kwargs)
16 | elif phase == 'eval':
17 | return DeAOTInferEngine(**kwargs)
18 | else:
19 | raise NotImplementedError
20 | else:
21 | raise NotImplementedError
22 |
--------------------------------------------------------------------------------
/networks/engines/deaot_engine.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from utils.image import one_hot_mask
4 |
5 | from networks.layers.basic import seq_to_2d
6 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine
7 |
8 |
9 | class DeAOTEngine(AOTEngine):
10 | def __init__(self,
11 | aot_model,
12 | gpu_id=0,
13 | long_term_mem_gap=9999,
14 | short_term_mem_skip=1,
15 | layer_loss_scaling_ratio=2.):
16 | super().__init__(aot_model, gpu_id, long_term_mem_gap,
17 | short_term_mem_skip)
18 | self.layer_loss_scaling_ratio = layer_loss_scaling_ratio
19 |
20 | def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False):
21 |
22 | if curr_id_emb is None:
23 | if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1:
24 | curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num)
25 | else:
26 | curr_one_hot_mask = curr_mask
27 | curr_id_emb = self.assign_identity(curr_one_hot_mask)
28 |
29 | lstt_curr_memories = self.curr_lstt_output[1]
30 | lstt_curr_memories_2d = []
31 | for layer_idx in range(len(lstt_curr_memories)):
32 | curr_k, curr_v, curr_id_k, curr_id_v = lstt_curr_memories[
33 | layer_idx]
34 | curr_id_k, curr_id_v = self.AOT.LSTT.layers[
35 | layer_idx].fuse_key_value_id(curr_id_k, curr_id_v, curr_id_emb)
36 | lstt_curr_memories[layer_idx][2], lstt_curr_memories[layer_idx][
37 | 3] = curr_id_k, curr_id_v
38 | local_curr_id_k = seq_to_2d(
39 | curr_id_k, self.enc_size_2d) if curr_id_k is not None else None
40 | local_curr_id_v = seq_to_2d(curr_id_v, self.enc_size_2d)
41 | lstt_curr_memories_2d.append([
42 | seq_to_2d(curr_k, self.enc_size_2d),
43 | seq_to_2d(curr_v, self.enc_size_2d), local_curr_id_k,
44 | local_curr_id_v
45 | ])
46 |
47 | self.short_term_memories_list.append(lstt_curr_memories_2d)
48 | self.short_term_memories_list = self.short_term_memories_list[
49 | -self.short_term_mem_skip:]
50 | self.short_term_memories = self.short_term_memories_list[0]
51 |
52 | if self.frame_step - self.last_mem_step >= self.long_term_mem_gap:
53 | # skip the update of long-term memory or not
54 | if not skip_long_term_update:
55 | self.update_long_term_memory(lstt_curr_memories)
56 | self.last_mem_step = self.frame_step
57 |
58 |
59 | class DeAOTInferEngine(AOTInferEngine):
60 | def __init__(self,
61 | aot_model,
62 | gpu_id=0,
63 | long_term_mem_gap=9999,
64 | short_term_mem_skip=1,
65 | max_aot_obj_num=None):
66 | super().__init__(aot_model, gpu_id, long_term_mem_gap,
67 | short_term_mem_skip, max_aot_obj_num)
68 |
69 | def add_reference_frame(self, img, mask, obj_nums, frame_step=-1):
70 | if isinstance(obj_nums, list):
71 | obj_nums = obj_nums[0]
72 | self.obj_nums = obj_nums
73 | aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1)
74 | while (aot_num > len(self.aot_engines)):
75 | new_engine = DeAOTEngine(self.AOT, self.gpu_id,
76 | self.long_term_mem_gap,
77 | self.short_term_mem_skip)
78 | new_engine.eval()
79 | self.aot_engines.append(new_engine)
80 |
81 | separated_masks, separated_obj_nums = self.separate_mask(
82 | mask, obj_nums)
83 | img_embs = None
84 | for aot_engine, separated_mask, separated_obj_num in zip(
85 | self.aot_engines, separated_masks, separated_obj_nums):
86 | aot_engine.add_reference_frame(img,
87 | separated_mask,
88 | obj_nums=[separated_obj_num],
89 | frame_step=frame_step,
90 | img_embs=img_embs)
91 | if img_embs is None: # reuse image embeddings
92 | img_embs = aot_engine.curr_enc_embs
93 |
94 | self.update_size()
95 |
--------------------------------------------------------------------------------
/networks/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/networks/layers/__init__.py
--------------------------------------------------------------------------------
/networks/layers/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class GroupNorm1D(nn.Module):
7 | def __init__(self, indim, groups=8):
8 | super().__init__()
9 | self.gn = nn.GroupNorm(groups, indim)
10 |
11 | def forward(self, x):
12 | return self.gn(x.permute(1, 2, 0)).permute(2, 0, 1)
13 |
14 |
15 | class GNActDWConv2d(nn.Module):
16 | def __init__(self, indim, gn_groups=32):
17 | super().__init__()
18 | self.gn = nn.GroupNorm(gn_groups, indim)
19 | self.conv = nn.Conv2d(indim,
20 | indim,
21 | 5,
22 | dilation=1,
23 | padding=2,
24 | groups=indim,
25 | bias=False)
26 |
27 | def forward(self, x, size_2d):
28 | h, w = size_2d
29 | _, bs, c = x.size()
30 | x = x.view(h, w, bs, c).permute(2, 3, 0, 1)
31 | x = self.gn(x)
32 | x = F.gelu(x)
33 | x = self.conv(x)
34 | x = x.view(bs, c, h * w).permute(2, 0, 1)
35 | return x
36 |
37 |
38 | class DWConv2d(nn.Module):
39 | def __init__(self, indim, dropout=0.1):
40 | super().__init__()
41 | self.conv = nn.Conv2d(indim,
42 | indim,
43 | 5,
44 | dilation=1,
45 | padding=2,
46 | groups=indim,
47 | bias=False)
48 | self.dropout = nn.Dropout2d(p=dropout, inplace=True)
49 |
50 | def forward(self, x, size_2d):
51 | h, w = size_2d
52 | _, bs, c = x.size()
53 | x = x.view(h, w, bs, c).permute(2, 3, 0, 1)
54 | x = self.conv(x)
55 | x = self.dropout(x)
56 | x = x.view(bs, c, h * w).permute(2, 0, 1)
57 | return x
58 |
59 |
60 | class ScaleOffset(nn.Module):
61 | def __init__(self, indim):
62 | super().__init__()
63 | self.gamma = nn.Parameter(torch.ones(indim))
64 | # torch.nn.init.normal_(self.gamma, std=0.02)
65 | self.beta = nn.Parameter(torch.zeros(indim))
66 |
67 | def forward(self, x):
68 | if len(x.size()) == 3:
69 | return x * self.gamma + self.beta
70 | else:
71 | return x * self.gamma.view(1, -1, 1, 1) + self.beta.view(
72 | 1, -1, 1, 1)
73 |
74 |
75 | class ConvGN(nn.Module):
76 | def __init__(self, indim, outdim, kernel_size, gn_groups=8):
77 | super().__init__()
78 | self.conv = nn.Conv2d(indim,
79 | outdim,
80 | kernel_size,
81 | padding=kernel_size // 2)
82 | self.gn = nn.GroupNorm(gn_groups, outdim)
83 |
84 | def forward(self, x):
85 | return self.gn(self.conv(x))
86 |
87 |
88 | def seq_to_2d(tensor, size_2d):
89 | h, w = size_2d
90 | _, n, c = tensor.size()
91 | tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous()
92 | return tensor
93 |
94 |
95 | def drop_path(x, drop_prob: float = 0., training: bool = False):
96 | if drop_prob == 0. or not training:
97 | return x
98 | keep_prob = 1 - drop_prob
99 | shape = (
100 | x.shape[0],
101 | x.shape[1],
102 | ) + (1, ) * (x.ndim - 2
103 | ) # work with diff dim tensors, not just 2D ConvNets
104 | random_tensor = keep_prob + torch.rand(
105 | shape, dtype=x.dtype, device=x.device)
106 | random_tensor.floor_() # binarize
107 | output = x.div(keep_prob) * random_tensor
108 | return output
109 |
110 |
111 | def mask_out(x, y, mask_rate=0.15, training=False):
112 | if mask_rate == 0. or not training:
113 | return x
114 |
115 | keep_prob = 1 - mask_rate
116 | shape = (
117 | x.shape[0],
118 | x.shape[1],
119 | ) + (1, ) * (x.ndim - 2
120 | ) # work with diff dim tensors, not just 2D ConvNets
121 | random_tensor = keep_prob + torch.rand(
122 | shape, dtype=x.dtype, device=x.device)
123 | random_tensor.floor_() # binarize
124 | output = x * random_tensor + y * (1 - random_tensor)
125 |
126 | return output
127 |
128 |
129 | class DropPath(nn.Module):
130 | def __init__(self, drop_prob=None, batch_dim=0):
131 | super(DropPath, self).__init__()
132 | self.drop_prob = drop_prob
133 | self.batch_dim = batch_dim
134 |
135 | def forward(self, x):
136 | return self.drop_path(x, self.drop_prob)
137 |
138 | def drop_path(self, x, drop_prob):
139 | if drop_prob == 0. or not self.training:
140 | return x
141 | keep_prob = 1 - drop_prob
142 | shape = [1 for _ in range(x.ndim)]
143 | shape[self.batch_dim] = x.shape[self.batch_dim]
144 | random_tensor = keep_prob + torch.rand(
145 | shape, dtype=x.dtype, device=x.device)
146 | random_tensor.floor_() # binarize
147 | output = x.div(keep_prob) * random_tensor
148 | return output
149 |
150 |
151 | class DropOutLogit(nn.Module):
152 | def __init__(self, drop_prob=None):
153 | super(DropOutLogit, self).__init__()
154 | self.drop_prob = drop_prob
155 |
156 | def forward(self, x):
157 | return self.drop_logit(x, self.drop_prob)
158 |
159 | def drop_logit(self, x, drop_prob):
160 | if drop_prob == 0. or not self.training:
161 | return x
162 | random_tensor = drop_prob + torch.rand(
163 | x.shape, dtype=x.dtype, device=x.device)
164 | random_tensor.floor_() # binarize
165 | mask = random_tensor * 1e+8 if (
166 | x.dtype == torch.float32) else random_tensor * 1e+4
167 | output = x - mask
168 | return output
169 |
--------------------------------------------------------------------------------
/networks/layers/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | try:
6 | from itertools import ifilterfalse
7 | except ImportError: # py3k
8 | from itertools import filterfalse as ifilterfalse
9 |
10 |
11 | def dice_loss(probas, labels, smooth=1):
12 |
13 | C = probas.size(1)
14 | losses = []
15 | for c in list(range(C)):
16 | fg = (labels == c).float()
17 | if fg.sum() == 0:
18 | continue
19 | class_pred = probas[:, c]
20 | p0 = class_pred
21 | g0 = fg
22 | numerator = 2 * torch.sum(p0 * g0) + smooth
23 | denominator = torch.sum(p0) + torch.sum(g0) + smooth
24 | losses.append(1 - ((numerator) / (denominator)))
25 | return mean(losses)
26 |
27 |
28 | def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6):
29 | '''
30 | Tversky loss function.
31 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
32 | labels: [P] Tensor, ground truth labels (between 0 and C - 1)
33 |
34 | Same as soft dice loss when alpha=beta=0.5.
35 | Same as Jaccord loss when alpha=beta=1.0.
36 | See `Tversky loss function for image segmentation using 3D fully convolutional deep networks`
37 | https://arxiv.org/pdf/1706.05721.pdf
38 | '''
39 | C = probas.size(1)
40 | losses = []
41 | for c in list(range(C)):
42 | fg = (labels == c).float()
43 | if fg.sum() == 0:
44 | continue
45 | class_pred = probas[:, c]
46 | p0 = class_pred
47 | p1 = 1 - class_pred
48 | g0 = fg
49 | g1 = 1 - fg
50 | numerator = torch.sum(p0 * g0)
51 | denominator = numerator + alpha * \
52 | torch.sum(p0*g1) + beta*torch.sum(p1*g0)
53 | losses.append(1 - ((numerator) / (denominator + epsilon)))
54 | return mean(losses)
55 |
56 |
57 | def flatten_probas(probas, labels, ignore=255):
58 | """
59 | Flattens predictions in the batch
60 | """
61 | B, C, H, W = probas.size()
62 | probas = probas.permute(0, 2, 3,
63 | 1).contiguous().view(-1, C) # B * H * W, C = P, C
64 | labels = labels.view(-1)
65 | if ignore is None:
66 | return probas, labels
67 | valid = (labels != ignore)
68 | vprobas = probas[valid.view(-1, 1).expand(-1, C)].reshape(-1, C)
69 | # vprobas = probas[torch.nonzero(valid).squeeze()]
70 | vlabels = labels[valid]
71 | return vprobas, vlabels
72 |
73 |
74 | def isnan(x):
75 | return x != x
76 |
77 |
78 | def mean(l, ignore_nan=False, empty=0):
79 | """
80 | nanmean compatible with generators.
81 | """
82 | l = iter(l)
83 | if ignore_nan:
84 | l = ifilterfalse(isnan, l)
85 | try:
86 | n = 1
87 | acc = next(l)
88 | except StopIteration:
89 | if empty == 'raise':
90 | raise ValueError('Empty mean')
91 | return empty
92 | for n, v in enumerate(l, 2):
93 | acc += v
94 | if n == 1:
95 | return acc
96 | return acc / n
97 |
98 |
99 | class DiceLoss(nn.Module):
100 | def __init__(self, ignore_index=255):
101 | super(DiceLoss, self).__init__()
102 | self.ignore_index = ignore_index
103 |
104 | def forward(self, tmp_dic, label_dic, step=None):
105 | total_loss = []
106 | for idx in range(len(tmp_dic)):
107 | pred = tmp_dic[idx]
108 | label = label_dic[idx]
109 | pred = F.softmax(pred, dim=1)
110 | label = label.view(1, 1, pred.size()[2], pred.size()[3])
111 | loss = dice_loss(
112 | *flatten_probas(pred, label, ignore=self.ignore_index))
113 | total_loss.append(loss.unsqueeze(0))
114 | total_loss = torch.cat(total_loss, dim=0)
115 | return total_loss
116 |
117 |
118 | class SoftJaccordLoss(nn.Module):
119 | def __init__(self, ignore_index=255):
120 | super(SoftJaccordLoss, self).__init__()
121 | self.ignore_index = ignore_index
122 |
123 | def forward(self, tmp_dic, label_dic, step=None):
124 | total_loss = []
125 | for idx in range(len(tmp_dic)):
126 | pred = tmp_dic[idx]
127 | label = label_dic[idx]
128 | pred = F.softmax(pred, dim=1)
129 | label = label.view(1, 1, pred.size()[2], pred.size()[3])
130 | loss = tversky_loss(*flatten_probas(pred,
131 | label,
132 | ignore=self.ignore_index),
133 | alpha=1.0,
134 | beta=1.0)
135 | total_loss.append(loss.unsqueeze(0))
136 | total_loss = torch.cat(total_loss, dim=0)
137 | return total_loss
138 |
139 |
140 | class CrossEntropyLoss(nn.Module):
141 | def __init__(self,
142 | top_k_percent_pixels=None,
143 | hard_example_mining_step=100000):
144 | super(CrossEntropyLoss, self).__init__()
145 | self.top_k_percent_pixels = top_k_percent_pixels
146 | if top_k_percent_pixels is not None:
147 | assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1)
148 | self.hard_example_mining_step = hard_example_mining_step + 1e-5
149 | if self.top_k_percent_pixels is None:
150 | self.celoss = nn.CrossEntropyLoss(ignore_index=255,
151 | reduction='mean')
152 | else:
153 | self.celoss = nn.CrossEntropyLoss(ignore_index=255,
154 | reduction='none')
155 |
156 | def forward(self, dic_tmp, y, step):
157 | total_loss = []
158 | for i in range(len(dic_tmp)):
159 | pred_logits = dic_tmp[i]
160 | gts = y[i]
161 | if self.top_k_percent_pixels is None:
162 | final_loss = self.celoss(pred_logits, gts)
163 | else:
164 | # Only compute the loss for top k percent pixels.
165 | # First, compute the loss for all pixels. Note we do not put the loss
166 | # to loss_collection and set reduction = None to keep the shape.
167 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3))
168 | pred_logits = pred_logits.view(
169 | -1, pred_logits.size(1),
170 | pred_logits.size(2) * pred_logits.size(3))
171 | gts = gts.view(-1, gts.size(1) * gts.size(2))
172 | pixel_losses = self.celoss(pred_logits, gts)
173 | if self.hard_example_mining_step == 0:
174 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels)
175 | else:
176 | ratio = min(1.0,
177 | step / float(self.hard_example_mining_step))
178 | top_k_pixels = int((ratio * self.top_k_percent_pixels +
179 | (1.0 - ratio)) * num_pixels)
180 | top_k_loss, top_k_indices = torch.topk(pixel_losses,
181 | k=top_k_pixels,
182 | dim=1)
183 |
184 | final_loss = torch.mean(top_k_loss)
185 | final_loss = final_loss.unsqueeze(0)
186 | total_loss.append(final_loss)
187 | total_loss = torch.cat(total_loss, dim=0)
188 | return total_loss
189 |
--------------------------------------------------------------------------------
/networks/layers/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FrozenBatchNorm2d(nn.Module):
7 | """
8 | BatchNorm2d where the batch statistics and the affine parameters
9 | are fixed
10 | """
11 | def __init__(self, n, epsilon=1e-5):
12 | super(FrozenBatchNorm2d, self).__init__()
13 | self.register_buffer("weight", torch.ones(n))
14 | self.register_buffer("bias", torch.zeros(n))
15 | self.register_buffer("running_mean", torch.zeros(n))
16 | self.register_buffer("running_var", torch.ones(n) - epsilon)
17 | self.epsilon = epsilon
18 |
19 | def forward(self, x):
20 | """
21 | Refer to Detectron2 (https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py)
22 | """
23 | if x.requires_grad:
24 | # When gradients are needed, F.batch_norm will use extra memory
25 | # because its backward op computes gradients for weight/bias as well.
26 | scale = self.weight * (self.running_var + self.epsilon).rsqrt()
27 | bias = self.bias - self.running_mean * scale
28 | scale = scale.reshape(1, -1, 1, 1)
29 | bias = bias.reshape(1, -1, 1, 1)
30 | out_dtype = x.dtype # may be half
31 | return x * scale.to(out_dtype) + bias.to(out_dtype)
32 | else:
33 | # When gradients are not needed, F.batch_norm is a single fused op
34 | # and provide more optimization opportunities.
35 | return F.batch_norm(
36 | x,
37 | self.running_mean,
38 | self.running_var,
39 | self.weight,
40 | self.bias,
41 | training=False,
42 | eps=self.epsilon,
43 | )
44 |
--------------------------------------------------------------------------------
/networks/layers/position.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from utils.math import truncated_normal_
8 |
9 |
10 | class Downsample2D(nn.Module):
11 | def __init__(self, mode='nearest', scale=4):
12 | super().__init__()
13 | self.mode = mode
14 | self.scale = scale
15 |
16 | def forward(self, x):
17 | n, c, h, w = x.size()
18 | x = F.interpolate(x,
19 | size=(h // self.scale + 1, w // self.scale + 1),
20 | mode=self.mode)
21 | return x
22 |
23 |
24 | def generate_coord(x):
25 | _, _, h, w = x.size()
26 | device = x.device
27 | col = torch.arange(0, h, device=device)
28 | row = torch.arange(0, w, device=device)
29 | grid_h, grid_w = torch.meshgrid(col, row)
30 | return grid_h, grid_w
31 |
32 |
33 | class PositionEmbeddingSine(nn.Module):
34 | def __init__(self,
35 | num_pos_feats=64,
36 | temperature=10000,
37 | normalize=False,
38 | scale=None):
39 | super().__init__()
40 | self.num_pos_feats = num_pos_feats
41 | self.temperature = temperature
42 | self.normalize = normalize
43 | if scale is not None and normalize is False:
44 | raise ValueError("normalize should be True if scale is passed")
45 | if scale is None:
46 | scale = 2 * math.pi
47 | self.scale = scale
48 |
49 | def forward(self, x):
50 | grid_y, grid_x = generate_coord(x)
51 |
52 | y_embed = grid_y.unsqueeze(0).float()
53 | x_embed = grid_x.unsqueeze(0).float()
54 |
55 | if self.normalize:
56 | eps = 1e-6
57 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
58 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
59 |
60 | dim_t = torch.arange(self.num_pos_feats,
61 | dtype=torch.float32,
62 | device=x.device)
63 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats)
64 |
65 | pos_x = x_embed[:, :, :, None] / dim_t
66 | pos_y = y_embed[:, :, :, None] / dim_t
67 | pos_x = torch.stack(
68 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
69 | dim=4).flatten(3)
70 | pos_y = torch.stack(
71 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
72 | dim=4).flatten(3)
73 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
74 | return pos
75 |
76 |
77 | class PositionEmbeddingLearned(nn.Module):
78 | def __init__(self, num_pos_feats=64, H=30, W=30):
79 | super().__init__()
80 | self.H = H
81 | self.W = W
82 | self.pos_emb = nn.Parameter(
83 | truncated_normal_(torch.zeros(1, num_pos_feats, H, W)))
84 |
85 | def forward(self, x):
86 | bs, _, h, w = x.size()
87 | pos_emb = self.pos_emb
88 | if h != self.H or w != self.W:
89 | pos_emb = F.interpolate(pos_emb, size=(h, w), mode="bilinear")
90 | return pos_emb
91 |
--------------------------------------------------------------------------------
/networks/models/__init__.py:
--------------------------------------------------------------------------------
1 | from networks.models.aot import AOT
2 | from networks.models.deaot import DeAOT
3 |
4 |
5 | def build_vos_model(name, cfg, **kwargs):
6 | if name == 'aot':
7 | return AOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs)
8 | elif name == 'deaot':
9 | return DeAOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs)
10 | else:
11 | raise NotImplementedError
12 |
--------------------------------------------------------------------------------
/networks/models/aot.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from networks.encoders import build_encoder
4 | from networks.layers.transformer import LongShortTermTransformer
5 | from networks.decoders import build_decoder
6 | from networks.layers.position import PositionEmbeddingSine
7 |
8 |
9 | class AOT(nn.Module):
10 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'):
11 | super().__init__()
12 | self.cfg = cfg
13 | self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM
14 | self.epsilon = cfg.MODEL_EPSILON
15 |
16 | self.encoder = build_encoder(encoder,
17 | frozen_bn=cfg.MODEL_FREEZE_BN,
18 | freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT)
19 | self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1],
20 | cfg.MODEL_ENCODER_EMBEDDING_DIM,
21 | kernel_size=1)
22 |
23 | self.LSTT = LongShortTermTransformer(
24 | cfg.MODEL_LSTT_NUM,
25 | cfg.MODEL_ENCODER_EMBEDDING_DIM,
26 | cfg.MODEL_SELF_HEADS,
27 | cfg.MODEL_ATT_HEADS,
28 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT,
29 | droppath=cfg.TRAIN_LSTT_DROPPATH,
30 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT,
31 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT,
32 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST,
33 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING,
34 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT,
35 | return_intermediate=True)
36 |
37 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \
38 | (cfg.MODEL_LSTT_NUM +
39 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM
40 |
41 | self.decoder = build_decoder(
42 | decoder,
43 | in_dim=decoder_indim,
44 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1,
45 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT,
46 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM,
47 | shortcut_dims=cfg.MODEL_ENCODER_DIM,
48 | align_corners=cfg.MODEL_ALIGN_CORNERS)
49 |
50 | if cfg.MODEL_ALIGN_CORNERS:
51 | self.patch_wise_id_bank = nn.Conv2d(
52 | cfg.MODEL_MAX_OBJ_NUM + 1,
53 | cfg.MODEL_ENCODER_EMBEDDING_DIM,
54 | kernel_size=17,
55 | stride=16,
56 | padding=8)
57 | else:
58 | self.patch_wise_id_bank = nn.Conv2d(
59 | cfg.MODEL_MAX_OBJ_NUM + 1,
60 | cfg.MODEL_ENCODER_EMBEDDING_DIM,
61 | kernel_size=16,
62 | stride=16,
63 | padding=0)
64 |
65 | self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True)
66 |
67 | self.pos_generator = PositionEmbeddingSine(
68 | cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True)
69 |
70 | self._init_weight()
71 |
72 | def get_pos_emb(self, x):
73 | pos_emb = self.pos_generator(x)
74 | return pos_emb
75 |
76 | def get_id_emb(self, x):
77 | id_emb = self.patch_wise_id_bank(x)
78 | id_emb = self.id_dropout(id_emb)
79 | return id_emb
80 |
81 | def encode_image(self, img):
82 | xs = self.encoder(img)
83 | xs[-1] = self.encoder_projector(xs[-1])
84 | return xs
85 |
86 | def decode_id_logits(self, lstt_emb, shortcuts):
87 | n, c, h, w = shortcuts[-1].size()
88 | decoder_inputs = [shortcuts[-1]]
89 | for emb in lstt_emb:
90 | decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1))
91 | pred_logit = self.decoder(decoder_inputs, shortcuts)
92 | return pred_logit
93 |
94 | def LSTT_forward(self,
95 | curr_embs,
96 | long_term_memories,
97 | short_term_memories,
98 | curr_id_emb=None,
99 | pos_emb=None,
100 | size_2d=(30, 30)):
101 | n, c, h, w = curr_embs[-1].size()
102 | curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1)
103 | lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories,
104 | short_term_memories, curr_id_emb,
105 | pos_emb, size_2d)
106 | lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip(
107 | *lstt_memories)
108 | return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories
109 |
110 | def _init_weight(self):
111 | nn.init.xavier_uniform_(self.encoder_projector.weight)
112 | nn.init.orthogonal_(
113 | self.patch_wise_id_bank.weight.view(
114 | self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1),
115 | gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2)
116 |
--------------------------------------------------------------------------------
/networks/models/deaot.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from networks.layers.transformer import DualBranchGPM
4 | from networks.models.aot import AOT
5 | from networks.decoders import build_decoder
6 |
7 |
8 | class DeAOT(AOT):
9 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'):
10 | super().__init__(cfg, encoder, decoder)
11 |
12 | self.LSTT = DualBranchGPM(
13 | cfg.MODEL_LSTT_NUM,
14 | cfg.MODEL_ENCODER_EMBEDDING_DIM,
15 | cfg.MODEL_SELF_HEADS,
16 | cfg.MODEL_ATT_HEADS,
17 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT,
18 | droppath=cfg.TRAIN_LSTT_DROPPATH,
19 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT,
20 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT,
21 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST,
22 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING,
23 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT,
24 | return_intermediate=True)
25 |
26 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \
27 | (cfg.MODEL_LSTT_NUM * 2 +
28 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM * 2
29 |
30 | self.decoder = build_decoder(
31 | decoder,
32 | in_dim=decoder_indim,
33 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1,
34 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT,
35 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM,
36 | shortcut_dims=cfg.MODEL_ENCODER_DIM,
37 | align_corners=cfg.MODEL_ALIGN_CORNERS)
38 |
39 | self.id_norm = nn.LayerNorm(cfg.MODEL_ENCODER_EMBEDDING_DIM)
40 |
41 | self._init_weight()
42 |
43 | def decode_id_logits(self, lstt_emb, shortcuts):
44 | n, c, h, w = shortcuts[-1].size()
45 | decoder_inputs = [shortcuts[-1]]
46 | for emb in lstt_emb:
47 | decoder_inputs.append(emb.view(h, w, n, -1).permute(2, 3, 0, 1))
48 | pred_logit = self.decoder(decoder_inputs, shortcuts)
49 | return pred_logit
50 |
51 | def get_id_emb(self, x):
52 | id_emb = self.patch_wise_id_bank(x)
53 | id_emb = self.id_norm(id_emb.permute(2, 3, 0, 1)).permute(2, 3, 0, 1)
54 | id_emb = self.id_dropout(id_emb)
55 | return id_emb
56 |
--------------------------------------------------------------------------------
/pretrain_models/README.md:
--------------------------------------------------------------------------------
1 | Put pretrained models here.
--------------------------------------------------------------------------------
/source/1001_3iEIq5HBY1s.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/1001_3iEIq5HBY1s.gif
--------------------------------------------------------------------------------
/source/1007_YCTBBdbKSSg.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/1007_YCTBBdbKSSg.gif
--------------------------------------------------------------------------------
/source/kobe.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/kobe.gif
--------------------------------------------------------------------------------
/source/messi.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/messi.gif
--------------------------------------------------------------------------------
/source/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/overview.png
--------------------------------------------------------------------------------
/source/overview_deaot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/overview_deaot.png
--------------------------------------------------------------------------------
/source/some_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/some_results.png
--------------------------------------------------------------------------------
/tools/demo.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 | import os
4 |
5 | sys.path.append('.')
6 | sys.path.append('..')
7 |
8 | import cv2
9 | from PIL import Image
10 | from skimage.morphology.binary import binary_dilation
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.utils.data import DataLoader
16 | from torchvision import transforms
17 |
18 | from networks.models import build_vos_model
19 | from networks.engines import build_engine
20 | from utils.checkpoint import load_network
21 |
22 | from dataloaders.eval_datasets import VOSTest
23 | import dataloaders.video_transforms as tr
24 | from utils.image import save_mask
25 |
26 | _palette = [
27 | 255, 0, 0, 0, 0, 139, 255, 255, 84, 0, 255, 0, 139, 0, 139, 0, 128, 128,
28 | 128, 128, 128, 139, 0, 0, 218, 165, 32, 144, 238, 144, 160, 82, 45, 148, 0,
29 | 211, 255, 0, 255, 30, 144, 255, 255, 218, 185, 85, 107, 47, 255, 140, 0,
30 | 50, 205, 50, 123, 104, 238, 240, 230, 140, 72, 61, 139, 128, 128, 0, 0, 0,
31 | 205, 221, 160, 221, 143, 188, 143, 127, 255, 212, 176, 224, 230, 244, 164,
32 | 96, 250, 128, 114, 70, 130, 180, 0, 128, 0, 173, 255, 47, 255, 105, 180,
33 | 238, 130, 238, 154, 205, 50, 220, 20, 60, 176, 48, 96, 0, 206, 209, 0, 191,
34 | 255, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45,
35 | 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51,
36 | 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58,
37 | 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64,
38 | 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70,
39 | 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77,
40 | 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83,
41 | 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89,
42 | 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96,
43 | 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101,
44 | 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106,
45 | 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111,
46 | 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116,
47 | 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121,
48 | 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126,
49 | 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131,
50 | 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136,
51 | 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141,
52 | 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146,
53 | 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151,
54 | 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156,
55 | 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161,
56 | 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166,
57 | 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171,
58 | 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176,
59 | 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181,
60 | 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186,
61 | 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191,
62 | 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196,
63 | 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201,
64 | 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206,
65 | 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211,
66 | 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216,
67 | 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221,
68 | 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226,
69 | 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231,
70 | 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236,
71 | 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241,
72 | 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246,
73 | 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251,
74 | 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255, 0, 0, 0
75 | ]
76 | color_palette = np.array(_palette).reshape(-1, 3)
77 |
78 |
79 | def overlay(image, mask, colors=[255, 0, 0], cscale=1, alpha=0.4):
80 | colors = np.atleast_2d(colors) * cscale
81 |
82 | im_overlay = image.copy()
83 | object_ids = np.unique(mask)
84 |
85 | for object_id in object_ids[1:]:
86 | # Overlay color on binary mask
87 |
88 | foreground = image * alpha + np.ones(
89 | image.shape) * (1 - alpha) * np.array(colors[object_id])
90 | binary_mask = mask == object_id
91 |
92 | # Compose image
93 | im_overlay[binary_mask] = foreground[binary_mask]
94 |
95 | countours = binary_dilation(binary_mask) ^ binary_mask
96 | im_overlay[countours, :] = 0
97 |
98 | return im_overlay.astype(image.dtype)
99 |
100 |
101 | def demo(cfg):
102 | video_fps = 15
103 | gpu_id = cfg.TEST_GPU_ID
104 |
105 | # Load pre-trained model
106 | print('Build AOT model.')
107 | model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(gpu_id)
108 |
109 | print('Load checkpoint from {}'.format(cfg.TEST_CKPT_PATH))
110 | model, _ = load_network(model, cfg.TEST_CKPT_PATH, gpu_id)
111 |
112 | print('Build AOT engine.')
113 | engine = build_engine(cfg.MODEL_ENGINE,
114 | phase='eval',
115 | aot_model=model,
116 | gpu_id=gpu_id,
117 | long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP)
118 |
119 | # Prepare datasets for each sequence
120 | transform = transforms.Compose([
121 | tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE,
122 | cfg.TEST_FLIP, cfg.TEST_MULTISCALE,
123 | cfg.MODEL_ALIGN_CORNERS),
124 | tr.MultiToTensor()
125 | ])
126 | image_root = os.path.join(cfg.TEST_DATA_PATH, 'images')
127 | label_root = os.path.join(cfg.TEST_DATA_PATH, 'masks')
128 |
129 | sequences = os.listdir(image_root)
130 | seq_datasets = []
131 | for seq_name in sequences:
132 | print('Build a dataset for sequence {}.'.format(seq_name))
133 | seq_images = np.sort(os.listdir(os.path.join(image_root, seq_name)))
134 | seq_labels = [seq_images[0].replace('jpg', 'png')]
135 | seq_dataset = VOSTest(image_root,
136 | label_root,
137 | seq_name,
138 | seq_images,
139 | seq_labels,
140 | transform=transform)
141 | seq_datasets.append(seq_dataset)
142 |
143 | # Infer
144 | output_root = cfg.TEST_OUTPUT_PATH
145 | output_mask_root = os.path.join(output_root, 'pred_masks')
146 | if not os.path.exists(output_mask_root):
147 | os.makedirs(output_mask_root)
148 |
149 | for seq_dataset in seq_datasets:
150 | seq_name = seq_dataset.seq_name
151 | image_seq_root = os.path.join(image_root, seq_name)
152 | output_mask_seq_root = os.path.join(output_mask_root, seq_name)
153 | if not os.path.exists(output_mask_seq_root):
154 | os.makedirs(output_mask_seq_root)
155 | print('Build a dataloader for sequence {}.'.format(seq_name))
156 | seq_dataloader = DataLoader(seq_dataset,
157 | batch_size=1,
158 | shuffle=False,
159 | num_workers=cfg.TEST_WORKERS,
160 | pin_memory=True)
161 |
162 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
163 | output_video_path = os.path.join(
164 | output_root, '{}_{}fps.avi'.format(seq_name, video_fps))
165 |
166 | print('Start the inference of sequence {}:'.format(seq_name))
167 | model.eval()
168 | engine.restart_engine()
169 | with torch.no_grad():
170 | for frame_idx, samples in enumerate(seq_dataloader):
171 | sample = samples[0]
172 | img_name = sample['meta']['current_name'][0]
173 |
174 | obj_nums = sample['meta']['obj_num']
175 | output_height = sample['meta']['height']
176 | output_width = sample['meta']['width']
177 | obj_idx = sample['meta']['obj_idx']
178 |
179 | obj_nums = [int(obj_num) for obj_num in obj_nums]
180 | obj_idx = [int(_obj_idx) for _obj_idx in obj_idx]
181 |
182 | current_img = sample['current_img']
183 | current_img = current_img.cuda(gpu_id, non_blocking=True)
184 |
185 | if frame_idx == 0:
186 | videoWriter = cv2.VideoWriter(
187 | output_video_path, fourcc, video_fps,
188 | (int(output_width), int(output_height)))
189 | print(
190 | 'Object number: {}. Inference size: {}x{}. Output size: {}x{}.'
191 | .format(obj_nums[0],
192 | current_img.size()[2],
193 | current_img.size()[3], int(output_height),
194 | int(output_width)))
195 | current_label = sample['current_label'].cuda(
196 | gpu_id, non_blocking=True).float()
197 | current_label = F.interpolate(current_label,
198 | size=current_img.size()[2:],
199 | mode="nearest")
200 | # add reference frame
201 | engine.add_reference_frame(current_img,
202 | current_label,
203 | frame_step=0,
204 | obj_nums=obj_nums)
205 | else:
206 | print('Processing image {}...'.format(img_name))
207 | # predict segmentation
208 | engine.match_propogate_one_frame(current_img)
209 | pred_logit = engine.decode_current_logits(
210 | (output_height, output_width))
211 | pred_prob = torch.softmax(pred_logit, dim=1)
212 | pred_label = torch.argmax(pred_prob, dim=1,
213 | keepdim=True).float()
214 | _pred_label = F.interpolate(pred_label,
215 | size=engine.input_size_2d,
216 | mode="nearest")
217 | # update memory
218 | engine.update_memory(_pred_label)
219 |
220 | # save results
221 | input_image_path = os.path.join(image_seq_root, img_name)
222 | output_mask_path = os.path.join(
223 | output_mask_seq_root,
224 | img_name.split('.')[0] + '.png')
225 |
226 | pred_label = Image.fromarray(
227 | pred_label.squeeze(0).squeeze(0).cpu().numpy().astype(
228 | 'uint8')).convert('P')
229 | pred_label.putpalette(_palette)
230 | pred_label.save(output_mask_path)
231 |
232 | input_image = Image.open(input_image_path)
233 |
234 | overlayed_image = overlay(
235 | np.array(input_image, dtype=np.uint8),
236 | np.array(pred_label, dtype=np.uint8), color_palette)
237 | videoWriter.write(overlayed_image[..., [2, 1, 0]])
238 |
239 | print('Save a visualization video to {}.'.format(output_video_path))
240 | videoWriter.release()
241 |
242 |
243 | def main():
244 | import argparse
245 | parser = argparse.ArgumentParser(description="AOT Demo")
246 | parser.add_argument('--exp_name', type=str, default='default')
247 |
248 | parser.add_argument('--stage', type=str, default='pre_ytb_dav')
249 | parser.add_argument('--model', type=str, default='r50_aotl')
250 |
251 | parser.add_argument('--gpu_id', type=int, default=0)
252 |
253 | parser.add_argument('--data_path', type=str, default='./datasets/Demo')
254 | parser.add_argument('--output_path', type=str, default='./demo_output')
255 | parser.add_argument('--ckpt_path',
256 | type=str,
257 | default='./pretrain_models/R50_AOTL_PRE_YTB_DAV.pth')
258 |
259 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3)
260 |
261 | parser.add_argument('--amp', action='store_true')
262 | parser.set_defaults(amp=False)
263 |
264 | args = parser.parse_args()
265 |
266 | engine_config = importlib.import_module('configs.' + args.stage)
267 | cfg = engine_config.EngineConfig(args.exp_name, args.model)
268 |
269 | cfg.TEST_GPU_ID = args.gpu_id
270 |
271 | cfg.TEST_CKPT_PATH = args.ckpt_path
272 | cfg.TEST_DATA_PATH = args.data_path
273 | cfg.TEST_OUTPUT_PATH = args.output_path
274 |
275 | cfg.TEST_MIN_SIZE = None
276 | cfg.TEST_MAX_SIZE = args.max_resolution * 800. / 480.
277 |
278 | if args.amp:
279 | with torch.cuda.amp.autocast(enabled=True):
280 | demo(cfg)
281 | else:
282 | demo(cfg)
283 |
284 |
285 | if __name__ == '__main__':
286 | main()
287 |
--------------------------------------------------------------------------------
/tools/eval.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 |
4 | sys.path.append('.')
5 | sys.path.append('..')
6 |
7 | import torch
8 | import torch.multiprocessing as mp
9 |
10 | from networks.managers.evaluator import Evaluator
11 |
12 |
13 | def main_worker(gpu, cfg, seq_queue=None, info_queue=None, enable_amp=False):
14 | # Initiate a evaluating manager
15 | evaluator = Evaluator(rank=gpu,
16 | cfg=cfg,
17 | seq_queue=seq_queue,
18 | info_queue=info_queue)
19 | # Start evaluation
20 | if enable_amp:
21 | with torch.cuda.amp.autocast(enabled=True):
22 | evaluator.evaluating()
23 | else:
24 | evaluator.evaluating()
25 |
26 |
27 | def main():
28 | import argparse
29 | parser = argparse.ArgumentParser(description="Eval VOS")
30 | parser.add_argument('--exp_name', type=str, default='default')
31 |
32 | parser.add_argument('--stage', type=str, default='pre')
33 | parser.add_argument('--model', type=str, default='aott')
34 | parser.add_argument('--lstt_num', type=int, default=-1)
35 | parser.add_argument('--lt_gap', type=int, default=-1)
36 | parser.add_argument('--st_skip', type=int, default=-1)
37 | parser.add_argument('--max_id_num', type=int, default='-1')
38 |
39 | parser.add_argument('--gpu_id', type=int, default=0)
40 | parser.add_argument('--gpu_num', type=int, default=1)
41 |
42 | parser.add_argument('--ckpt_path', type=str, default='')
43 | parser.add_argument('--ckpt_step', type=int, default=-1)
44 |
45 | parser.add_argument('--dataset', type=str, default='')
46 | parser.add_argument('--split', type=str, default='')
47 |
48 | parser.add_argument('--ema', action='store_true')
49 | parser.set_defaults(ema=False)
50 |
51 | parser.add_argument('--flip', action='store_true')
52 | parser.set_defaults(flip=False)
53 | parser.add_argument('--ms', nargs='+', type=float, default=[1.])
54 |
55 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3)
56 |
57 | parser.add_argument('--amp', action='store_true')
58 | parser.set_defaults(amp=False)
59 |
60 | args = parser.parse_args()
61 |
62 | engine_config = importlib.import_module('configs.' + args.stage)
63 | cfg = engine_config.EngineConfig(args.exp_name, args.model)
64 |
65 | cfg.TEST_EMA = args.ema
66 |
67 | cfg.TEST_GPU_ID = args.gpu_id
68 | cfg.TEST_GPU_NUM = args.gpu_num
69 |
70 | if args.lstt_num > 0:
71 | cfg.MODEL_LSTT_NUM = args.lstt_num
72 | if args.lt_gap > 0:
73 | cfg.TEST_LONG_TERM_MEM_GAP = args.lt_gap
74 | if args.st_skip > 0:
75 | cfg.TEST_SHORT_TERM_MEM_SKIP = args.st_skip
76 |
77 | if args.max_id_num > 0:
78 | cfg.MODEL_MAX_OBJ_NUM = args.max_id_num
79 |
80 | if args.ckpt_path != '':
81 | cfg.TEST_CKPT_PATH = args.ckpt_path
82 | if args.ckpt_step > 0:
83 | cfg.TEST_CKPT_STEP = args.ckpt_step
84 |
85 | if args.dataset != '':
86 | cfg.TEST_DATASET = args.dataset
87 |
88 | if args.split != '':
89 | cfg.TEST_DATASET_SPLIT = args.split
90 |
91 | cfg.TEST_FLIP = args.flip
92 | cfg.TEST_MULTISCALE = args.ms
93 |
94 | if cfg.TEST_MULTISCALE != [1.]:
95 | cfg.TEST_MAX_SHORT_EDGE = args.max_resolution # for preventing OOM
96 | else:
97 | cfg.TEST_MAX_SHORT_EDGE = None # the default resolution setting of CFBI and AOT
98 | cfg.TEST_MAX_LONG_EDGE = args.max_resolution * 800. / 480.
99 |
100 | if args.gpu_num > 1:
101 | mp.set_start_method('spawn')
102 | seq_queue = mp.Queue()
103 | info_queue = mp.Queue()
104 | mp.spawn(main_worker,
105 | nprocs=cfg.TEST_GPU_NUM,
106 | args=(cfg, seq_queue, info_queue, args.amp))
107 | else:
108 | main_worker(0, cfg, enable_amp=args.amp)
109 |
110 |
111 | if __name__ == '__main__':
112 | main()
113 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import random
3 | import sys
4 |
5 | sys.setrecursionlimit(10000)
6 | sys.path.append('.')
7 | sys.path.append('..')
8 |
9 | import torch.multiprocessing as mp
10 |
11 | from networks.managers.trainer import Trainer
12 |
13 |
14 | def main_worker(gpu, cfg, enable_amp=True):
15 | # Initiate a training manager
16 | trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp)
17 | # Start Training
18 | trainer.sequential_training()
19 |
20 |
21 | def main():
22 | import argparse
23 | parser = argparse.ArgumentParser(description="Train VOS")
24 | parser.add_argument('--exp_name', type=str, default='')
25 | parser.add_argument('--stage', type=str, default='pre')
26 | parser.add_argument('--model', type=str, default='aott')
27 | parser.add_argument('--max_id_num', type=int, default='-1')
28 |
29 | parser.add_argument('--start_gpu', type=int, default=0)
30 | parser.add_argument('--gpu_num', type=int, default=-1)
31 | parser.add_argument('--batch_size', type=int, default=-1)
32 | parser.add_argument('--dist_url', type=str, default='')
33 | parser.add_argument('--amp', action='store_true')
34 | parser.set_defaults(amp=False)
35 |
36 | parser.add_argument('--pretrained_path', type=str, default='')
37 |
38 | parser.add_argument('--datasets', nargs='+', type=str, default=[])
39 | parser.add_argument('--lr', type=float, default=-1.)
40 | parser.add_argument('--total_step', type=int, default=-1.)
41 | parser.add_argument('--start_step', type=int, default=-1.)
42 |
43 | args = parser.parse_args()
44 |
45 | engine_config = importlib.import_module('configs.' + args.stage)
46 |
47 | cfg = engine_config.EngineConfig(args.exp_name, args.model)
48 |
49 | if len(args.datasets) > 0:
50 | cfg.DATASETS = args.datasets
51 |
52 | cfg.DIST_START_GPU = args.start_gpu
53 | if args.gpu_num > 0:
54 | cfg.TRAIN_GPUS = args.gpu_num
55 | if args.batch_size > 0:
56 | cfg.TRAIN_BATCH_SIZE = args.batch_size
57 |
58 | if args.pretrained_path != '':
59 | cfg.PRETRAIN_MODEL = args.pretrained_path
60 |
61 | if args.max_id_num > 0:
62 | cfg.MODEL_MAX_OBJ_NUM = args.max_id_num
63 |
64 | if args.lr > 0:
65 | cfg.TRAIN_LR = args.lr
66 |
67 | if args.total_step > 0:
68 | cfg.TRAIN_TOTAL_STEPS = args.total_step
69 |
70 | if args.start_step > 0:
71 | cfg.TRAIN_START_STEP = args.start_step
72 |
73 | if args.dist_url == '':
74 | cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str(
75 | random.randint(0, 9))
76 | else:
77 | cfg.DIST_URL = args.dist_url
78 |
79 | if cfg.TRAIN_GPUS > 1:
80 | # Use torch.multiprocessing.spawn to launch distributed processes
81 | mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp))
82 | else:
83 | cfg.TRAIN_GPUS = 1
84 | main_worker(0, cfg, args.amp)
85 |
86 | if __name__ == '__main__':
87 | main()
88 |
--------------------------------------------------------------------------------
/train_eval.sh:
--------------------------------------------------------------------------------
1 | exp="default"
2 | gpu_num="4"
3 |
4 | model="aott"
5 | # model="aots"
6 | # model="aotb"
7 | # model="aotl"
8 | # model="r50_aotl"
9 | # model="swinb_aotl"
10 |
11 | ## Training ##
12 | stage="pre"
13 | python tools/train.py --amp \
14 | --exp_name ${exp} \
15 | --stage ${stage} \
16 | --model ${model} \
17 | --gpu_num ${gpu_num}
18 |
19 | stage="pre_ytb_dav"
20 | python tools/train.py --amp \
21 | --exp_name ${exp} \
22 | --stage ${stage} \
23 | --model ${model} \
24 | --gpu_num ${gpu_num}
25 |
26 | ## Evaluation ##
27 | dataset="davis2017"
28 | split="test"
29 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \
30 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num}
31 |
32 | dataset="davis2017"
33 | split="val"
34 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \
35 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num}
36 |
37 | dataset="davis2016"
38 | split="val"
39 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \
40 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num}
41 |
42 | dataset="youtubevos2018"
43 | split="val" # or "val_all_frames"
44 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \
45 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num}
46 |
47 | dataset="youtubevos2019"
48 | split="val" # or "val_all_frames"
49 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \
50 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num}
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/utils/__init__.py
--------------------------------------------------------------------------------
/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import shutil
4 | import numpy as np
5 |
6 |
7 | def load_network_and_optimizer(net, opt, pretrained_dir, gpu, scaler=None):
8 | pretrained = torch.load(pretrained_dir,
9 | map_location=torch.device("cuda:" + str(gpu)))
10 | pretrained_dict = pretrained['state_dict']
11 | model_dict = net.state_dict()
12 | pretrained_dict_update = {}
13 | pretrained_dict_remove = []
14 | for k, v in pretrained_dict.items():
15 | if k in model_dict:
16 | pretrained_dict_update[k] = v
17 | elif k[:7] == 'module.':
18 | if k[7:] in model_dict:
19 | pretrained_dict_update[k[7:]] = v
20 | else:
21 | pretrained_dict_remove.append(k)
22 | model_dict.update(pretrained_dict_update)
23 | net.load_state_dict(model_dict)
24 | opt.load_state_dict(pretrained['optimizer'])
25 | if scaler is not None and 'scaler' in pretrained.keys():
26 | scaler.load_state_dict(pretrained['scaler'])
27 | del (pretrained)
28 | return net.cuda(gpu), opt, pretrained_dict_remove
29 |
30 |
31 | def load_network_and_optimizer_v2(net, opt, pretrained_dir, gpu, scaler=None):
32 | pretrained = torch.load(pretrained_dir,
33 | map_location=torch.device("cuda:" + str(gpu)))
34 | # load model
35 | pretrained_dict = pretrained['state_dict']
36 | model_dict = net.state_dict()
37 | pretrained_dict_update = {}
38 | pretrained_dict_remove = []
39 | for k, v in pretrained_dict.items():
40 | if k in model_dict:
41 | pretrained_dict_update[k] = v
42 | elif k[:7] == 'module.':
43 | if k[7:] in model_dict:
44 | pretrained_dict_update[k[7:]] = v
45 | else:
46 | pretrained_dict_remove.append(k)
47 | model_dict.update(pretrained_dict_update)
48 | net.load_state_dict(model_dict)
49 |
50 | # load optimizer
51 | opt_dict = opt.state_dict()
52 | all_params = {
53 | param_group['name']: param_group['params'][0]
54 | for param_group in opt_dict['param_groups']
55 | }
56 | pretrained_opt_dict = {'state': {}, 'param_groups': []}
57 | for idx in range(len(pretrained['optimizer']['param_groups'])):
58 | param_group = pretrained['optimizer']['param_groups'][idx]
59 | if param_group['name'] in all_params.keys():
60 | pretrained_opt_dict['state'][all_params[
61 | param_group['name']]] = pretrained['optimizer']['state'][
62 | param_group['params'][0]]
63 | param_group['params'][0] = all_params[param_group['name']]
64 | pretrained_opt_dict['param_groups'].append(param_group)
65 |
66 | opt_dict.update(pretrained_opt_dict)
67 | opt.load_state_dict(opt_dict)
68 |
69 | # load scaler
70 | if scaler is not None and 'scaler' in pretrained.keys():
71 | scaler.load_state_dict(pretrained['scaler'])
72 | del (pretrained)
73 | return net.cuda(gpu), opt, pretrained_dict_remove
74 |
75 |
76 | def load_network(net, pretrained_dir, gpu):
77 | pretrained = torch.load(pretrained_dir,
78 | map_location=torch.device("cuda:" + str(gpu)))
79 | if 'state_dict' in pretrained.keys():
80 | pretrained_dict = pretrained['state_dict']
81 | elif 'model' in pretrained.keys():
82 | pretrained_dict = pretrained['model']
83 | else:
84 | pretrained_dict = pretrained
85 | model_dict = net.state_dict()
86 | pretrained_dict_update = {}
87 | pretrained_dict_remove = []
88 | for k, v in pretrained_dict.items():
89 | if k in model_dict:
90 | pretrained_dict_update[k] = v
91 | elif k[:7] == 'module.':
92 | if k[7:] in model_dict:
93 | pretrained_dict_update[k[7:]] = v
94 | else:
95 | pretrained_dict_remove.append(k)
96 | model_dict.update(pretrained_dict_update)
97 | net.load_state_dict(model_dict)
98 | del (pretrained)
99 | return net.cuda(gpu), pretrained_dict_remove
100 |
101 |
102 | def save_network(net,
103 | opt,
104 | step,
105 | save_path,
106 | max_keep=8,
107 | backup_dir='./saved_models',
108 | scaler=None):
109 | ckpt = {'state_dict': net.state_dict(), 'optimizer': opt.state_dict()}
110 | if scaler is not None:
111 | ckpt['scaler'] = scaler.state_dict()
112 |
113 | try:
114 | if not os.path.exists(save_path):
115 | os.makedirs(save_path)
116 | save_file = 'save_step_%s.pth' % (step)
117 | save_dir = os.path.join(save_path, save_file)
118 | torch.save(ckpt, save_dir)
119 | except:
120 | save_path = backup_dir
121 | if not os.path.exists(save_path):
122 | os.makedirs(save_path)
123 | save_file = 'save_step_%s.pth' % (step)
124 | save_dir = os.path.join(save_path, save_file)
125 | torch.save(ckpt, save_dir)
126 |
127 | all_ckpt = os.listdir(save_path)
128 | if len(all_ckpt) > max_keep:
129 | all_step = []
130 | for ckpt_name in all_ckpt:
131 | step = int(ckpt_name.split('_')[-1].split('.')[0])
132 | all_step.append(step)
133 | all_step = list(np.sort(all_step))[:-max_keep]
134 | for step in all_step:
135 | ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step))
136 | os.system('rm {}'.format(ckpt_path))
137 |
138 |
139 | def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"):
140 | exps = os.listdir(curr_dir)
141 | for exp in exps:
142 | exp_dir = os.path.join(curr_dir, exp)
143 | stages = os.listdir(exp_dir)
144 | for stage in stages:
145 | stage_dir = os.path.join(exp_dir, stage)
146 | finals = ["ema_ckpt", "ckpt"]
147 | for final in finals:
148 | final_dir = os.path.join(stage_dir, final)
149 | ckpts = os.listdir(final_dir)
150 | for ckpt in ckpts:
151 | if '.pth' not in ckpt:
152 | continue
153 | curr_ckpt_path = os.path.join(final_dir, ckpt)
154 | remote_ckpt_path = os.path.join(remote_dir, exp, stage,
155 | final, ckpt)
156 | if os.path.exists(remote_ckpt_path):
157 | os.system('rm {}'.format(remote_ckpt_path))
158 | try:
159 | shutil.copy(curr_ckpt_path, remote_ckpt_path)
160 | print("Copy {} to {}.".format(curr_ckpt_path,
161 | remote_ckpt_path))
162 | except OSError as Inst:
163 | return
164 |
--------------------------------------------------------------------------------
/utils/cp_ckpt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"):
6 | exps = os.listdir(curr_dir)
7 | for exp in exps:
8 | print("Exp: ", exp)
9 | exp_dir = os.path.join(curr_dir, exp)
10 | stages = os.listdir(exp_dir)
11 | for stage in stages:
12 | print("Stage: ", stage)
13 | stage_dir = os.path.join(exp_dir, stage)
14 | finals = ["ema_ckpt", "ckpt"]
15 | for final in finals:
16 | print("Final: ", final)
17 | final_dir = os.path.join(stage_dir, final)
18 | ckpts = os.listdir(final_dir)
19 | for ckpt in ckpts:
20 | if '.pth' not in ckpt:
21 | continue
22 | curr_ckpt_path = os.path.join(final_dir, ckpt)
23 | remote_ckpt_path = os.path.join(remote_dir, exp, stage,
24 | final, ckpt)
25 | if os.path.exists(remote_ckpt_path):
26 | os.system('rm {}'.format(remote_ckpt_path))
27 | try:
28 | shutil.copy(curr_ckpt_path, remote_ckpt_path)
29 | print(ckpt, ': OK')
30 | except OSError as Inst:
31 | print(Inst)
32 | print(ckpt, ': Fail')
33 |
34 |
35 | if __name__ == "__main__":
36 | cp_ckpt()
37 |
--------------------------------------------------------------------------------
/utils/ema.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import unicode_literals
3 |
4 | import torch
5 |
6 |
7 | def get_param_buffer_for_ema(model,
8 | update_buffer=False,
9 | required_buffers=['running_mean', 'running_var']):
10 | params = model.parameters()
11 | all_param_buffer = [p for p in params if p.requires_grad]
12 | if update_buffer:
13 | named_buffers = model.named_buffers()
14 | for key, value in named_buffers:
15 | for buffer_name in required_buffers:
16 | if buffer_name in key:
17 | all_param_buffer.append(value)
18 | break
19 | return all_param_buffer
20 |
21 |
22 | class ExponentialMovingAverage:
23 | """
24 | Maintains (exponential) moving average of a set of parameters.
25 | """
26 | def __init__(self, parameters, decay, use_num_updates=True):
27 | """
28 | Args:
29 | parameters: Iterable of `torch.nn.Parameter`; usually the result of
30 | `model.parameters()`.
31 | decay: The exponential decay.
32 | use_num_updates: Whether to use number of updates when computing
33 | averages.
34 | """
35 | if decay < 0.0 or decay > 1.0:
36 | raise ValueError('Decay must be between 0 and 1')
37 | self.decay = decay
38 | self.num_updates = 0 if use_num_updates else None
39 | self.shadow_params = [p.clone().detach() for p in parameters]
40 | self.collected_params = []
41 |
42 | def update(self, parameters):
43 | """
44 | Update currently maintained parameters.
45 | Call this every time the parameters are updated, such as the result of
46 | the `optimizer.step()` call.
47 | Args:
48 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of
49 | parameters used to initialize this object.
50 | """
51 | decay = self.decay
52 | if self.num_updates is not None:
53 | self.num_updates += 1
54 | decay = min(decay,
55 | (1 + self.num_updates) / (10 + self.num_updates))
56 | one_minus_decay = 1.0 - decay
57 | with torch.no_grad():
58 | for s_param, param in zip(self.shadow_params, parameters):
59 | s_param.sub_(one_minus_decay * (s_param - param))
60 |
61 | def copy_to(self, parameters):
62 | """
63 | Copy current parameters into given collection of parameters.
64 | Args:
65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66 | updated with the stored moving averages.
67 | """
68 | for s_param, param in zip(self.shadow_params, parameters):
69 | param.data.copy_(s_param.data)
70 |
71 | def store(self, parameters):
72 | """
73 | Save the current parameters for restoring later.
74 | Args:
75 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
76 | temporarily stored.
77 | """
78 | self.collected_params = [param.clone() for param in parameters]
79 |
80 | def restore(self, parameters):
81 | """
82 | Restore the parameters stored with the `store` method.
83 | Useful to validate the model with EMA parameters without affecting the
84 | original optimization process. Store the parameters before the
85 | `copy_to` method. After validation (or model saving), use this to
86 | restore the former parameters.
87 | Args:
88 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
89 | updated with the stored parameters.
90 | """
91 | for c_param, param in zip(self.collected_params, parameters):
92 | param.data.copy_(c_param.data)
93 | del (self.collected_params)
94 |
--------------------------------------------------------------------------------
/utils/eval.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 | import os
3 |
4 |
5 | def zip_folder(source_folder, zip_dir):
6 | f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED)
7 | pre_len = len(os.path.dirname(source_folder))
8 | for dirpath, dirnames, filenames in os.walk(source_folder):
9 | for filename in filenames:
10 | pathfile = os.path.join(dirpath, filename)
11 | arcname = pathfile[pre_len:].strip(os.path.sep)
12 | f.write(pathfile, arcname)
13 | f.close()
--------------------------------------------------------------------------------
/utils/image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 | import threading
5 |
6 | _palette = [
7 | 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128,
8 | 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0,
9 | 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0,
10 | 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24,
11 | 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30,
12 | 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36,
13 | 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43,
14 | 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49,
15 | 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55,
16 | 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62,
17 | 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68,
18 | 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74,
19 | 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81,
20 | 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87,
21 | 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93,
22 | 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99,
23 | 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104,
24 | 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109,
25 | 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114,
26 | 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119,
27 | 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124,
28 | 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129,
29 | 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134,
30 | 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139,
31 | 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144,
32 | 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149,
33 | 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154,
34 | 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159,
35 | 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164,
36 | 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169,
37 | 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174,
38 | 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179,
39 | 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184,
40 | 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189,
41 | 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194,
42 | 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199,
43 | 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204,
44 | 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209,
45 | 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214,
46 | 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219,
47 | 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224,
48 | 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229,
49 | 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234,
50 | 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239,
51 | 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244,
52 | 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249,
53 | 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254,
54 | 255, 255, 255
55 | ]
56 |
57 |
58 | def label2colormap(label):
59 |
60 | m = label.astype(np.uint8)
61 | r, c = m.shape
62 | cmap = np.zeros((r, c, 3), dtype=np.uint8)
63 | cmap[:, :, 0] = (m & 1) << 7 | (m & 8) << 3 | (m & 64) >> 1
64 | cmap[:, :, 1] = (m & 2) << 6 | (m & 16) << 2 | (m & 128) >> 2
65 | cmap[:, :, 2] = (m & 4) << 5 | (m & 32) << 1
66 | return cmap
67 |
68 |
69 | def one_hot_mask(mask, cls_num):
70 | if len(mask.size()) == 3:
71 | mask = mask.unsqueeze(1)
72 | indices = torch.arange(0, cls_num + 1,
73 | device=mask.device).view(1, -1, 1, 1)
74 | return (mask == indices).float()
75 |
76 |
77 | def masked_image(image, colored_mask, mask, alpha=0.7):
78 | mask = np.expand_dims(mask > 0, axis=0)
79 | mask = np.repeat(mask, 3, axis=0)
80 | show_img = (image * alpha + colored_mask *
81 | (1 - alpha)) * mask + image * (1 - mask)
82 | return show_img
83 |
84 |
85 | def save_image(image, path):
86 | im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0)))
87 | im.save(path)
88 |
89 |
90 | def _save_mask(mask, path, squeeze_idx=None):
91 | if squeeze_idx is not None:
92 | unsqueezed_mask = mask * 0
93 | for idx in range(1, len(squeeze_idx)):
94 | obj_id = squeeze_idx[idx]
95 | mask_i = mask == idx
96 | unsqueezed_mask += (mask_i * obj_id).astype(np.uint8)
97 | mask = unsqueezed_mask
98 | mask = Image.fromarray(mask).convert('P')
99 | mask.putpalette(_palette)
100 | mask.save(path)
101 |
102 |
103 | def save_mask(mask_tensor, path, squeeze_idx=None):
104 | mask = mask_tensor.cpu().numpy().astype('uint8')
105 | threading.Thread(target=_save_mask, args=[mask, path, squeeze_idx]).start()
106 |
107 |
108 | def flip_tensor(tensor, dim=0):
109 | inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1,
110 | device=tensor.device).long()
111 | tensor = tensor.index_select(dim, inv_idx)
112 | return tensor
113 |
114 |
115 | def shuffle_obj_mask(mask):
116 |
117 | bs, obj_num, _, _ = mask.size()
118 | new_masks = []
119 | for idx in range(bs):
120 | now_mask = mask[idx]
121 | random_matrix = torch.eye(obj_num, device=mask.device)
122 | fg = random_matrix[1:][torch.randperm(obj_num - 1)]
123 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0)
124 | now_mask = torch.einsum('nm,nhw->mhw', random_matrix, now_mask)
125 | new_masks.append(now_mask)
126 |
127 | return torch.stack(new_masks, dim=0)
128 |
--------------------------------------------------------------------------------
/utils/learning.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def adjust_learning_rate(optimizer,
5 | base_lr,
6 | p,
7 | itr,
8 | max_itr,
9 | restart=1,
10 | warm_up_steps=1000,
11 | is_cosine_decay=False,
12 | min_lr=1e-5,
13 | encoder_lr_ratio=1.0,
14 | freeze_params=[]):
15 |
16 | if restart > 1:
17 | each_max_itr = int(math.ceil(float(max_itr) / restart))
18 | itr = itr % each_max_itr
19 | warm_up_steps /= restart
20 | max_itr = each_max_itr
21 |
22 | if itr < warm_up_steps:
23 | now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps
24 | else:
25 | itr = itr - warm_up_steps
26 | max_itr = max_itr - warm_up_steps
27 | if is_cosine_decay:
28 | now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr /
29 | (max_itr + 1)) +
30 | 1.) * 0.5
31 | else:
32 | now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p
33 |
34 | for param_group in optimizer.param_groups:
35 | if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]:
36 | param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr
37 | else:
38 | param_group['lr'] = now_lr
39 |
40 | for freeze_param in freeze_params:
41 | if freeze_param in param_group["name"]:
42 | param_group['lr'] = 0
43 | param_group['weight_decay'] = 0
44 | break
45 |
46 | return now_lr
47 |
48 |
49 | def get_trainable_params(model,
50 | base_lr,
51 | weight_decay,
52 | use_frozen_bn=False,
53 | exclusive_wd_dict={},
54 | no_wd_keys=[]):
55 | params = []
56 | memo = set()
57 | total_param = 0
58 | for key, value in model.named_parameters():
59 | if value in memo:
60 | continue
61 | total_param += value.numel()
62 | if not value.requires_grad:
63 | continue
64 | memo.add(value)
65 | wd = weight_decay
66 | for exclusive_key in exclusive_wd_dict.keys():
67 | if exclusive_key in key:
68 | wd = exclusive_wd_dict[exclusive_key]
69 | break
70 | if len(value.shape) == 1: # normalization layers
71 | if 'bias' in key: # bias requires no weight decay
72 | wd = 0.
73 | elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay
74 | wd = 0.
75 | elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder
76 | wd = 0.
77 | else:
78 | for no_wd_key in no_wd_keys:
79 | if no_wd_key in key:
80 | wd = 0.
81 | break
82 | params += [{
83 | "params": [value],
84 | "lr": base_lr,
85 | "weight_decay": wd,
86 | "name": key
87 | }]
88 |
89 | print('Total Param: {:.2f}M'.format(total_param / 1e6))
90 | return params
91 |
92 |
93 | def freeze_params(module):
94 | for p in module.parameters():
95 | p.requires_grad = False
96 |
97 |
98 | def calculate_params(state_dict):
99 | memo = set()
100 | total_param = 0
101 | for key, value in state_dict.items():
102 | if value in memo:
103 | continue
104 | memo.add(value)
105 | total_param += value.numel()
106 | print('Total Param: {:.2f}M'.format(total_param / 1e6))
107 |
--------------------------------------------------------------------------------
/utils/math.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0):
5 | all_matrix = []
6 | for idx in range(num):
7 | random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id))
8 | if keep_first:
9 | fg = random_matrix[1:][torch.randperm(dim - 1)]
10 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0)
11 | else:
12 | random_matrix = random_matrix[torch.randperm(dim)]
13 | all_matrix.append(random_matrix)
14 | return torch.stack(all_matrix, dim=0)
15 |
16 |
17 | def truncated_normal_(tensor, mean=0, std=.02):
18 | size = tensor.shape
19 | tmp = tensor.new_empty(size + (4, )).normal_()
20 | valid = (tmp < 2) & (tmp > -2)
21 | ind = valid.max(-1, keepdim=True)[1]
22 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
23 | tensor.data.mul_(std).add_(mean)
24 | return tensor
25 |
--------------------------------------------------------------------------------
/utils/meters.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 | def __init__(self, momentum=0.999):
7 | self.val = 0
8 | self.avg = 0
9 | self.sum = 0
10 | self.count = 0
11 | self.long_count = 0
12 | self.momentum = momentum
13 | self.moving_avg = 0
14 |
15 | def reset(self):
16 | self.val = 0
17 | self.avg = 0
18 | self.sum = 0
19 | self.count = 0
20 |
21 | def update(self, val, n=1):
22 | if self.long_count == 0:
23 | self.moving_avg = val
24 | else:
25 | momentum = min(self.momentum, 1. - 1. / self.long_count)
26 | self.moving_avg = self.moving_avg * momentum + val * (1 - momentum)
27 | self.val = val
28 | self.sum += val * n
29 | self.count += n
30 | self.long_count += n
31 | self.avg = self.sum / self.count
32 |
--------------------------------------------------------------------------------
/utils/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def pytorch_iou(pred, target, obj_num, epsilon=1e-6):
5 | '''
6 | pred: [bs, h, w]
7 | target: [bs, h, w]
8 | obj_num: [bs]
9 | '''
10 | bs = pred.size(0)
11 | all_iou = []
12 | for idx in range(bs):
13 | now_pred = pred[idx].unsqueeze(0)
14 | now_target = target[idx].unsqueeze(0)
15 | now_obj_num = obj_num[idx]
16 |
17 | obj_ids = torch.arange(0, now_obj_num + 1,
18 | device=now_pred.device).int().view(-1, 1, 1)
19 | if obj_ids.size(0) == 1: # only contain background
20 | continue
21 | else:
22 | obj_ids = obj_ids[1:]
23 | now_pred = (now_pred == obj_ids).float()
24 | now_target = (now_target == obj_ids).float()
25 |
26 | intersection = (now_pred * now_target).sum((1, 2))
27 | union = ((now_pred + now_target) > 0).float().sum((1, 2))
28 |
29 | now_iou = (intersection + epsilon) / (union + epsilon)
30 |
31 | all_iou.append(now_iou.mean())
32 | if len(all_iou) > 0:
33 | all_iou = torch.stack(all_iou).mean()
34 | else:
35 | all_iou = torch.ones((1), device=pred.device)
36 | return all_iou
37 |
--------------------------------------------------------------------------------