├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── configs ├── default.py ├── models │ ├── aotb.py │ ├── aotl.py │ ├── aots.py │ ├── aott.py │ ├── deaotb.py │ ├── deaotl.py │ ├── deaots.py │ ├── deaott.py │ ├── default.py │ ├── default_deaot.py │ ├── r101_aotl.py │ ├── r50_aotl.py │ ├── r50_deaotl.py │ ├── rs101_aotl.py │ ├── swinb_aotl.py │ └── swinb_deaotl.py ├── pre.py ├── pre_dav.py ├── pre_ytb.py ├── pre_ytb_dav.py └── ytb.py ├── dataloaders ├── __init__.py ├── eval_datasets.py ├── image_transforms.py ├── train_datasets.py └── video_transforms.py ├── datasets ├── DAVIS │ └── README.md ├── Demo │ ├── images │ │ ├── 1001_3iEIq5HBY1s │ │ │ ├── 00002058.jpg │ │ │ ├── 00002059.jpg │ │ │ ├── 00002060.jpg │ │ │ ├── 00002061.jpg │ │ │ ├── 00002062.jpg │ │ │ ├── 00002063.jpg │ │ │ ├── 00002064.jpg │ │ │ ├── 00002065.jpg │ │ │ ├── 00002066.jpg │ │ │ ├── 00002067.jpg │ │ │ ├── 00002068.jpg │ │ │ ├── 00002069.jpg │ │ │ ├── 00002070.jpg │ │ │ ├── 00002071.jpg │ │ │ ├── 00002072.jpg │ │ │ ├── 00002073.jpg │ │ │ ├── 00002074.jpg │ │ │ ├── 00002075.jpg │ │ │ ├── 00002076.jpg │ │ │ ├── 00002077.jpg │ │ │ ├── 00002078.jpg │ │ │ ├── 00002079.jpg │ │ │ ├── 00002080.jpg │ │ │ ├── 00002081.jpg │ │ │ ├── 00002082.jpg │ │ │ ├── 00002083.jpg │ │ │ ├── 00002084.jpg │ │ │ ├── 00002085.jpg │ │ │ ├── 00002086.jpg │ │ │ ├── 00002087.jpg │ │ │ ├── 00002088.jpg │ │ │ ├── 00002089.jpg │ │ │ ├── 00002090.jpg │ │ │ ├── 00002091.jpg │ │ │ ├── 00002092.jpg │ │ │ ├── 00002093.jpg │ │ │ ├── 00002094.jpg │ │ │ ├── 00002095.jpg │ │ │ ├── 00002096.jpg │ │ │ ├── 00002097.jpg │ │ │ ├── 00002098.jpg │ │ │ ├── 00002099.jpg │ │ │ ├── 00002100.jpg │ │ │ ├── 00002101.jpg │ │ │ └── 00002102.jpg │ │ └── 1007_YCTBBdbKSSg │ │ │ ├── 00000693.jpg │ │ │ ├── 00000694.jpg │ │ │ ├── 00000695.jpg │ │ │ ├── 00000696.jpg │ │ │ ├── 00000697.jpg │ │ │ ├── 00000698.jpg │ │ │ ├── 00000699.jpg │ │ │ ├── 00000700.jpg │ │ │ ├── 00000701.jpg │ │ │ ├── 00000702.jpg │ │ │ ├── 00000703.jpg │ │ │ ├── 00000704.jpg │ │ │ ├── 00000705.jpg │ │ │ ├── 00000706.jpg │ │ │ ├── 00000707.jpg │ │ │ ├── 00000708.jpg │ │ │ ├── 00000709.jpg │ │ │ ├── 00000710.jpg │ │ │ ├── 00000711.jpg │ │ │ ├── 00000712.jpg │ │ │ ├── 00000713.jpg │ │ │ ├── 00000714.jpg │ │ │ ├── 00000715.jpg │ │ │ ├── 00000716.jpg │ │ │ ├── 00000717.jpg │ │ │ ├── 00000718.jpg │ │ │ ├── 00000719.jpg │ │ │ ├── 00000720.jpg │ │ │ ├── 00000721.jpg │ │ │ ├── 00000722.jpg │ │ │ ├── 00000723.jpg │ │ │ ├── 00000724.jpg │ │ │ ├── 00000725.jpg │ │ │ ├── 00000726.jpg │ │ │ ├── 00000727.jpg │ │ │ ├── 00000728.jpg │ │ │ ├── 00000729.jpg │ │ │ ├── 00000730.jpg │ │ │ ├── 00000731.jpg │ │ │ ├── 00000732.jpg │ │ │ ├── 00000733.jpg │ │ │ ├── 00000734.jpg │ │ │ ├── 00000735.jpg │ │ │ └── 00000736.jpg │ └── masks │ │ ├── 1001_3iEIq5HBY1s │ │ └── 00002058.png │ │ └── 1007_YCTBBdbKSSg │ │ └── 00000693.png ├── Static │ └── README.md └── YTB │ ├── 2018 │ ├── train │ │ └── README.md │ ├── valid │ │ └── README.md │ └── valid_all_frames │ │ └── README.md │ └── 2019 │ ├── train │ └── README.md │ ├── valid │ └── README.md │ └── valid_all_frames │ └── README.md ├── networks ├── __init__.py ├── decoders │ ├── __init__.py │ └── fpn.py ├── encoders │ ├── __init__.py │ ├── mobilenetv2.py │ ├── mobilenetv3.py │ ├── resnest │ │ ├── __init__.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ └── splat.py │ ├── resnet.py │ └── swin │ │ ├── __init__.py │ │ ├── build.py │ │ └── swin_transformer.py ├── engines │ ├── __init__.py │ ├── aot_engine.py │ └── deaot_engine.py ├── layers │ ├── __init__.py │ ├── attention.py │ ├── basic.py │ ├── loss.py │ ├── normalization.py │ ├── position.py │ └── transformer.py ├── managers │ ├── evaluator.py │ └── trainer.py └── models │ ├── __init__.py │ ├── aot.py │ └── deaot.py ├── pretrain_models └── README.md ├── source ├── 1001_3iEIq5HBY1s.gif ├── 1007_YCTBBdbKSSg.gif ├── kobe.gif ├── messi.gif ├── overview.png ├── overview_deaot.png └── some_results.png ├── tools ├── demo.py ├── eval.py └── train.py ├── train_eval.sh └── utils ├── __init__.py ├── checkpoint.py ├── cp_ckpt.py ├── ema.py ├── eval.py ├── image.py ├── learning.py ├── math.py ├── meters.py └── metric.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, z-x-yang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | ## Model Zoo and Results 2 | 3 | ### Environment and Settings 4 | - 4/1 NVIDIA V100 GPUs for training/evaluation. 5 | - Auto-mixed precision was enabled in training but disabled in evaluation. 6 | - Test-time augmentations were not used. 7 | - The inference resolution of DAVIS/YouTube-VOS was 480p/1.3x480p as [CFBI](https://github.com/z-x-yang/CFBI). 8 | - Fully online inference. We passed all the modules frame by frame. 9 | - Multi-object FPS was recorded instead of single-object one. 10 | 11 | ### Pre-trained Models 12 | Stages: 13 | 14 | - `PRE`: the pre-training stage with static images. 15 | 16 | - `PRE_YTB_DAV`: the main-training stage with YouTube-VOS and DAVIS. All the kinds of evaluation share an **identical** model and the **same** parameters. 17 | 18 | 19 | | Model | Param (M) | PRE | PRE_YTB_DAV | 20 | |:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:| 21 | | AOTT | 5.7 | [gdrive](https://drive.google.com/file/d/1_513h8Hok9ySQPMs_dHgX5sPexUhyCmy/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1owPmwV4owd_ll6GuilzklqTyAd0ZvbCu/view?usp=sharing) | 22 | | AOTS | 7.0 | [gdrive](https://drive.google.com/file/d/1QUP0-VED-lOF1oX_ppYWnXyBjvUzJJB7/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1beU5E6Mdnr_pPrgjWvdWurKAIwJSz1xf/view?usp=sharing) | 23 | | AOTB | 8.3 | [gdrive](https://drive.google.com/file/d/11Bx8n_INAha1IdpHjueGpf7BrKmCJDvK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1hH-GOn4GAxHkV8ARcQzsUy8Ax6ndot-A/view?usp=sharing) | 24 | | AOTL | 8.3 | [gdrive](https://drive.google.com/file/d/1WL6QCsYeT7Bt-Gain9ZIrNNXpR2Hgh29/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1L1N2hkSPqrwGgnW9GyFHuG59_EYYfTG4/view?usp=sharing) | 25 | | R50-AOTL | 14.9 | [gdrive](https://drive.google.com/file/d/1hS4JIvOXeqvbs-CokwV6PwZV-EvzE6x8/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) | 26 | | SwinB-AOTL | 65.4 | [gdrive](https://drive.google.com/file/d/1LlhKQiXD8JyZGGs3hZiNzcaCLqyvL9tj/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/192jCGQZdnuTsvX-CVra-KVZl2q1ZR0vW/view?usp=sharing) | 27 | 28 | | Model | Param (M) | PRE | PRE_YTB_DAV | 29 | |:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:| 30 | | DeAOTT | 7.2 | [gdrive](https://drive.google.com/file/d/11C1ZBoFpL3ztKtINS8qqwPSldfYXexFK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1ThWIZQS03cYWx1EKNN8MIMnJS5eRowzr/view?usp=sharing) | 31 | | DeAOTS | 10.2 | [gdrive](https://drive.google.com/file/d/1uUidrWVoaP9A5B5-EzQLbielUnRLRF3j/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1YwIAV5tBtn5spSFxKLBQBEQGwPHyQlHi/view?usp=sharing) | 32 | | DeAOTB | 13.2 | [gdrive](https://drive.google.com/file/d/1bEQr6vIgQMVITrSOtxWTMgycKpS0cor9/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1BHxsonnvJXylqHlZ1zJHHc-ymKyq-CFf/view?usp=sharing) | 33 | | DeAOTL | 13.2 | [gdrive](https://drive.google.com/file/d/1_vBL4KJlmBy0oBE4YFDOvsYL1ZtpEL32/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/18elNz_wi9JyVBcIUYKhRdL08MA-FqHD5/view?usp=sharing) | 34 | | R50-DeAOTL | 19.8 | [gdrive](https://drive.google.com/file/d/1sTRQ1g0WCpqVCdavv7uJiZNkXunBt3-R/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ/view?usp=sharing) | 35 | | SwinB-DeAOTL | 70.3 | [gdrive](https://drive.google.com/file/d/16BZEE53no8CxT-pPLDC2q1d6Xlg8mWPU/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1g4E-F0RPOx9Nd6J7tU9AE1TjsouL4oZq/view?usp=sharing) | 36 | 37 | To use our pre-trained model to infer, a simple way is to set `--model` and `--ckpt_path` to your downloaded checkpoint's model type and file path when running `eval.py`. 38 | 39 | ### YouTube-VOS 2018 val 40 | `ALL-F`: all frames. The default evaluation setting of YouTube-VOS is 6fps, but 30fps sequences (all the frames) are also supplied by the dataset organizers. We noticed that many VOS methods prefer to evaluate with 30fps videos. Thus, we also supply our results here. Denser video sequences can significantly improve VOS performance when using the memory reading strategy (like AOTL, R50-AOTL, and SwinB-AOTL), but the efficiency will be influenced since more memorized frames are stored for object matching. 41 | | Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions | 42 | |:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:| 43 | | AOTT | PRE_YTB_DAV | 41.0 | | 80.2 | 80.4 | 85.0 | 73.6 | 81.7 | [gdrive](https://drive.google.com/file/d/1u8mvPRT08ENZHsw9Xf_4C6Sv9BoCzENR/view?usp=sharing) | 44 | | AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 80.0 | 84.7 | 75.2 | 83.5 | [gdrive](https://drive.google.com/file/d/1RGMI5-29Z0odq73rt26eCxOUYUd-fvVv/view?usp=sharing) | 45 | | DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.6** | **86.3** | **75.8** | **84.2** | - | 46 | | AOTS | PRE_YTB_DAV | 27.1 | | 82.9 | 82.3 | 87.0 | 77.1 | 85.1 | [gdrive](https://drive.google.com/file/d/1a4-rNnxjMuPBq21IKo31WDYZXMPgS7r2/view?usp=sharing) | 47 | | AOTS | PRE_YTB_DAV | 27.1 | √ | 83.0 | 82.2 | 87.0 | 77.3 | 85.7 | [gdrive](https://drive.google.com/file/d/1Z0cndyoCw5Na6u-VFRE8CyiIG2RbMIUO/view?usp=sharing) | 48 | | DeAOTS | PRE_YTB_DAV | **38.7** | | **84.0** | **83.3** | **88.3** | **77.9** | **86.6** | - | 49 | | AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.2 | 88.1 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1J5nhuQbbjVLYNXViBIgo21ddQy-MiOLG/view?usp=sharing) | 50 | | AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.6 | 88.5 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1gFaweB_GTJjHzSD61v_ZsY9K7UEND30O/view?usp=sharing) | 51 | | DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.9** | **88.9** | **78.5** | **87.0** | - | 52 | | AOTL | PRE_YTB_DAV | 16.0 | | 84.1 | 83.2 | 88.2 | 78.2 | 86.8 | [gdrive](https://drive.google.com/file/d/1kS8KWQ2L3wzxt44ROLTxwZOT7ZpT8Igc/view?usp=sharing) | 53 | | AOTL | PRE_YTB_DAV | 6.5 | √ | 84.5 | 83.7 | 88.8 | 78.4 | **87.1** | [gdrive](https://drive.google.com/file/d/1Rpm3e215kJOUvb562lJ2kYg2I3hkrxiM/view?usp=sharing) | 54 | | DeAOTL | PRE_YTB_DAV | **24.7** | | **84.8** | **84.2** | **89.4** | **78.6** | 87.0 | - | 55 | | R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.6 | 83.7 | 88.5 | 78.8 | 87.3 | [gdrive](https://drive.google.com/file/d/1nbJZ1bbmEgyK-bg6HQ8LwCz5gVJ6wzIZ/view?usp=sharing) | 56 | | R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.5 | 84.5 | 89.5 | 79.6 | 88.2 | [gdrive](https://drive.google.com/file/d/1NbB54ZhYvfJh38KFOgovYYPjWopd-2TE/view?usp=sharing) | 57 | | R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **86.0** | **84.9** | **89.9** | **80.4** | **88.7** | - | 58 | | SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.5 | 89.5 | 78.1 | 86.7 | [gdrive](https://drive.google.com/file/d/1QFowulSY0LHfpsjUV8ZE9rYc55L9DOC7/view?usp=sharing) | 59 | | SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.1 | 85.1 | 90.1 | 78.4 | 86.9 | [gdrive](https://drive.google.com/file/d/1TulhVOhh01rkssNYbOQASeWKu7CQ5Azx/view?usp=sharing) | 60 | | SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.2** | **85.6** | **90.6** | **80.0** | **88.4** | - | 61 | 62 | ### YouTube-VOS 2019 val 63 | | Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions | 64 | |:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:| 65 | | AOTT | PRE_YTB_DAV | 41.0 | | 80.0 | 79.8 | 84.2 | 74.1 | 82.1 | [gdrive](https://drive.google.com/file/d/1zzyhN1XYtajte5nbZ7opOdfXeDJgCxC5/view?usp=sharing) | 66 | | AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 79.9 | 84.4 | 75.6 | 83.8 | [gdrive](https://drive.google.com/file/d/1V_5vi9dAXOis_WrDieacSESm7OX20Bv-/view?usp=sharing) | 67 | | DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.2** | **85.6** | **76.4** | **84.7** | - | 68 | | AOTS | PRE_YTB_DAV | 27.1 | | 82.7 | 81.9 | 86.5 | 77.3 | 85.2 | [gdrive](https://drive.google.com/file/d/11YdkUeyjkTv8Uw7xMgPCBzJs6v5SDt6n/view?usp=sharing) | 69 | | AOTS | PRE_YTB_DAV | 27.1 | √ | 82.8 | 81.9 | 86.5 | 77.3 | 85.6 | [gdrive](https://drive.google.com/file/d/1UhyurGTJeAw412czU3_ebzNwF8xQ4QG_/view?usp=sharing) | 70 | | DeAOTS | PRE_YTB_DAV | **38.7** | | **83.8** | **82.8** | **87.5** | **78.1** | **86.8** | - | 71 | | AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.1 | 87.7 | 78.5 | 86.8 | [gdrive](https://drive.google.com/file/d/1NeI8cT4kVqTqVWAwtwiga1rkrvksNWaO/view?usp=sharing) | 72 | | AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.3 | 88.0 | 78.2 | 86.7 | [gdrive](https://drive.google.com/file/d/1kpYV2XFR0sOfLWD-wMhd-nUO6CFiLjlL/view?usp=sharing) | 73 | | DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.5** | **88.3** | **79.1** | **87.5** | - | 74 | | AOTL | PRE_YTB_DAV | 16.0 | | 84.0 | 82.8 | 87.6 | 78.6 | 87.1 | [gdrive](https://drive.google.com/file/d/1qKLlNXxmT31bW0weEHI_zAf4QwU8Lhou/view?usp=sharing) | 75 | | AOTL | PRE_YTB_DAV | 6.5 | √ | 84.2 | 83.0 | 87.8 | 78.7 | 87.3 | [gdrive](https://drive.google.com/file/d/1o3fwZ0cH71bqHSA3bYNjhP4GGv9Vyuwa/view?usp=sharing) | 76 | | DeAOTL | PRE_YTB_DAV | **24.7** | | **84.7** | **83.8** | **88.8** | **79.0** | **87.2** | - | 77 | | R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.4 | 83.4 | 88.1 | 78.7 | 87.2 | [gdrive](https://drive.google.com/file/d/1I7ooSp8EYfU6fvkP6QcCMaxeencA68AH/view?usp=sharing) | 78 | | R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.3 | 83.9 | 88.8 | 79.9 | 88.5 | [gdrive](https://drive.google.com/file/d/1OGqlkEu0uXa8QVWIVz_M5pmXXiYR2sh3/view?usp=sharing) | 79 | | R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **85.9** | **84.6** | **89.4** | **80.8** | **88.9** | - | 80 | | SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.0 | 88.8 | 78.7 | 87.1 | [gdrive](https://drive.google.com/file/d/1fPzCxi5GM7N2sLKkhoTC2yoY_oTQCHp1/view?usp=sharing) | 81 | | SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.3 | 84.6 | 89.5 | 79.3 | 87.7 | [gdrive](https://drive.google.com/file/d/1e3D22s_rJ7Y2X2MHo7x5lcNtwmHFlwYB/view?usp=sharing) | 82 | | SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.1** | **85.3** | **90.2** | **80.4** | **88.6** | - | 83 | 84 | ### DAVIS-2017 test 85 | 86 | | Model | Stage | FPS | Mean | J Score | F Score | Predictions | 87 | | ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:| 88 | | AOTT | PRE_YTB_DAV | **51.4** | 73.7 | 70.0 | 77.3 | [gdrive](https://drive.google.com/file/d/14Pu-6Uz4rfmJ_WyL2yl57KTx_pSSUNAf/view?usp=sharing) | 89 | | AOTS | PRE_YTB_DAV | 40.0 | 75.2 | 71.4 | 78.9 | [gdrive](https://drive.google.com/file/d/1zzAPZCRLgnBWuAXqejPPEYLqBxu67Rj1/view?usp=sharing) | 90 | | AOTB | PRE_YTB_DAV | 29.6 | 77.4 | 73.7 | 81.1 | [gdrive](https://drive.google.com/file/d/1WpQ-_Jrs7Ssfw0oekrejM2OVWEx_tBN1/view?usp=sharing) | 91 | | AOTL | PRE_YTB_DAV | 18.7 | 79.3 | 75.5 | 83.2 | [gdrive](https://drive.google.com/file/d/1rP1Zdgc0N1d8RR2EaXMz3F-o5zqcNVe8/view?usp=sharing) | 92 | | R50-AOTL | PRE_YTB_DAV | 18.0 | 79.5 | 76.0 | 83.0 | [gdrive](https://drive.google.com/file/d/1iQ5iNlvlS-In586ZNc4LIZMSdNIWDvle/view?usp=sharing) | 93 | | SwinB-AOTL | PRE_YTB_DAV | 12.1 | **82.1** | **78.2** | **85.9** | [gdrive](https://drive.google.com/file/d/1oVt4FPcZdfVHiOxjYYKef0q7Ovy4f5Q_/view?usp=sharing) | 94 | 95 | ### DAVIS-2017 val 96 | 97 | | Model | Stage | FPS | Mean | J Score | F Score | Predictions | 98 | | ---------- |:-----------:|:----:|:--------:|:--------:|:---------:|:----:| 99 | | AOTT | PRE_YTB_DAV | **51.4** | 79.2 | 76.5 | 81.9 | [gdrive](https://drive.google.com/file/d/10OUFhK2Sz-hOJrTDoTI0mA45KO1qodZt/view?usp=sharing) | 100 | | AOTS | PRE_YTB_DAV | 40.0 | 82.1 | 79.3 | 84.8 | [gdrive](https://drive.google.com/file/d/1T-JTYyksWlq45jxcLjnRaBvvYUhWgHFH/view?usp=sharing) | 101 | | AOTB | PRE_YTB_DAV | 29.6 | 83.3 | 80.6 | 85.9 | [gdrive](https://drive.google.com/file/d/1EVUnxQm9TLBTuwK82QyiSKk9R9V8NwRL/view?usp=sharing) | 102 | | AOTL | PRE_YTB_DAV | 18.7 | 83.6 | 80.8 | 86.3 | [gdrive](https://drive.google.com/file/d/1CFauSni2BxAe_fcl8W_6bFByuwJRbDYm/view?usp=sharing) | 103 | | R50-AOTL | PRE_YTB_DAV | 18.0 | 85.2 | 82.5 | 87.9 | [gdrive](https://drive.google.com/file/d/1vjloxnP8R4PZdsH2DDizfU2CrkdRHHyo/view?usp=sharing) | 104 | | SwinB-AOTL | PRE_YTB_DAV | 12.1 | **85.9** | **82.9** | **88.9** | [gdrive](https://drive.google.com/file/d/1tYCbKOas0i7Et2iyUAyDwaXnaD9YWxLr/view?usp=sharing) | 105 | 106 | ### DAVIS-2016 val 107 | 108 | | Model | Stage | FPS | Mean | J Score | F Score | Predictions | 109 | | ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:| 110 | | AOTT | PRE_YTB_DAV | **51.4** | 87.5 | 86.5 | 88.4 | [gdrive](https://drive.google.com/file/d/1LeW8WQhnylZ3umT7E379KdII92uUsGA9/view?usp=sharing) | 111 | | AOTS | PRE_YTB_DAV | 40.0 | 89.6 | 88.6 | 90.5 | [gdrive](https://drive.google.com/file/d/1vqGei5tLu1FPVrTi5bwRAsaGy3Upf7B1/view?usp=sharing) | 112 | | AOTB | PRE_YTB_DAV | 29.6 | 90.9 | 89.6 | 92.1 | [gdrive](https://drive.google.com/file/d/1qAppo2uOVu0FbE9t1FBUpymC3yWgw1LM/view?usp=sharing) | 113 | | AOTL | PRE_YTB_DAV | 18.7 | 91.1 | 89.5 | 92.7 | [gdrive](https://drive.google.com/file/d/1g6cjYhgBWjMaY3RGAm31qm3SPEF3QcKV/view?usp=sharing) | 114 | | R50-AOTL | PRE_YTB_DAV | 18.0 | 91.7 | 90.4 | 93.0 | [gdrive](https://drive.google.com/file/d/1QzxojqWKsvRf53K2AgKsK523ZVuYU4O-/view?usp=sharing) | 115 | | SwinB-AOTL | PRE_YTB_DAV | 12.1 | **92.2** | **90.6** | **93.8** | [gdrive](https://drive.google.com/file/d/1RIqUtAyVnopeogfT520d7a0yiULg1obp/view?usp=sharing) | 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AOT Series Frameworks in PyTorch 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/decoupling-features-in-hierarchical/semi-supervised-video-object-segmentation-on-15)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-15?p=decoupling-features-in-hierarchical) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/video-object-segmentation-on-youtube-vos)](https://paperswithcode.com/sota/video-object-segmentation-on-youtube-vos?p=associating-objects-with-scalable) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/semi-supervised-video-object-segmentation-on-18)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-18?p=associating-objects-with-scalable) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/semi-supervised-video-object-segmentation-on-1)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-1?p=associating-objects-with-scalable) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/visual-object-tracking-on-davis-2017)](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2017?p=associating-objects-with-scalable) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/visual-object-tracking-on-davis-2016)](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2016?p=associating-objects-with-scalable) 9 | 10 | ## News 11 | - `2024/03`: **AOST** - [AOST](https://arxiv.org/abs/2203.11442), the journal extension of AOT, has been accepted by TPAMI. AOST is the first scalable VOS framework supporting run-time speed-accuracy trade-offs, from real-time efficiency to SOTA performance. 12 | - `2023/07`: **Pyramid/Panoptic AOT** - The code of PAOT has been released in [paot](https://github.com/yoxu515/aot-benchmark/tree/paot) branch of this repository. We propose a benchmark [**VIPOSeg**](https://github.com/yoxu515/VIPOSeg-Benchmark) for panoptic VOS, and PAOT is designed to tackle the challenges in panoptic VOS and achieves SOTA performance. PAOT consists of a multi-scale architecture of LSTT (same as MS-AOT in VOT2022) and panoptic ID banks for thing and stuff. Please refer to the [paper](https://arxiv.org/abs/2305.04470) for more details. 13 | - `2023/07`: **WINNER** - DeAOT-based Tracker ranked **1st** in the [**VOTS 2023**](https://www.votchallenge.net/vots2023/) challenge ([leaderboard](https://eu.aihub.ml/competitions/201#results)). In detail, our [DMAOT](https://eu.aihub.ml/my/competition/submission/1139/detailed_results/) improves DeAOT by storing object-wise long-term memories instead of frame-wise long-term memories. This avoids the memory growth problem when processing long video sequences and produces better results when handling multiple objects. 14 | - `2023/06`: **WINNER** - DeAOT-based Tracker ranked **1st** in **two tracks** of [**EPIC-Kitchens**](https://epic-kitchens.github.io/2023) challenges ([leaderboard](http://epic-kitchens.github.io/2023)). In detail, our MS-DeAOT is a multi-scale version of DeAOT and is the winner of Semi-Supervised Video Object Segmentation (segmentation-based tracking) and TREK-150 Object Tracking (BBox-based tracking). Technical reports are coming soon. 15 | - `2023/04`: **SAM-Track** - We are pleased to announce the release of our latest project, [Segment and Track Anything (SAM-Track)](https://github.com/z-x-yang/Segment-and-Track-Anything). This innovative project merges two kinds of models, [SAM](https://github.com/facebookresearch/segment-anything) and our [DeAOT](https://github.com/yoxu515/aot-benchmark), to achieve seamless segmentation and efficient tracking of any objects in videos. 16 | - `2022/10`: **WINNER** - AOT-based Tracker ranked **1st** in **four tracks** of the **VOT 2022** challenge ([presentation of results](https://data.votchallenge.net/vot2022/vot2022_st_rt.pdf)). In detail, our MS-AOT is the winner of two segmentation tracks, VOT-STs2022 and VOT-RTs2022 (real-time). In addition, the bounding box results of MS-AOT (initialized by [AlphaRef](https://github.com/MasterBin-IIAU/AlphaRefine), and output is bounding box fitted to mask prediction) surpass the winners of two bounding box tracks, VOT-STb2022 and VOT-RTb2022 (real-time). The bounding box results were required by the organizers after the competition deadline but were highlighted in the [workshop presentation](https://data.votchallenge.net/vot2022/vot2022_st_rt.pdf) (ECCV 2022). 17 | 18 | ## Intro 19 | A modular reference PyTorch implementation of AOT series frameworks: 20 | - **DeAOT**: Decoupling Features in Hierarchical Propagation for Video Object Segmentation (NeurIPS 2022, Spotlight) [[OpenReview](https://openreview.net/forum?id=DgM7-7eMkq0)][[PDF](https://arxiv.org/pdf/2210.09782.pdf)] 21 | 22 | 23 | - **AOT**: Associating Objects with Transformers for Video Object Segmentation (NeurIPS 2021, Score 8/8/7/8) [[OpenReview](https://openreview.net/forum?id=hl3v8io3ZYt)][[PDF](https://arxiv.org/abs/2106.02638)] 24 | 25 | 26 | An extension of AOT, [AOST](https://arxiv.org/abs/2203.11442) (under review), is available now. AOST is a more robust and flexible framework, supporting run-time speed-accuracy trade-offs. 27 | 28 | ## Examples 29 | Benchmark examples: 30 | 31 | 32 | 33 | General examples (Messi and Kobe): 34 | 35 | 36 | 37 | ## Highlights 38 | - **High performance:** up to **85.5%** ([R50-AOTL](MODEL_ZOO.md#youtube-vos-2018-val)) on YouTube-VOS 2018 and **82.1%** ([SwinB-AOTL]((MODEL_ZOO.md#youtube-vos-2018-val))) on DAVIS-2017 Test-dev under standard settings (without any test-time augmentation and post processing). 39 | - **High efficiency:** up to **51fps** ([AOTT](MODEL_ZOO.md#davis-2017-test)) on DAVIS-2017 (480p) even with **10** objects and **41fps** on YouTube-VOS (1.3x480p). AOT can process multiple objects (less than a pre-defined number, 10 is the default) as efficiently as processing a single object. This project also supports inferring any number of objects together within a video by automatic separation and aggregation. 40 | - **Multi-GPU training and inference** 41 | - **Mixed precision training and inference** 42 | - **Test-time augmentation:** multi-scale and flipping augmentations are supported. 43 | 44 | ## Requirements 45 | * Python3 46 | * pytorch >= 1.7.0 and torchvision 47 | * opencv-python 48 | * Pillow 49 | * Pytorch Correlation. Recommend to install from [source](https://github.com/ClementPinard/Pytorch-Correlation-extension) instead of using `pip`: 50 | ```bash 51 | git clone https://github.com/ClementPinard/Pytorch-Correlation-extension.git 52 | cd Pytorch-Correlation-extension 53 | python setup.py install 54 | cd - 55 | ``` 56 | 57 | Optional: 58 | * scikit-image (if you want to run our **Demo**, please install) 59 | 60 | ## Model Zoo and Results 61 | Pre-trained models, benckmark scores, and pre-computed results reproduced by this project can be found in [MODEL_ZOO.md](MODEL_ZOO.md). 62 | 63 | ## Demo - Panoptic Propagation 64 | We provide a simple demo to demonstrate AOT's effectiveness. The demo will propagate more than **40** objects, including semantic regions (like sky) and instances (like person), together within a single complex scenario and predict its video panoptic segmentation. 65 | 66 | To run the demo, download the [checkpoint](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) of R50-AOTL into [pretrain_models](pretrain_models), and then run: 67 | ```bash 68 | python tools/demo.py 69 | ``` 70 | which will predict the given scenarios in the resolution of 1.3x480p. You can also run this demo with other AOTs ([MODEL_ZOO.md](MODEL_ZOO.md)) by setting `--model` (model type) and `--ckpt_path` (checkpoint path). 71 | 72 | Two scenarios from [VSPW](https://www.vspwdataset.com/home) are supplied in [datasets/Demo](datasets/Demo): 73 | 74 | - 1001_3iEIq5HBY1s: 44 objects. 1080P. 75 | - 1007_YCTBBdbKSSg: 43 objects. 1080P. 76 | 77 | Results: 78 | 79 | 80 | 81 | 82 | ## Getting Started 83 | 0. Prepare a valid environment follow the [requirements](#requirements). 84 | 85 | 1. Prepare datasets: 86 | 87 | Please follow the below instruction to prepare datasets in each corresponding folder. 88 | * **Static** 89 | 90 | [datasets/Static](datasets/Static): pre-training dataset with static images. Guidance can be found in [AFB-URR](https://github.com/xmlyqing00/AFB-URR), which we referred to in the implementation of the pre-training. 91 | * **YouTube-VOS** 92 | 93 | A commonly-used large-scale VOS dataset. 94 | 95 | [datasets/YTB/2019](datasets/YTB/2019): version 2019, download [link](https://drive.google.com/drive/folders/1BWzrCWyPEmBEKm0lOHe5KLuBuQxUSwqz?usp=sharing). `train` is required for training. `valid` (6fps) and `valid_all_frames` (30fps, optional) are used for evaluation. 96 | 97 | [datasets/YTB/2018](datasets/YTB/2018): version 2018, download [link](https://drive.google.com/drive/folders/1bI5J1H3mxsIGo7Kp-pPZU8i6rnykOw7f?usp=sharing). Only `valid` (6fps) and `valid_all_frames` (30fps, optional) are required for this project and used for evaluation. 98 | 99 | * **DAVIS** 100 | 101 | A commonly-used small-scale VOS dataset. 102 | 103 | [datasets/DAVIS](datasets/DAVIS): [TrainVal](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) (480p) contains both the training and validation split. [Test-Dev](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) (480p) contains the Test-dev split. The [full-resolution version](https://davischallenge.org/davis2017/code.html) is also supported for training and evaluation but not required. 104 | 105 | 106 | 2. Prepare ImageNet pre-trained encoders 107 | 108 | Select and download below checkpoints into [pretrain_models](pretrain_models): 109 | 110 | - [MobileNet-V2](https://download.pytorch.org/models/mobilenet_v2-b0353104.pth) (default encoder) 111 | - [MobileNet-V3](https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth) 112 | - [ResNet-50](https://download.pytorch.org/models/resnet50-0676ba61.pth) 113 | - [ResNet-101](https://download.pytorch.org/models/resnet101-63fe2227.pth) 114 | - [ResNeSt-50](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest50-528c19ca.pth) 115 | - [ResNeSt-101](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth) 116 | - [Swin-Base](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth) 117 | 118 | The current default training configs are not optimized for encoders larger than ResNet-50. If you want to use larger encoders, we recommend early stopping the main-training stage at 80,000 iterations (100,000 in default) to avoid over-fitting on the seen classes of YouTube-VOS. 119 | 120 | 121 | 122 | 3. Training and Evaluation 123 | 124 | The [example script](train_eval.sh) will train AOTT with 2 stages using 4 GPUs and auto-mixed precision (`--amp`). The first stage is a pre-training stage using `Static` dataset, and the second stage is a main-training stage, which uses both `YouTube-VOS 2019 train` and `DAVIS-2017 train` for training, resulting in a model that can generalize to different domains (YouTube-VOS and DAVIS) and different frame rates (6fps, 24fps, and 30fps). 125 | 126 | Notably, you can use only the `YouTube-VOS 2019 train` split in the second stage by changing `pre_ytb_dav` to `pre_ytb`, which leads to better YouTube-VOS performance on unseen classes. Besides, if you don't want to do the first stage, you can start the training from stage `ytb`, but the performance will drop about 1~2% absolutely. 127 | 128 | After the training is finished (about 0.6 days for each stage with 4 Tesla V100 GPUs), the [example script](train_eval.sh) will evaluate the model on YouTube-VOS and DAVIS, and the results will be packed into Zip files. For calculating scores, please use official YouTube-VOS servers ([2018 server](https://competitions.codalab.org/competitions/19544) and [2019 server](https://competitions.codalab.org/competitions/20127)), official [DAVIS toolkit](https://github.com/davisvideochallenge/davis-2017) (for Val), and official [DAVIS server](https://competitions.codalab.org/competitions/20516#learn_the_details) (for Test-dev). 129 | 130 | ## Adding your own dataset 131 | Coming 132 | 133 | ## Troubleshooting 134 | Waiting 135 | 136 | ## TODO 137 | - [ ] Code documentation 138 | - [ ] Adding your own dataset 139 | - [ ] Results with test-time augmentations in Model Zoo 140 | - [ ] Support gradient accumulation 141 | - [x] Demo tool 142 | 143 | ## Citations 144 | Please consider citing the related paper(s) in your publications if it helps your research. 145 | ``` 146 | @article{yang2021aost, 147 | title={Scalable Video Object Segmentation with Identification Mechanism}, 148 | author={Yang, Zongxin and Miao, Jiaxu and Wei, Yunchao and Wang, Wenguan and Wang, Xiaohan and Yang, Yi}, 149 | journal={TPAMI}, 150 | year={2024} 151 | } 152 | @inproceedings{xu2023video, 153 | title={Video object segmentation in panoptic wild scenes}, 154 | author={Xu, Yuanyou and Yang, Zongxin and Yang, Yi}, 155 | booktitle={IJCAI}, 156 | year={2023} 157 | } 158 | @inproceedings{yang2022deaot, 159 | title={Decoupling Features in Hierarchical Propagation for Video Object Segmentation}, 160 | author={Yang, Zongxin and Yang, Yi}, 161 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 162 | year={2022} 163 | } 164 | @inproceedings{yang2021aot, 165 | title={Associating Objects with Transformers for Video Object Segmentation}, 166 | author={Yang, Zongxin and Wei, Yunchao and Yang, Yi}, 167 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 168 | year={2021} 169 | } 170 | ``` 171 | 172 | ## License 173 | This project is released under the BSD-3-Clause license. See [LICENSE](LICENSE) for additional details. 174 | -------------------------------------------------------------------------------- /configs/default.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | 5 | class DefaultEngineConfig(): 6 | def __init__(self, exp_name='default', model='aott'): 7 | model_cfg = importlib.import_module('configs.models.' + 8 | model).ModelConfig() 9 | self.__dict__.update(model_cfg.__dict__) # add model config 10 | 11 | self.EXP_NAME = exp_name + '_' + self.MODEL_NAME 12 | 13 | self.STAGE_NAME = 'YTB' 14 | 15 | self.DATASETS = ['youtubevos'] 16 | self.DATA_WORKERS = 8 17 | self.DATA_RANDOMCROP = (465, 18 | 465) if self.MODEL_ALIGN_CORNERS else (464, 19 | 464) 20 | self.DATA_RANDOMFLIP = 0.5 21 | self.DATA_MAX_CROP_STEPS = 10 22 | self.DATA_SHORT_EDGE_LEN = 480 23 | self.DATA_MIN_SCALE_FACTOR = 0.7 24 | self.DATA_MAX_SCALE_FACTOR = 1.3 25 | self.DATA_RANDOM_REVERSE_SEQ = True 26 | self.DATA_SEQ_LEN = 5 27 | self.DATA_DAVIS_REPEAT = 5 28 | self.DATA_RANDOM_GAP_DAVIS = 12 # max frame interval between two sampled frames for DAVIS (24fps) 29 | self.DATA_RANDOM_GAP_YTB = 3 # max frame interval between two sampled frames for YouTube-VOS (6fps) 30 | self.DATA_DYNAMIC_MERGE_PROB = 0.3 31 | 32 | self.PRETRAIN = True 33 | self.PRETRAIN_FULL = False # if False, load encoder only 34 | self.PRETRAIN_MODEL = './data_wd/pretrain_model/mobilenet_v2.pth' 35 | # self.PRETRAIN_MODEL = './pretrain_models/mobilenet_v2-b0353104.pth' 36 | 37 | self.TRAIN_TOTAL_STEPS = 100000 38 | self.TRAIN_START_STEP = 0 39 | self.TRAIN_WEIGHT_DECAY = 0.07 40 | self.TRAIN_WEIGHT_DECAY_EXCLUSIVE = { 41 | # 'encoder.': 0.01 42 | } 43 | self.TRAIN_WEIGHT_DECAY_EXEMPTION = [ 44 | 'absolute_pos_embed', 'relative_position_bias_table', 45 | 'relative_emb_v', 'conv_out' 46 | ] 47 | self.TRAIN_LR = 2e-4 48 | self.TRAIN_LR_MIN = 2e-5 if 'mobilenetv2' in self.MODEL_ENCODER else 1e-5 49 | self.TRAIN_LR_POWER = 0.9 50 | self.TRAIN_LR_ENCODER_RATIO = 0.1 51 | self.TRAIN_LR_WARM_UP_RATIO = 0.05 52 | self.TRAIN_LR_COSINE_DECAY = False 53 | self.TRAIN_LR_RESTART = 1 54 | self.TRAIN_LR_UPDATE_STEP = 1 55 | self.TRAIN_AUX_LOSS_WEIGHT = 1.0 56 | self.TRAIN_AUX_LOSS_RATIO = 1.0 57 | self.TRAIN_OPT = 'adamw' 58 | self.TRAIN_SGD_MOMENTUM = 0.9 59 | self.TRAIN_GPUS = 4 60 | self.TRAIN_BATCH_SIZE = 16 61 | self.TRAIN_TBLOG = False 62 | self.TRAIN_TBLOG_STEP = 50 63 | self.TRAIN_LOG_STEP = 20 64 | self.TRAIN_IMG_LOG = True 65 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15 66 | self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank'] 67 | self.TRAIN_SEQ_TRAINING_START_RATIO = 0.5 68 | self.TRAIN_HARD_MINING_RATIO = 0.5 69 | self.TRAIN_EMA_RATIO = 0.1 70 | self.TRAIN_CLIP_GRAD_NORM = 5. 71 | self.TRAIN_SAVE_STEP = 5000 72 | self.TRAIN_MAX_KEEP_CKPT = 8 73 | self.TRAIN_RESUME = False 74 | self.TRAIN_RESUME_CKPT = None 75 | self.TRAIN_RESUME_STEP = 0 76 | self.TRAIN_AUTO_RESUME = True 77 | self.TRAIN_DATASET_FULL_RESOLUTION = False 78 | self.TRAIN_ENABLE_PREV_FRAME = False 79 | self.TRAIN_ENCODER_FREEZE_AT = 2 80 | self.TRAIN_LSTT_EMB_DROPOUT = 0. 81 | self.TRAIN_LSTT_ID_DROPOUT = 0. 82 | self.TRAIN_LSTT_DROPPATH = 0.1 83 | self.TRAIN_LSTT_DROPPATH_SCALING = False 84 | self.TRAIN_LSTT_DROPPATH_LST = False 85 | self.TRAIN_LSTT_LT_DROPOUT = 0. 86 | self.TRAIN_LSTT_ST_DROPOUT = 0. 87 | 88 | self.TEST_GPU_ID = 0 89 | self.TEST_GPU_NUM = 1 90 | self.TEST_FRAME_LOG = False 91 | self.TEST_DATASET = 'youtubevos' 92 | self.TEST_DATASET_FULL_RESOLUTION = False 93 | self.TEST_DATASET_SPLIT = 'val' 94 | self.TEST_CKPT_PATH = None 95 | # if "None", evaluate the latest checkpoint. 96 | self.TEST_CKPT_STEP = None 97 | self.TEST_FLIP = False 98 | self.TEST_MULTISCALE = [1] 99 | self.TEST_MAX_SHORT_EDGE = None 100 | self.TEST_MAX_LONG_EDGE = 800 * 1.3 101 | self.TEST_WORKERS = 4 102 | 103 | # GPU distribution 104 | self.DIST_ENABLE = True 105 | self.DIST_BACKEND = "nccl" # "gloo" 106 | self.DIST_URL = "tcp://127.0.0.1:13241" 107 | self.DIST_START_GPU = 0 108 | 109 | def init_dir(self): 110 | self.DIR_DATA = '../VOS02/datasets'#'./datasets' 111 | self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS') 112 | self.DIR_YTB = os.path.join(self.DIR_DATA, 'YTB') 113 | self.DIR_STATIC = os.path.join(self.DIR_DATA, 'Static') 114 | 115 | self.DIR_ROOT = './'#'./data_wd/youtube_vos_jobs' 116 | 117 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME, 118 | self.STAGE_NAME) 119 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt') 120 | self.DIR_EMA_CKPT = os.path.join(self.DIR_RESULT, 'ema_ckpt') 121 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log') 122 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard') 123 | # self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img') 124 | # self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval') 125 | self.DIR_IMG_LOG = './img_logs' 126 | self.DIR_EVALUATION = './results' 127 | 128 | for path in [ 129 | self.DIR_RESULT, self.DIR_CKPT, self.DIR_EMA_CKPT, 130 | self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, 131 | self.DIR_TB_LOG 132 | ]: 133 | if not os.path.isdir(path): 134 | try: 135 | os.makedirs(path) 136 | except Exception as inst: 137 | print(inst) 138 | print('Failed to make dir: {}.'.format(path)) 139 | -------------------------------------------------------------------------------- /configs/models/aotb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTB' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | -------------------------------------------------------------------------------- /configs/models/aotl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTL' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | 11 | self.TRAIN_LONG_TERM_MEM_GAP = 2 12 | 13 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /configs/models/aots.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTS' 8 | 9 | self.MODEL_LSTT_NUM = 2 10 | -------------------------------------------------------------------------------- /configs/models/aott.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTT' 8 | -------------------------------------------------------------------------------- /configs/models/deaotb.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTB' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | -------------------------------------------------------------------------------- /configs/models/deaotl.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTL' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | 11 | self.TRAIN_LONG_TERM_MEM_GAP = 2 12 | 13 | self.TEST_LONG_TERM_MEM_GAP = 5 14 | -------------------------------------------------------------------------------- /configs/models/deaots.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTS' 8 | 9 | self.MODEL_LSTT_NUM = 2 10 | -------------------------------------------------------------------------------- /configs/models/deaott.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTT' 8 | -------------------------------------------------------------------------------- /configs/models/default.py: -------------------------------------------------------------------------------- 1 | class DefaultModelConfig(): 2 | def __init__(self): 3 | self.MODEL_NAME = 'AOTDefault' 4 | 5 | self.MODEL_VOS = 'aot' 6 | self.MODEL_ENGINE = 'aotengine' 7 | self.MODEL_ALIGN_CORNERS = True 8 | self.MODEL_ENCODER = 'mobilenetv2' 9 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/mobilenet_v2-b0353104.pth' 10 | self.MODEL_ENCODER_DIM = [24, 32, 96, 1280] # 4x, 8x, 16x, 16x 11 | self.MODEL_ENCODER_EMBEDDING_DIM = 256 12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = True 13 | self.MODEL_FREEZE_BN = True 14 | self.MODEL_FREEZE_BACKBONE = False 15 | self.MODEL_MAX_OBJ_NUM = 10 16 | self.MODEL_SELF_HEADS = 8 17 | self.MODEL_ATT_HEADS = 8 18 | self.MODEL_LSTT_NUM = 1 19 | self.MODEL_EPSILON = 1e-5 20 | self.MODEL_USE_PREV_PROB = False 21 | 22 | self.TRAIN_LONG_TERM_MEM_GAP = 9999 23 | self.TRAIN_AUG_TYPE = 'v1' 24 | 25 | self.TEST_LONG_TERM_MEM_GAP = 9999 26 | 27 | self.TEST_SHORT_TERM_MEM_SKIP = 1 28 | -------------------------------------------------------------------------------- /configs/models/default_deaot.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig as BaseConfig 2 | 3 | 4 | class DefaultModelConfig(BaseConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'DeAOTDefault' 8 | 9 | self.MODEL_VOS = 'deaot' 10 | self.MODEL_ENGINE = 'deaotengine' 11 | 12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = False 13 | 14 | self.MODEL_SELF_HEADS = 1 15 | self.MODEL_ATT_HEADS = 1 16 | 17 | self.TRAIN_AUG_TYPE = 'v2' 18 | -------------------------------------------------------------------------------- /configs/models/r101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /configs/models/r50_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R50_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet50' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /configs/models/r50_deaotl.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R50_DeAOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet50' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | 13 | self.MODEL_LSTT_NUM = 3 14 | 15 | self.TRAIN_LONG_TERM_MEM_GAP = 2 16 | 17 | self.TEST_LONG_TERM_MEM_GAP = 5 18 | -------------------------------------------------------------------------------- /configs/models/rs101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnest101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnest101-22405ba7.pth' # https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /configs/models/swinb_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'SwinB_AOTL' 8 | 9 | self.MODEL_ENCODER = 'swin_base' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth 11 | self.MODEL_ALIGN_CORNERS = False 12 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x 13 | self.MODEL_LSTT_NUM = 3 14 | 15 | self.TRAIN_LONG_TERM_MEM_GAP = 2 16 | 17 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /configs/models/swinb_deaotl.py: -------------------------------------------------------------------------------- 1 | from .default_deaot import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'SwinB_DeAOTL' 8 | 9 | self.MODEL_ENCODER = 'swin_base' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth 11 | self.MODEL_ALIGN_CORNERS = False 12 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x 13 | 14 | self.MODEL_LSTT_NUM = 3 15 | 16 | self.TRAIN_LONG_TERM_MEM_GAP = 2 17 | 18 | self.TEST_LONG_TERM_MEM_GAP = 5 19 | -------------------------------------------------------------------------------- /configs/pre.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultEngineConfig 2 | 3 | 4 | class EngineConfig(DefaultEngineConfig): 5 | def __init__(self, exp_name='default', model='AOTT'): 6 | super().__init__(exp_name, model) 7 | self.STAGE_NAME = 'PRE' 8 | 9 | self.init_dir() 10 | 11 | self.DATASETS = ['static'] 12 | 13 | self.DATA_DYNAMIC_MERGE_PROB = 1.0 14 | 15 | self.TRAIN_LR = 4e-4 16 | self.TRAIN_LR_MIN = 2e-5 17 | self.TRAIN_WEIGHT_DECAY = 0.03 18 | self.TRAIN_SEQ_TRAINING_START_RATIO = 1.0 19 | self.TRAIN_AUX_LOSS_RATIO = 0.1 20 | -------------------------------------------------------------------------------- /configs/pre_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['davis2017'] 13 | 14 | self.TRAIN_TOTAL_STEPS = 50000 15 | 16 | pretrain_stage = 'PRE' 17 | pretrain_ckpt = 'save_step_100000.pth' 18 | self.PRETRAIN_FULL = True # if False, load encoder only 19 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 20 | self.EXP_NAME, pretrain_stage, 21 | 'ema_ckpt', pretrain_ckpt) 22 | -------------------------------------------------------------------------------- /configs/pre_ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_YTB' 9 | 10 | self.init_dir() 11 | 12 | pretrain_stage = 'PRE' 13 | pretrain_ckpt = 'save_step_100000.pth' 14 | self.PRETRAIN_FULL = True # if False, load encoder only 15 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 16 | self.EXP_NAME, pretrain_stage, 17 | 'ema_ckpt', pretrain_ckpt) 18 | -------------------------------------------------------------------------------- /configs/pre_ytb_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_YTB_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['youtubevos', 'davis2017'] 13 | 14 | pretrain_stage = 'PRE' 15 | pretrain_ckpt = 'save_step_100000.pth' 16 | self.PRETRAIN_FULL = True # if False, load encoder only 17 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 18 | self.EXP_NAME, pretrain_stage, 19 | 'ema_ckpt', pretrain_ckpt) 20 | -------------------------------------------------------------------------------- /configs/ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'YTB' 9 | 10 | self.init_dir() 11 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/eval_datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import shutil 4 | import json 5 | import cv2 6 | from PIL import Image 7 | 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | 11 | from utils.image import _palette 12 | 13 | 14 | class VOSTest(Dataset): 15 | def __init__(self, 16 | image_root, 17 | label_root, 18 | seq_name, 19 | images, 20 | labels, 21 | rgb=True, 22 | transform=None, 23 | single_obj=False, 24 | resolution=None): 25 | self.image_root = image_root 26 | self.label_root = label_root 27 | self.seq_name = seq_name 28 | self.images = images 29 | self.labels = labels 30 | self.obj_num = 1 31 | self.num_frame = len(self.images) 32 | self.transform = transform 33 | self.rgb = rgb 34 | self.single_obj = single_obj 35 | self.resolution = resolution 36 | 37 | self.obj_nums = [] 38 | self.obj_indices = [] 39 | 40 | curr_objs = [0] 41 | for img_name in self.images: 42 | self.obj_nums.append(len(curr_objs) - 1) 43 | current_label_name = img_name.split('.')[0] + '.png' 44 | if current_label_name in self.labels: 45 | current_label = self.read_label(current_label_name) 46 | curr_obj = list(np.unique(current_label)) 47 | for obj_idx in curr_obj: 48 | if obj_idx not in curr_objs: 49 | curr_objs.append(obj_idx) 50 | self.obj_indices.append(curr_objs.copy()) 51 | 52 | self.obj_nums[0] = self.obj_nums[1] 53 | 54 | def __len__(self): 55 | return len(self.images) 56 | 57 | def read_image(self, idx): 58 | img_name = self.images[idx] 59 | img_path = os.path.join(self.image_root, self.seq_name, img_name) 60 | img = cv2.imread(img_path) 61 | img = np.array(img, dtype=np.float32) 62 | if self.rgb: 63 | img = img[:, :, [2, 1, 0]] 64 | return img 65 | 66 | def read_label(self, label_name, squeeze_idx=None): 67 | label_path = os.path.join(self.label_root, self.seq_name, label_name) 68 | label = Image.open(label_path) 69 | label = np.array(label, dtype=np.uint8) 70 | if self.single_obj: 71 | label = (label > 0).astype(np.uint8) 72 | elif squeeze_idx is not None: 73 | squeezed_label = label * 0 74 | for idx in range(len(squeeze_idx)): 75 | obj_id = squeeze_idx[idx] 76 | if obj_id == 0: 77 | continue 78 | mask = label == obj_id 79 | squeezed_label += (mask * idx).astype(np.uint8) 80 | label = squeezed_label 81 | return label 82 | 83 | def __getitem__(self, idx): 84 | img_name = self.images[idx] 85 | current_img = self.read_image(idx) 86 | height, width, channels = current_img.shape 87 | if self.resolution is not None: 88 | width = int(np.ceil( 89 | float(width) * self.resolution / float(height))) 90 | height = int(self.resolution) 91 | 92 | current_label_name = img_name.split('.')[0] + '.png' 93 | obj_num = self.obj_nums[idx] 94 | obj_idx = self.obj_indices[idx] 95 | 96 | if current_label_name in self.labels: 97 | current_label = self.read_label(current_label_name, obj_idx) 98 | sample = { 99 | 'current_img': current_img, 100 | 'current_label': current_label 101 | } 102 | else: 103 | sample = {'current_img': current_img} 104 | 105 | sample['meta'] = { 106 | 'seq_name': self.seq_name, 107 | 'frame_num': self.num_frame, 108 | 'obj_num': obj_num, 109 | 'current_name': img_name, 110 | 'height': height, 111 | 'width': width, 112 | 'flip': False, 113 | 'obj_idx': obj_idx 114 | } 115 | 116 | if self.transform is not None: 117 | sample = self.transform(sample) 118 | return sample 119 | 120 | 121 | class YOUTUBEVOS_Test(object): 122 | def __init__(self, 123 | root='./datasets/YTB', 124 | year=2018, 125 | split='val', 126 | transform=None, 127 | rgb=True, 128 | result_root=None): 129 | if split == 'val': 130 | split = 'valid' 131 | root = os.path.join(root, str(year), split) 132 | self.db_root_dir = root 133 | self.result_root = result_root 134 | self.rgb = rgb 135 | self.transform = transform 136 | self.seq_list_file = os.path.join(self.db_root_dir, 'meta.json') 137 | self._check_preprocess() 138 | self.seqs = list(self.ann_f.keys()) 139 | self.image_root = os.path.join(root, 'JPEGImages') 140 | self.label_root = os.path.join(root, 'Annotations') 141 | 142 | def __len__(self): 143 | return len(self.seqs) 144 | 145 | def __getitem__(self, idx): 146 | seq_name = self.seqs[idx] 147 | data = self.ann_f[seq_name]['objects'] 148 | obj_names = list(data.keys()) 149 | images = [] 150 | labels = [] 151 | for obj_n in obj_names: 152 | images += map(lambda x: x + '.jpg', list(data[obj_n]["frames"])) 153 | labels.append(data[obj_n]["frames"][0] + '.png') 154 | images = np.sort(np.unique(images)) 155 | labels = np.sort(np.unique(labels)) 156 | 157 | try: 158 | if not os.path.isfile( 159 | os.path.join(self.result_root, seq_name, labels[0])): 160 | if not os.path.exists(os.path.join(self.result_root, 161 | seq_name)): 162 | os.makedirs(os.path.join(self.result_root, seq_name)) 163 | shutil.copy( 164 | os.path.join(self.label_root, seq_name, labels[0]), 165 | os.path.join(self.result_root, seq_name, labels[0])) 166 | except Exception as inst: 167 | print(inst) 168 | print('Failed to create a result folder for sequence {}.'.format( 169 | seq_name)) 170 | 171 | seq_dataset = VOSTest(self.image_root, 172 | self.label_root, 173 | seq_name, 174 | images, 175 | labels, 176 | transform=self.transform, 177 | rgb=self.rgb) 178 | return seq_dataset 179 | 180 | def _check_preprocess(self): 181 | _seq_list_file = self.seq_list_file 182 | if not os.path.isfile(_seq_list_file): 183 | print(_seq_list_file) 184 | return False 185 | else: 186 | self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] 187 | return True 188 | 189 | 190 | class YOUTUBEVOS_DenseTest(object): 191 | def __init__(self, 192 | root='./datasets/YTB', 193 | year=2018, 194 | split='val', 195 | transform=None, 196 | rgb=True, 197 | result_root=None): 198 | if split == 'val': 199 | split = 'valid' 200 | root_sparse = os.path.join(root, str(year), split) 201 | root_dense = root_sparse + '_all_frames' 202 | self.db_root_dir = root_dense 203 | self.result_root = result_root 204 | self.rgb = rgb 205 | self.transform = transform 206 | self.seq_list_file = os.path.join(root_sparse, 'meta.json') 207 | self._check_preprocess() 208 | self.seqs = list(self.ann_f.keys()) 209 | self.image_root = os.path.join(root_dense, 'JPEGImages') 210 | self.label_root = os.path.join(root_sparse, 'Annotations') 211 | 212 | def __len__(self): 213 | return len(self.seqs) 214 | 215 | def __getitem__(self, idx): 216 | seq_name = self.seqs[idx] 217 | 218 | data = self.ann_f[seq_name]['objects'] 219 | obj_names = list(data.keys()) 220 | images_sparse = [] 221 | for obj_n in obj_names: 222 | images_sparse += map(lambda x: x + '.jpg', 223 | list(data[obj_n]["frames"])) 224 | images_sparse = np.sort(np.unique(images_sparse)) 225 | 226 | images = np.sort( 227 | list(os.listdir(os.path.join(self.image_root, seq_name)))) 228 | start_img = images_sparse[0] 229 | end_img = images_sparse[-1] 230 | for start_idx in range(len(images)): 231 | if start_img in images[start_idx]: 232 | break 233 | for end_idx in range(len(images))[::-1]: 234 | if end_img in images[end_idx]: 235 | break 236 | images = images[start_idx:(end_idx + 1)] 237 | labels = np.sort( 238 | list(os.listdir(os.path.join(self.label_root, seq_name)))) 239 | 240 | try: 241 | if not os.path.isfile( 242 | os.path.join(self.result_root, seq_name, labels[0])): 243 | if not os.path.exists(os.path.join(self.result_root, 244 | seq_name)): 245 | os.makedirs(os.path.join(self.result_root, seq_name)) 246 | shutil.copy( 247 | os.path.join(self.label_root, seq_name, labels[0]), 248 | os.path.join(self.result_root, seq_name, labels[0])) 249 | except Exception as inst: 250 | print(inst) 251 | print('Failed to create a result folder for sequence {}.'.format( 252 | seq_name)) 253 | 254 | seq_dataset = VOSTest(self.image_root, 255 | self.label_root, 256 | seq_name, 257 | images, 258 | labels, 259 | transform=self.transform, 260 | rgb=self.rgb) 261 | seq_dataset.images_sparse = images_sparse 262 | 263 | return seq_dataset 264 | 265 | def _check_preprocess(self): 266 | _seq_list_file = self.seq_list_file 267 | if not os.path.isfile(_seq_list_file): 268 | print(_seq_list_file) 269 | return False 270 | else: 271 | self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] 272 | return True 273 | 274 | 275 | class DAVIS_Test(object): 276 | def __init__(self, 277 | split=['val'], 278 | root='./DAVIS', 279 | year=2017, 280 | transform=None, 281 | rgb=True, 282 | full_resolution=False, 283 | result_root=None): 284 | self.transform = transform 285 | self.rgb = rgb 286 | self.result_root = result_root 287 | if year == 2016: 288 | self.single_obj = True 289 | else: 290 | self.single_obj = False 291 | if full_resolution: 292 | resolution = 'Full-Resolution' 293 | else: 294 | resolution = '480p' 295 | self.image_root = os.path.join(root, 'JPEGImages', resolution) 296 | self.label_root = os.path.join(root, 'Annotations', resolution) 297 | seq_names = [] 298 | for spt in split: 299 | if spt == 'test': 300 | spt = 'test-dev' 301 | with open(os.path.join(root, 'ImageSets', str(year), 302 | spt + '.txt')) as f: 303 | seqs_tmp = f.readlines() 304 | seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) 305 | seq_names.extend(seqs_tmp) 306 | self.seqs = list(np.unique(seq_names)) 307 | 308 | def __len__(self): 309 | return len(self.seqs) 310 | 311 | def __getitem__(self, idx): 312 | seq_name = self.seqs[idx] 313 | images = list( 314 | np.sort(os.listdir(os.path.join(self.image_root, seq_name)))) 315 | labels = [images[0].replace('jpg', 'png')] 316 | 317 | if not os.path.isfile( 318 | os.path.join(self.result_root, seq_name, labels[0])): 319 | seq_result_folder = os.path.join(self.result_root, seq_name) 320 | try: 321 | if not os.path.exists(seq_result_folder): 322 | os.makedirs(seq_result_folder) 323 | except Exception as inst: 324 | print(inst) 325 | print( 326 | 'Failed to create a result folder for sequence {}.'.format( 327 | seq_name)) 328 | source_label_path = os.path.join(self.label_root, seq_name, 329 | labels[0]) 330 | result_label_path = os.path.join(self.result_root, seq_name, 331 | labels[0]) 332 | if self.single_obj: 333 | label = Image.open(source_label_path) 334 | label = np.array(label, dtype=np.uint8) 335 | label = (label > 0).astype(np.uint8) 336 | label = Image.fromarray(label).convert('P') 337 | label.putpalette(_palette) 338 | label.save(result_label_path) 339 | else: 340 | shutil.copy(source_label_path, result_label_path) 341 | 342 | seq_dataset = VOSTest(self.image_root, 343 | self.label_root, 344 | seq_name, 345 | images, 346 | labels, 347 | transform=self.transform, 348 | rgb=self.rgb, 349 | single_obj=self.single_obj, 350 | resolution=480) 351 | return seq_dataset 352 | 353 | 354 | class _EVAL_TEST(Dataset): 355 | def __init__(self, transform, seq_name): 356 | self.seq_name = seq_name 357 | self.num_frame = 10 358 | self.transform = transform 359 | 360 | def __len__(self): 361 | return self.num_frame 362 | 363 | def __getitem__(self, idx): 364 | current_frame_obj_num = 2 365 | height = 400 366 | width = 400 367 | img_name = 'test{}.jpg'.format(idx) 368 | current_img = np.zeros((height, width, 3)).astype(np.float32) 369 | if idx == 0: 370 | current_label = (current_frame_obj_num * np.ones( 371 | (height, width))).astype(np.uint8) 372 | sample = { 373 | 'current_img': current_img, 374 | 'current_label': current_label 375 | } 376 | else: 377 | sample = {'current_img': current_img} 378 | 379 | sample['meta'] = { 380 | 'seq_name': self.seq_name, 381 | 'frame_num': self.num_frame, 382 | 'obj_num': current_frame_obj_num, 383 | 'current_name': img_name, 384 | 'height': height, 385 | 'width': width, 386 | 'flip': False 387 | } 388 | 389 | if self.transform is not None: 390 | sample = self.transform(sample) 391 | return sample 392 | 393 | 394 | class EVAL_TEST(object): 395 | def __init__(self, transform=None, result_root=None): 396 | self.transform = transform 397 | self.result_root = result_root 398 | 399 | self.seqs = ['test1', 'test2', 'test3'] 400 | 401 | def __len__(self): 402 | return len(self.seqs) 403 | 404 | def __getitem__(self, idx): 405 | seq_name = self.seqs[idx] 406 | 407 | if not os.path.exists(os.path.join(self.result_root, seq_name)): 408 | os.makedirs(os.path.join(self.result_root, seq_name)) 409 | 410 | seq_dataset = _EVAL_TEST(self.transform, seq_name) 411 | return seq_dataset 412 | -------------------------------------------------------------------------------- /datasets/DAVIS/README.md: -------------------------------------------------------------------------------- 1 | Put DAVIS 2017 here. -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002058.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002058.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002059.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002059.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002060.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002060.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002061.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002062.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002062.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002063.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002063.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002064.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002064.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002065.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002065.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002066.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002066.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002067.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002067.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002068.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002068.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002069.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002069.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002070.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002070.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002071.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002071.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002072.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002072.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002073.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002073.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002074.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002074.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002075.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002075.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002076.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002076.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002077.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002077.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002078.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002078.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002079.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002079.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002080.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002080.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002081.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002081.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002082.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002082.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002083.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002083.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002084.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002084.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002085.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002085.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002086.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002086.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002087.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002087.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002088.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002088.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002089.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002089.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002090.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002090.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002091.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002092.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002092.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002093.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002093.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002094.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002094.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002095.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002095.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002096.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002096.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002097.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002097.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002098.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002098.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002099.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002099.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002100.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002101.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002101.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1001_3iEIq5HBY1s/00002102.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1001_3iEIq5HBY1s/00002102.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000693.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000693.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000694.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000694.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000695.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000695.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000696.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000696.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000697.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000697.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000698.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000698.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000699.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000699.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000700.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000700.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000701.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000701.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000702.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000702.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000703.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000703.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000704.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000704.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000705.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000705.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000706.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000706.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000707.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000707.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000708.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000708.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000709.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000709.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000710.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000710.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000711.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000711.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000712.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000712.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000713.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000713.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000714.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000714.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000715.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000715.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000716.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000716.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000717.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000717.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000718.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000718.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000719.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000719.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000720.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000720.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000721.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000721.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000722.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000722.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000723.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000723.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000724.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000724.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000725.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000725.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000726.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000726.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000727.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000727.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000728.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000728.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000729.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000729.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000730.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000730.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000731.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000731.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000732.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000732.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000733.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000733.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000734.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000734.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000735.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000735.jpg -------------------------------------------------------------------------------- /datasets/Demo/images/1007_YCTBBdbKSSg/00000736.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/images/1007_YCTBBdbKSSg/00000736.jpg -------------------------------------------------------------------------------- /datasets/Demo/masks/1001_3iEIq5HBY1s/00002058.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/masks/1001_3iEIq5HBY1s/00002058.png -------------------------------------------------------------------------------- /datasets/Demo/masks/1007_YCTBBdbKSSg/00000693.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/datasets/Demo/masks/1007_YCTBBdbKSSg/00000693.png -------------------------------------------------------------------------------- /datasets/Static/README.md: -------------------------------------------------------------------------------- 1 | Put the static dataset here. Guidance can be found in [AFB-URR](https://github.com/xmlyqing00/AFB-URR), which we referred to in the implementation of the pre-training. 2 | -------------------------------------------------------------------------------- /datasets/YTB/2018/train/README.md: -------------------------------------------------------------------------------- 1 | Put the training split of YouTube-VOS 2018 here. -------------------------------------------------------------------------------- /datasets/YTB/2018/valid/README.md: -------------------------------------------------------------------------------- 1 | Put the validation split of YouTube-VOS 2018 here. -------------------------------------------------------------------------------- /datasets/YTB/2018/valid_all_frames/README.md: -------------------------------------------------------------------------------- 1 | Put the all-frame validation split of YouTube-VOS 2018 here. -------------------------------------------------------------------------------- /datasets/YTB/2019/train/README.md: -------------------------------------------------------------------------------- 1 | Put the training split of YouTube-VOS 2019 here. -------------------------------------------------------------------------------- /datasets/YTB/2019/valid/README.md: -------------------------------------------------------------------------------- 1 | Put the validation split of YouTube-VOS 2019 here. -------------------------------------------------------------------------------- /datasets/YTB/2019/valid_all_frames/README.md: -------------------------------------------------------------------------------- 1 | Put the all-frame validation split of YouTube-VOS 2018 here. -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/networks/__init__.py -------------------------------------------------------------------------------- /networks/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.decoders.fpn import FPNSegmentationHead 2 | 3 | 4 | def build_decoder(name, **kwargs): 5 | 6 | if name == 'fpn': 7 | return FPNSegmentationHead(**kwargs) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /networks/decoders/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.layers.basic import ConvGN 5 | 6 | 7 | class FPNSegmentationHead(nn.Module): 8 | def __init__(self, 9 | in_dim, 10 | out_dim, 11 | decode_intermediate_input=True, 12 | hidden_dim=256, 13 | shortcut_dims=[24, 32, 96, 1280], 14 | align_corners=True): 15 | super().__init__() 16 | self.align_corners = align_corners 17 | 18 | self.decode_intermediate_input = decode_intermediate_input 19 | 20 | self.conv_in = ConvGN(in_dim, hidden_dim, 1) 21 | 22 | self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3) 23 | self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3) 24 | self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3) 25 | 26 | self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1) 27 | self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1) 28 | self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1) 29 | 30 | self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1) 31 | 32 | self._init_weight() 33 | 34 | def forward(self, inputs, shortcuts): 35 | 36 | if self.decode_intermediate_input: 37 | x = torch.cat(inputs, dim=1) 38 | else: 39 | x = inputs[-1] 40 | 41 | x = F.relu_(self.conv_in(x)) 42 | x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x)) 43 | 44 | x = F.interpolate(x, 45 | size=shortcuts[-3].size()[-2:], 46 | mode="bilinear", 47 | align_corners=self.align_corners) 48 | x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x)) 49 | 50 | x = F.interpolate(x, 51 | size=shortcuts[-4].size()[-2:], 52 | mode="bilinear", 53 | align_corners=self.align_corners) 54 | x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x)) 55 | 56 | x = self.conv_out(x) 57 | 58 | return x 59 | 60 | def _init_weight(self): 61 | for p in self.parameters(): 62 | if p.dim() > 1: 63 | nn.init.xavier_uniform_(p) 64 | -------------------------------------------------------------------------------- /networks/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.encoders.mobilenetv2 import MobileNetV2 2 | from networks.encoders.mobilenetv3 import MobileNetV3Large 3 | from networks.encoders.resnet import ResNet101, ResNet50 4 | from networks.encoders.resnest import resnest 5 | from networks.encoders.swin import build_swin_model 6 | from networks.layers.normalization import FrozenBatchNorm2d 7 | from torch import nn 8 | 9 | 10 | def build_encoder(name, frozen_bn=True, freeze_at=-1): 11 | if frozen_bn: 12 | BatchNorm = FrozenBatchNorm2d 13 | else: 14 | BatchNorm = nn.BatchNorm2d 15 | 16 | if name == 'mobilenetv2': 17 | return MobileNetV2(16, BatchNorm, freeze_at=freeze_at) 18 | elif name == 'mobilenetv3': 19 | return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at) 20 | elif name == 'resnet50': 21 | return ResNet50(16, BatchNorm, freeze_at=freeze_at) 22 | elif name == 'resnet101': 23 | return ResNet101(16, BatchNorm, freeze_at=freeze_at) 24 | elif name == 'resnest50': 25 | return resnest.resnest50(norm_layer=BatchNorm, 26 | dilation=2, 27 | freeze_at=freeze_at) 28 | elif name == 'resnest101': 29 | return resnest.resnest101(norm_layer=BatchNorm, 30 | dilation=2, 31 | freeze_at=freeze_at) 32 | elif 'swin' in name: 33 | return build_swin_model(name, freeze_at=freeze_at) 34 | else: 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /networks/encoders/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import Tensor 3 | from typing import Callable, Optional, List 4 | from utils.learning import freeze_params 5 | 6 | __all__ = ['MobileNetV2'] 7 | 8 | 9 | def _make_divisible(v: float, 10 | divisor: int, 11 | min_value: Optional[int] = None) -> int: 12 | """ 13 | This function is taken from the original tf repo. 14 | It ensures that all layers have a channel number that is divisible by 8 15 | It can be seen here: 16 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 17 | """ 18 | if min_value is None: 19 | min_value = divisor 20 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 21 | # Make sure that round down does not go down by more than 10%. 22 | if new_v < 0.9 * v: 23 | new_v += divisor 24 | return new_v 25 | 26 | 27 | class ConvBNActivation(nn.Sequential): 28 | def __init__( 29 | self, 30 | in_planes: int, 31 | out_planes: int, 32 | kernel_size: int = 3, 33 | stride: int = 1, 34 | groups: int = 1, 35 | padding: int = -1, 36 | norm_layer: Optional[Callable[..., nn.Module]] = None, 37 | activation_layer: Optional[Callable[..., nn.Module]] = None, 38 | dilation: int = 1, 39 | ) -> None: 40 | if padding == -1: 41 | padding = (kernel_size - 1) // 2 * dilation 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if activation_layer is None: 45 | activation_layer = nn.ReLU6 46 | super().__init__( 47 | nn.Conv2d(in_planes, 48 | out_planes, 49 | kernel_size, 50 | stride, 51 | padding, 52 | dilation=dilation, 53 | groups=groups, 54 | bias=False), norm_layer(out_planes), 55 | activation_layer(inplace=True)) 56 | self.out_channels = out_planes 57 | 58 | 59 | # necessary for backwards compatibility 60 | ConvBNReLU = ConvBNActivation 61 | 62 | 63 | class InvertedResidual(nn.Module): 64 | def __init__( 65 | self, 66 | inp: int, 67 | oup: int, 68 | stride: int, 69 | dilation: int, 70 | expand_ratio: int, 71 | norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: 72 | super(InvertedResidual, self).__init__() 73 | self.stride = stride 74 | assert stride in [1, 2] 75 | 76 | if norm_layer is None: 77 | norm_layer = nn.BatchNorm2d 78 | 79 | self.kernel_size = 3 80 | self.dilation = dilation 81 | 82 | hidden_dim = int(round(inp * expand_ratio)) 83 | self.use_res_connect = self.stride == 1 and inp == oup 84 | 85 | layers: List[nn.Module] = [] 86 | if expand_ratio != 1: 87 | # pw 88 | layers.append( 89 | ConvBNReLU(inp, 90 | hidden_dim, 91 | kernel_size=1, 92 | norm_layer=norm_layer)) 93 | layers.extend([ 94 | # dw 95 | ConvBNReLU(hidden_dim, 96 | hidden_dim, 97 | stride=stride, 98 | dilation=dilation, 99 | groups=hidden_dim, 100 | norm_layer=norm_layer), 101 | # pw-linear 102 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 103 | norm_layer(oup), 104 | ]) 105 | self.conv = nn.Sequential(*layers) 106 | self.out_channels = oup 107 | self._is_cn = stride > 1 108 | 109 | def forward(self, x: Tensor) -> Tensor: 110 | if self.use_res_connect: 111 | return x + self.conv(x) 112 | else: 113 | return self.conv(x) 114 | 115 | 116 | class MobileNetV2(nn.Module): 117 | def __init__(self, 118 | output_stride=8, 119 | norm_layer: Optional[Callable[..., nn.Module]] = None, 120 | width_mult: float = 1.0, 121 | inverted_residual_setting: Optional[List[List[int]]] = None, 122 | round_nearest: int = 8, 123 | block: Optional[Callable[..., nn.Module]] = None, 124 | freeze_at=0) -> None: 125 | """ 126 | MobileNet V2 main class 127 | Args: 128 | num_classes (int): Number of classes 129 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 130 | inverted_residual_setting: Network structure 131 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 132 | Set to 1 to turn off rounding 133 | block: Module specifying inverted residual building block for mobilenet 134 | norm_layer: Module specifying the normalization layer to use 135 | """ 136 | super(MobileNetV2, self).__init__() 137 | 138 | if block is None: 139 | block = InvertedResidual 140 | 141 | if norm_layer is None: 142 | norm_layer = nn.BatchNorm2d 143 | 144 | last_channel = 1280 145 | input_channel = 32 146 | current_stride = 1 147 | rate = 1 148 | 149 | if inverted_residual_setting is None: 150 | inverted_residual_setting = [ 151 | # t, c, n, s 152 | [1, 16, 1, 1], 153 | [6, 24, 2, 2], 154 | [6, 32, 3, 2], 155 | [6, 64, 4, 2], 156 | [6, 96, 3, 1], 157 | [6, 160, 3, 2], 158 | [6, 320, 1, 1], 159 | ] 160 | 161 | # only check the first element, assuming user knows t,c,n,s are required 162 | if len(inverted_residual_setting) == 0 or len( 163 | inverted_residual_setting[0]) != 4: 164 | raise ValueError("inverted_residual_setting should be non-empty " 165 | "or a 4-element list, got {}".format( 166 | inverted_residual_setting)) 167 | 168 | # building first layer 169 | input_channel = _make_divisible(input_channel * width_mult, 170 | round_nearest) 171 | self.last_channel = _make_divisible( 172 | last_channel * max(1.0, width_mult), round_nearest) 173 | features: List[nn.Module] = [ 174 | ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer) 175 | ] 176 | current_stride *= 2 177 | # building inverted residual blocks 178 | for t, c, n, s in inverted_residual_setting: 179 | if current_stride == output_stride: 180 | stride = 1 181 | dilation = rate 182 | rate *= s 183 | else: 184 | stride = s 185 | dilation = 1 186 | current_stride *= s 187 | output_channel = _make_divisible(c * width_mult, round_nearest) 188 | for i in range(n): 189 | if i == 0: 190 | features.append( 191 | block(input_channel, output_channel, stride, dilation, 192 | t, norm_layer)) 193 | else: 194 | features.append( 195 | block(input_channel, output_channel, 1, rate, t, 196 | norm_layer)) 197 | input_channel = output_channel 198 | 199 | # building last several layers 200 | features.append( 201 | ConvBNReLU(input_channel, 202 | self.last_channel, 203 | kernel_size=1, 204 | norm_layer=norm_layer)) 205 | # make it nn.Sequential 206 | self.features = nn.Sequential(*features) 207 | 208 | self._initialize_weights() 209 | 210 | feature_4x = self.features[0:4] 211 | feautre_8x = self.features[4:7] 212 | feature_16x = self.features[7:14] 213 | feature_32x = self.features[14:] 214 | 215 | self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x] 216 | 217 | self.freeze(freeze_at) 218 | 219 | def forward(self, x): 220 | xs = [] 221 | for stage in self.stages: 222 | x = stage(x) 223 | xs.append(x) 224 | return xs 225 | 226 | def _initialize_weights(self): 227 | # weight initialization 228 | for m in self.modules(): 229 | if isinstance(m, nn.Conv2d): 230 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 231 | if m.bias is not None: 232 | nn.init.zeros_(m.bias) 233 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 234 | nn.init.ones_(m.weight) 235 | nn.init.zeros_(m.bias) 236 | elif isinstance(m, nn.Linear): 237 | nn.init.normal_(m.weight, 0, 0.01) 238 | nn.init.zeros_(m.bias) 239 | 240 | def freeze(self, freeze_at): 241 | if freeze_at >= 1: 242 | for m in self.stages[0][0]: 243 | freeze_params(m) 244 | 245 | for idx, stage in enumerate(self.stages, start=2): 246 | if freeze_at >= idx: 247 | freeze_params(stage) 248 | -------------------------------------------------------------------------------- /networks/encoders/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a MobileNetV3 Model as defined in: 3 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019). 4 | Searching for MobileNetV3 5 | arXiv preprint arXiv:1905.02244. 6 | """ 7 | 8 | import torch.nn as nn 9 | import math 10 | from utils.learning import freeze_params 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class h_sigmoid(nn.Module): 34 | def __init__(self, inplace=True): 35 | super(h_sigmoid, self).__init__() 36 | self.relu = nn.ReLU6(inplace=inplace) 37 | 38 | def forward(self, x): 39 | return self.relu(x + 3) / 6 40 | 41 | 42 | class h_swish(nn.Module): 43 | def __init__(self, inplace=True): 44 | super(h_swish, self).__init__() 45 | self.sigmoid = h_sigmoid(inplace=inplace) 46 | 47 | def forward(self, x): 48 | return x * self.sigmoid(x) 49 | 50 | 51 | class SELayer(nn.Module): 52 | def __init__(self, channel, reduction=4): 53 | super(SELayer, self).__init__() 54 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 55 | self.fc = nn.Sequential( 56 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 57 | nn.ReLU(inplace=True), 58 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 59 | h_sigmoid()) 60 | 61 | def forward(self, x): 62 | b, c, _, _ = x.size() 63 | y = self.avg_pool(x).view(b, c) 64 | y = self.fc(y).view(b, c, 1, 1) 65 | return x * y 66 | 67 | 68 | def conv_3x3_bn(inp, oup, stride, norm_layer=nn.BatchNorm2d): 69 | return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 70 | norm_layer(oup), h_swish()) 71 | 72 | 73 | def conv_1x1_bn(inp, oup, norm_layer=nn.BatchNorm2d): 74 | return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 75 | norm_layer(oup), h_swish()) 76 | 77 | 78 | class InvertedResidual(nn.Module): 79 | def __init__(self, 80 | inp, 81 | hidden_dim, 82 | oup, 83 | kernel_size, 84 | stride, 85 | use_se, 86 | use_hs, 87 | dilation=1, 88 | norm_layer=nn.BatchNorm2d): 89 | super(InvertedResidual, self).__init__() 90 | assert stride in [1, 2] 91 | 92 | self.identity = stride == 1 and inp == oup 93 | 94 | if inp == hidden_dim: 95 | self.conv = nn.Sequential( 96 | # dw 97 | nn.Conv2d(hidden_dim, 98 | hidden_dim, 99 | kernel_size, 100 | stride, (kernel_size - 1) // 2 * dilation, 101 | dilation=dilation, 102 | groups=hidden_dim, 103 | bias=False), 104 | norm_layer(hidden_dim), 105 | h_swish() if use_hs else nn.ReLU(inplace=True), 106 | # Squeeze-and-Excite 107 | SELayer(hidden_dim) if use_se else nn.Identity(), 108 | # pw-linear 109 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 110 | norm_layer(oup), 111 | ) 112 | else: 113 | self.conv = nn.Sequential( 114 | # pw 115 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 116 | norm_layer(hidden_dim), 117 | h_swish() if use_hs else nn.ReLU(inplace=True), 118 | # dw 119 | nn.Conv2d(hidden_dim, 120 | hidden_dim, 121 | kernel_size, 122 | stride, (kernel_size - 1) // 2 * dilation, 123 | dilation=dilation, 124 | groups=hidden_dim, 125 | bias=False), 126 | norm_layer(hidden_dim), 127 | # Squeeze-and-Excite 128 | SELayer(hidden_dim) if use_se else nn.Identity(), 129 | h_swish() if use_hs else nn.ReLU(inplace=True), 130 | # pw-linear 131 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 132 | norm_layer(oup), 133 | ) 134 | 135 | def forward(self, x): 136 | if self.identity: 137 | return x + self.conv(x) 138 | else: 139 | return self.conv(x) 140 | 141 | 142 | class MobileNetV3Large(nn.Module): 143 | def __init__(self, 144 | output_stride=16, 145 | norm_layer=nn.BatchNorm2d, 146 | width_mult=1., 147 | freeze_at=0): 148 | super(MobileNetV3Large, self).__init__() 149 | """ 150 | Constructs a MobileNetV3-Large model 151 | """ 152 | cfgs = [ 153 | # k, t, c, SE, HS, s 154 | [3, 1, 16, 0, 0, 1], 155 | [3, 4, 24, 0, 0, 2], 156 | [3, 3, 24, 0, 0, 1], 157 | [5, 3, 40, 1, 0, 2], 158 | [5, 3, 40, 1, 0, 1], 159 | [5, 3, 40, 1, 0, 1], 160 | [3, 6, 80, 0, 1, 2], 161 | [3, 2.5, 80, 0, 1, 1], 162 | [3, 2.3, 80, 0, 1, 1], 163 | [3, 2.3, 80, 0, 1, 1], 164 | [3, 6, 112, 1, 1, 1], 165 | [3, 6, 112, 1, 1, 1], 166 | [5, 6, 160, 1, 1, 2], 167 | [5, 6, 160, 1, 1, 1], 168 | [5, 6, 160, 1, 1, 1] 169 | ] 170 | self.cfgs = cfgs 171 | 172 | # building first layer 173 | input_channel = _make_divisible(16 * width_mult, 8) 174 | layers = [conv_3x3_bn(3, input_channel, 2, norm_layer)] 175 | # building inverted residual blocks 176 | block = InvertedResidual 177 | now_stride = 2 178 | rate = 1 179 | for k, t, c, use_se, use_hs, s in self.cfgs: 180 | if now_stride == output_stride: 181 | dilation = rate 182 | rate *= s 183 | s = 1 184 | else: 185 | dilation = 1 186 | now_stride *= s 187 | output_channel = _make_divisible(c * width_mult, 8) 188 | exp_size = _make_divisible(input_channel * t, 8) 189 | layers.append( 190 | block(input_channel, exp_size, output_channel, k, s, use_se, 191 | use_hs, dilation, norm_layer)) 192 | input_channel = output_channel 193 | 194 | self.features = nn.Sequential(*layers) 195 | self.conv = conv_1x1_bn(input_channel, exp_size, norm_layer) 196 | # building last several layers 197 | 198 | self._initialize_weights() 199 | 200 | feature_4x = self.features[0:4] 201 | feautre_8x = self.features[4:7] 202 | feature_16x = self.features[7:13] 203 | feature_32x = self.features[13:] 204 | 205 | self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x] 206 | 207 | self.freeze(freeze_at) 208 | 209 | def forward(self, x): 210 | xs = [] 211 | for stage in self.stages: 212 | x = stage(x) 213 | xs.append(x) 214 | xs[-1] = self.conv(xs[-1]) 215 | return xs 216 | 217 | def _initialize_weights(self): 218 | for m in self.modules(): 219 | if isinstance(m, nn.Conv2d): 220 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 221 | m.weight.data.normal_(0, math.sqrt(2. / n)) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | elif isinstance(m, nn.BatchNorm2d): 225 | m.weight.data.fill_(1) 226 | m.bias.data.zero_() 227 | elif isinstance(m, nn.Linear): 228 | n = m.weight.size(1) 229 | m.weight.data.normal_(0, 0.01) 230 | m.bias.data.zero_() 231 | 232 | def freeze(self, freeze_at): 233 | if freeze_at >= 1: 234 | for m in self.stages[0][0]: 235 | freeze_params(m) 236 | 237 | for idx, stage in enumerate(self.stages, start=2): 238 | if freeze_at >= idx: 239 | freeze_params(stage) 240 | -------------------------------------------------------------------------------- /networks/encoders/resnest/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnest import * 2 | -------------------------------------------------------------------------------- /networks/encoders/resnest/resnest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnet import ResNet, Bottleneck 3 | 4 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 5 | 6 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 7 | 8 | _model_sha256 = { 9 | name: checksum 10 | for checksum, name in [ 11 | ('528c19ca', 'resnest50'), 12 | ('22405ba7', 'resnest101'), 13 | ('75117900', 'resnest200'), 14 | ('0cc87c48', 'resnest269'), 15 | ] 16 | } 17 | 18 | 19 | def short_hash(name): 20 | if name not in _model_sha256: 21 | raise ValueError( 22 | 'Pretrained model for {name} is not available.'.format(name=name)) 23 | return _model_sha256[name][:8] 24 | 25 | 26 | resnest_model_urls = { 27 | name: _url_format.format(name, short_hash(name)) 28 | for name in _model_sha256.keys() 29 | } 30 | 31 | 32 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 33 | model = ResNet(Bottleneck, [3, 4, 6, 3], 34 | radix=2, 35 | groups=1, 36 | bottleneck_width=64, 37 | deep_stem=True, 38 | stem_width=32, 39 | avg_down=True, 40 | avd=True, 41 | avd_first=False, 42 | **kwargs) 43 | if pretrained: 44 | model.load_state_dict( 45 | torch.hub.load_state_dict_from_url(resnest_model_urls['resnest50'], 46 | progress=True, 47 | check_hash=True)) 48 | return model 49 | 50 | 51 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 52 | model = ResNet(Bottleneck, [3, 4, 23, 3], 53 | radix=2, 54 | groups=1, 55 | bottleneck_width=64, 56 | deep_stem=True, 57 | stem_width=64, 58 | avg_down=True, 59 | avd=True, 60 | avd_first=False, 61 | **kwargs) 62 | if pretrained: 63 | model.load_state_dict( 64 | torch.hub.load_state_dict_from_url( 65 | resnest_model_urls['resnest101'], 66 | progress=True, 67 | check_hash=True)) 68 | return model 69 | 70 | 71 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): 72 | model = ResNet(Bottleneck, [3, 24, 36, 3], 73 | radix=2, 74 | groups=1, 75 | bottleneck_width=64, 76 | deep_stem=True, 77 | stem_width=64, 78 | avg_down=True, 79 | avd=True, 80 | avd_first=False, 81 | **kwargs) 82 | if pretrained: 83 | model.load_state_dict( 84 | torch.hub.load_state_dict_from_url( 85 | resnest_model_urls['resnest200'], 86 | progress=True, 87 | check_hash=True)) 88 | return model 89 | 90 | 91 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): 92 | model = ResNet(Bottleneck, [3, 30, 48, 8], 93 | radix=2, 94 | groups=1, 95 | bottleneck_width=64, 96 | deep_stem=True, 97 | stem_width=64, 98 | avg_down=True, 99 | avd=True, 100 | avd_first=False, 101 | **kwargs) 102 | if pretrained: 103 | model.load_state_dict( 104 | torch.hub.load_state_dict_from_url( 105 | resnest_model_urls['resnest269'], 106 | progress=True, 107 | check_hash=True)) 108 | return model 109 | -------------------------------------------------------------------------------- /networks/encoders/resnest/splat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv2d, Module, ReLU 5 | from torch.nn.modules.utils import _pair 6 | 7 | __all__ = ['SplAtConv2d', 'DropBlock2D'] 8 | 9 | 10 | class DropBlock2D(object): 11 | def __init__(self, *args, **kwargs): 12 | raise NotImplementedError 13 | 14 | 15 | class SplAtConv2d(Module): 16 | """Split-Attention Conv2d 17 | """ 18 | def __init__(self, 19 | in_channels, 20 | channels, 21 | kernel_size, 22 | stride=(1, 1), 23 | padding=(0, 0), 24 | dilation=(1, 1), 25 | groups=1, 26 | bias=True, 27 | radix=2, 28 | reduction_factor=4, 29 | rectify=False, 30 | rectify_avg=False, 31 | norm_layer=None, 32 | dropblock_prob=0.0, 33 | **kwargs): 34 | super(SplAtConv2d, self).__init__() 35 | padding = _pair(padding) 36 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 37 | self.rectify_avg = rectify_avg 38 | inter_channels = max(in_channels * radix // reduction_factor, 32) 39 | self.radix = radix 40 | self.cardinality = groups 41 | self.channels = channels 42 | self.dropblock_prob = dropblock_prob 43 | if self.rectify: 44 | from rfconv import RFConv2d 45 | self.conv = RFConv2d(in_channels, 46 | channels * radix, 47 | kernel_size, 48 | stride, 49 | padding, 50 | dilation, 51 | groups=groups * radix, 52 | bias=bias, 53 | average_mode=rectify_avg, 54 | **kwargs) 55 | else: 56 | self.conv = Conv2d(in_channels, 57 | channels * radix, 58 | kernel_size, 59 | stride, 60 | padding, 61 | dilation, 62 | groups=groups * radix, 63 | bias=bias, 64 | **kwargs) 65 | self.use_bn = norm_layer is not None 66 | if self.use_bn: 67 | self.bn0 = norm_layer(channels * radix) 68 | self.relu = ReLU(inplace=True) 69 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 70 | if self.use_bn: 71 | self.bn1 = norm_layer(inter_channels) 72 | self.fc2 = Conv2d(inter_channels, 73 | channels * radix, 74 | 1, 75 | groups=self.cardinality) 76 | if dropblock_prob > 0.0: 77 | self.dropblock = DropBlock2D(dropblock_prob, 3) 78 | self.rsoftmax = rSoftMax(radix, groups) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | if self.use_bn: 83 | x = self.bn0(x) 84 | if self.dropblock_prob > 0.0: 85 | x = self.dropblock(x) 86 | x = self.relu(x) 87 | 88 | batch, rchannel = x.shape[:2] 89 | if self.radix > 1: 90 | if torch.__version__ < '1.5': 91 | splited = torch.split(x, int(rchannel // self.radix), dim=1) 92 | else: 93 | splited = torch.split(x, rchannel // self.radix, dim=1) 94 | gap = sum(splited) 95 | else: 96 | gap = x 97 | gap = F.adaptive_avg_pool2d(gap, 1) 98 | gap = self.fc1(gap) 99 | 100 | if self.use_bn: 101 | gap = self.bn1(gap) 102 | gap = self.relu(gap) 103 | 104 | atten = self.fc2(gap) 105 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 106 | 107 | if self.radix > 1: 108 | if torch.__version__ < '1.5': 109 | attens = torch.split(atten, int(rchannel // self.radix), dim=1) 110 | else: 111 | attens = torch.split(atten, rchannel // self.radix, dim=1) 112 | out = sum([att * split for (att, split) in zip(attens, splited)]) 113 | else: 114 | out = atten * x 115 | return out.contiguous() 116 | 117 | 118 | class rSoftMax(nn.Module): 119 | def __init__(self, radix, cardinality): 120 | super().__init__() 121 | self.radix = radix 122 | self.cardinality = cardinality 123 | 124 | def forward(self, x): 125 | batch = x.size(0) 126 | if self.radix > 1: 127 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 128 | x = F.softmax(x, dim=1) 129 | x = x.reshape(batch, -1) 130 | else: 131 | x = torch.sigmoid(x) 132 | return x 133 | -------------------------------------------------------------------------------- /networks/encoders/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from utils.learning import freeze_params 4 | 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, 10 | inplanes, 11 | planes, 12 | stride=1, 13 | dilation=1, 14 | downsample=None, 15 | BatchNorm=None): 16 | super(Bottleneck, self).__init__() 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn1 = BatchNorm(planes) 19 | self.conv2 = nn.Conv2d(planes, 20 | planes, 21 | kernel_size=3, 22 | stride=stride, 23 | dilation=dilation, 24 | padding=dilation, 25 | bias=False) 26 | self.bn2 = BatchNorm(planes) 27 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 28 | self.bn3 = BatchNorm(planes * 4) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = downsample 31 | self.stride = stride 32 | self.dilation = dilation 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class ResNet(nn.Module): 58 | def __init__(self, block, layers, output_stride, BatchNorm, freeze_at=0): 59 | self.inplanes = 64 60 | super(ResNet, self).__init__() 61 | 62 | if output_stride == 16: 63 | strides = [1, 2, 2, 1] 64 | dilations = [1, 1, 1, 2] 65 | elif output_stride == 8: 66 | strides = [1, 2, 1, 1] 67 | dilations = [1, 1, 2, 4] 68 | else: 69 | raise NotImplementedError 70 | 71 | # Modules 72 | self.conv1 = nn.Conv2d(3, 73 | 64, 74 | kernel_size=7, 75 | stride=2, 76 | padding=3, 77 | bias=False) 78 | self.bn1 = BatchNorm(64) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 81 | 82 | self.layer1 = self._make_layer(block, 83 | 64, 84 | layers[0], 85 | stride=strides[0], 86 | dilation=dilations[0], 87 | BatchNorm=BatchNorm) 88 | self.layer2 = self._make_layer(block, 89 | 128, 90 | layers[1], 91 | stride=strides[1], 92 | dilation=dilations[1], 93 | BatchNorm=BatchNorm) 94 | self.layer3 = self._make_layer(block, 95 | 256, 96 | layers[2], 97 | stride=strides[2], 98 | dilation=dilations[2], 99 | BatchNorm=BatchNorm) 100 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 101 | 102 | self.stem = [self.conv1, self.bn1] 103 | self.stages = [self.layer1, self.layer2, self.layer3] 104 | 105 | self._init_weight() 106 | self.freeze(freeze_at) 107 | 108 | def _make_layer(self, 109 | block, 110 | planes, 111 | blocks, 112 | stride=1, 113 | dilation=1, 114 | BatchNorm=None): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | nn.Conv2d(self.inplanes, 119 | planes * block.expansion, 120 | kernel_size=1, 121 | stride=stride, 122 | bias=False), 123 | BatchNorm(planes * block.expansion), 124 | ) 125 | 126 | layers = [] 127 | layers.append( 128 | block(self.inplanes, planes, stride, max(dilation // 2, 1), 129 | downsample, BatchNorm)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append( 133 | block(self.inplanes, 134 | planes, 135 | dilation=dilation, 136 | BatchNorm=BatchNorm)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, input): 141 | x = self.conv1(input) 142 | x = self.bn1(x) 143 | x = self.relu(x) 144 | x = self.maxpool(x) 145 | 146 | xs = [] 147 | 148 | x = self.layer1(x) 149 | xs.append(x) # 4X 150 | x = self.layer2(x) 151 | xs.append(x) # 8X 152 | x = self.layer3(x) 153 | xs.append(x) # 16X 154 | # Following STMVOS, we drop stage 5. 155 | xs.append(x) # 16X 156 | 157 | return xs 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, nn.BatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | 168 | def freeze(self, freeze_at): 169 | if freeze_at >= 1: 170 | for m in self.stem: 171 | freeze_params(m) 172 | 173 | for idx, stage in enumerate(self.stages, start=2): 174 | if freeze_at >= idx: 175 | freeze_params(stage) 176 | 177 | 178 | def ResNet50(output_stride, BatchNorm, freeze_at=0): 179 | """Constructs a ResNet-50 model. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(Bottleneck, [3, 4, 6, 3], 184 | output_stride, 185 | BatchNorm, 186 | freeze_at=freeze_at) 187 | return model 188 | 189 | 190 | def ResNet101(output_stride, BatchNorm, freeze_at=0): 191 | """Constructs a ResNet-101 model. 192 | Args: 193 | pretrained (bool): If True, returns a model pre-trained on ImageNet 194 | """ 195 | model = ResNet(Bottleneck, [3, 4, 23, 3], 196 | output_stride, 197 | BatchNorm, 198 | freeze_at=freeze_at) 199 | return model 200 | 201 | 202 | if __name__ == "__main__": 203 | import torch 204 | model = ResNet101(BatchNorm=nn.BatchNorm2d, output_stride=8) 205 | input = torch.rand(1, 3, 512, 512) 206 | output, low_level_feat = model(input) 207 | print(output.size()) 208 | print(low_level_feat.size()) 209 | -------------------------------------------------------------------------------- /networks/encoders/swin/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_swin_model -------------------------------------------------------------------------------- /networks/encoders/swin/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | 10 | 11 | def build_swin_model(model_type, freeze_at=0): 12 | if model_type == 'swin_base': 13 | model = SwinTransformer(embed_dim=128, 14 | depths=[2, 2, 18, 2], 15 | num_heads=[4, 8, 16, 32], 16 | window_size=7, 17 | drop_path_rate=0.3, 18 | out_indices=(0, 1, 2), 19 | ape=False, 20 | patch_norm=True, 21 | frozen_stages=freeze_at, 22 | use_checkpoint=False) 23 | 24 | else: 25 | raise NotImplementedError(f"Unkown model: {model_type}") 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /networks/engines/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine 2 | from networks.engines.deaot_engine import DeAOTEngine, DeAOTInferEngine 3 | 4 | 5 | def build_engine(name, phase='train', **kwargs): 6 | if name == 'aotengine': 7 | if phase == 'train': 8 | return AOTEngine(**kwargs) 9 | elif phase == 'eval': 10 | return AOTInferEngine(**kwargs) 11 | else: 12 | raise NotImplementedError 13 | elif name == 'deaotengine': 14 | if phase == 'train': 15 | return DeAOTEngine(**kwargs) 16 | elif phase == 'eval': 17 | return DeAOTInferEngine(**kwargs) 18 | else: 19 | raise NotImplementedError 20 | else: 21 | raise NotImplementedError 22 | -------------------------------------------------------------------------------- /networks/engines/deaot_engine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.image import one_hot_mask 4 | 5 | from networks.layers.basic import seq_to_2d 6 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine 7 | 8 | 9 | class DeAOTEngine(AOTEngine): 10 | def __init__(self, 11 | aot_model, 12 | gpu_id=0, 13 | long_term_mem_gap=9999, 14 | short_term_mem_skip=1, 15 | layer_loss_scaling_ratio=2.): 16 | super().__init__(aot_model, gpu_id, long_term_mem_gap, 17 | short_term_mem_skip) 18 | self.layer_loss_scaling_ratio = layer_loss_scaling_ratio 19 | 20 | def update_short_term_memory(self, curr_mask, curr_id_emb=None, skip_long_term_update=False): 21 | 22 | if curr_id_emb is None: 23 | if len(curr_mask.size()) == 3 or curr_mask.size()[0] == 1: 24 | curr_one_hot_mask = one_hot_mask(curr_mask, self.max_obj_num) 25 | else: 26 | curr_one_hot_mask = curr_mask 27 | curr_id_emb = self.assign_identity(curr_one_hot_mask) 28 | 29 | lstt_curr_memories = self.curr_lstt_output[1] 30 | lstt_curr_memories_2d = [] 31 | for layer_idx in range(len(lstt_curr_memories)): 32 | curr_k, curr_v, curr_id_k, curr_id_v = lstt_curr_memories[ 33 | layer_idx] 34 | curr_id_k, curr_id_v = self.AOT.LSTT.layers[ 35 | layer_idx].fuse_key_value_id(curr_id_k, curr_id_v, curr_id_emb) 36 | lstt_curr_memories[layer_idx][2], lstt_curr_memories[layer_idx][ 37 | 3] = curr_id_k, curr_id_v 38 | local_curr_id_k = seq_to_2d( 39 | curr_id_k, self.enc_size_2d) if curr_id_k is not None else None 40 | local_curr_id_v = seq_to_2d(curr_id_v, self.enc_size_2d) 41 | lstt_curr_memories_2d.append([ 42 | seq_to_2d(curr_k, self.enc_size_2d), 43 | seq_to_2d(curr_v, self.enc_size_2d), local_curr_id_k, 44 | local_curr_id_v 45 | ]) 46 | 47 | self.short_term_memories_list.append(lstt_curr_memories_2d) 48 | self.short_term_memories_list = self.short_term_memories_list[ 49 | -self.short_term_mem_skip:] 50 | self.short_term_memories = self.short_term_memories_list[0] 51 | 52 | if self.frame_step - self.last_mem_step >= self.long_term_mem_gap: 53 | # skip the update of long-term memory or not 54 | if not skip_long_term_update: 55 | self.update_long_term_memory(lstt_curr_memories) 56 | self.last_mem_step = self.frame_step 57 | 58 | 59 | class DeAOTInferEngine(AOTInferEngine): 60 | def __init__(self, 61 | aot_model, 62 | gpu_id=0, 63 | long_term_mem_gap=9999, 64 | short_term_mem_skip=1, 65 | max_aot_obj_num=None): 66 | super().__init__(aot_model, gpu_id, long_term_mem_gap, 67 | short_term_mem_skip, max_aot_obj_num) 68 | 69 | def add_reference_frame(self, img, mask, obj_nums, frame_step=-1): 70 | if isinstance(obj_nums, list): 71 | obj_nums = obj_nums[0] 72 | self.obj_nums = obj_nums 73 | aot_num = max(np.ceil(obj_nums / self.max_aot_obj_num), 1) 74 | while (aot_num > len(self.aot_engines)): 75 | new_engine = DeAOTEngine(self.AOT, self.gpu_id, 76 | self.long_term_mem_gap, 77 | self.short_term_mem_skip) 78 | new_engine.eval() 79 | self.aot_engines.append(new_engine) 80 | 81 | separated_masks, separated_obj_nums = self.separate_mask( 82 | mask, obj_nums) 83 | img_embs = None 84 | for aot_engine, separated_mask, separated_obj_num in zip( 85 | self.aot_engines, separated_masks, separated_obj_nums): 86 | aot_engine.add_reference_frame(img, 87 | separated_mask, 88 | obj_nums=[separated_obj_num], 89 | frame_step=frame_step, 90 | img_embs=img_embs) 91 | if img_embs is None: # reuse image embeddings 92 | img_embs = aot_engine.curr_enc_embs 93 | 94 | self.update_size() 95 | -------------------------------------------------------------------------------- /networks/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/networks/layers/__init__.py -------------------------------------------------------------------------------- /networks/layers/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class GroupNorm1D(nn.Module): 7 | def __init__(self, indim, groups=8): 8 | super().__init__() 9 | self.gn = nn.GroupNorm(groups, indim) 10 | 11 | def forward(self, x): 12 | return self.gn(x.permute(1, 2, 0)).permute(2, 0, 1) 13 | 14 | 15 | class GNActDWConv2d(nn.Module): 16 | def __init__(self, indim, gn_groups=32): 17 | super().__init__() 18 | self.gn = nn.GroupNorm(gn_groups, indim) 19 | self.conv = nn.Conv2d(indim, 20 | indim, 21 | 5, 22 | dilation=1, 23 | padding=2, 24 | groups=indim, 25 | bias=False) 26 | 27 | def forward(self, x, size_2d): 28 | h, w = size_2d 29 | _, bs, c = x.size() 30 | x = x.view(h, w, bs, c).permute(2, 3, 0, 1) 31 | x = self.gn(x) 32 | x = F.gelu(x) 33 | x = self.conv(x) 34 | x = x.view(bs, c, h * w).permute(2, 0, 1) 35 | return x 36 | 37 | 38 | class DWConv2d(nn.Module): 39 | def __init__(self, indim, dropout=0.1): 40 | super().__init__() 41 | self.conv = nn.Conv2d(indim, 42 | indim, 43 | 5, 44 | dilation=1, 45 | padding=2, 46 | groups=indim, 47 | bias=False) 48 | self.dropout = nn.Dropout2d(p=dropout, inplace=True) 49 | 50 | def forward(self, x, size_2d): 51 | h, w = size_2d 52 | _, bs, c = x.size() 53 | x = x.view(h, w, bs, c).permute(2, 3, 0, 1) 54 | x = self.conv(x) 55 | x = self.dropout(x) 56 | x = x.view(bs, c, h * w).permute(2, 0, 1) 57 | return x 58 | 59 | 60 | class ScaleOffset(nn.Module): 61 | def __init__(self, indim): 62 | super().__init__() 63 | self.gamma = nn.Parameter(torch.ones(indim)) 64 | # torch.nn.init.normal_(self.gamma, std=0.02) 65 | self.beta = nn.Parameter(torch.zeros(indim)) 66 | 67 | def forward(self, x): 68 | if len(x.size()) == 3: 69 | return x * self.gamma + self.beta 70 | else: 71 | return x * self.gamma.view(1, -1, 1, 1) + self.beta.view( 72 | 1, -1, 1, 1) 73 | 74 | 75 | class ConvGN(nn.Module): 76 | def __init__(self, indim, outdim, kernel_size, gn_groups=8): 77 | super().__init__() 78 | self.conv = nn.Conv2d(indim, 79 | outdim, 80 | kernel_size, 81 | padding=kernel_size // 2) 82 | self.gn = nn.GroupNorm(gn_groups, outdim) 83 | 84 | def forward(self, x): 85 | return self.gn(self.conv(x)) 86 | 87 | 88 | def seq_to_2d(tensor, size_2d): 89 | h, w = size_2d 90 | _, n, c = tensor.size() 91 | tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() 92 | return tensor 93 | 94 | 95 | def drop_path(x, drop_prob: float = 0., training: bool = False): 96 | if drop_prob == 0. or not training: 97 | return x 98 | keep_prob = 1 - drop_prob 99 | shape = ( 100 | x.shape[0], 101 | x.shape[1], 102 | ) + (1, ) * (x.ndim - 2 103 | ) # work with diff dim tensors, not just 2D ConvNets 104 | random_tensor = keep_prob + torch.rand( 105 | shape, dtype=x.dtype, device=x.device) 106 | random_tensor.floor_() # binarize 107 | output = x.div(keep_prob) * random_tensor 108 | return output 109 | 110 | 111 | def mask_out(x, y, mask_rate=0.15, training=False): 112 | if mask_rate == 0. or not training: 113 | return x 114 | 115 | keep_prob = 1 - mask_rate 116 | shape = ( 117 | x.shape[0], 118 | x.shape[1], 119 | ) + (1, ) * (x.ndim - 2 120 | ) # work with diff dim tensors, not just 2D ConvNets 121 | random_tensor = keep_prob + torch.rand( 122 | shape, dtype=x.dtype, device=x.device) 123 | random_tensor.floor_() # binarize 124 | output = x * random_tensor + y * (1 - random_tensor) 125 | 126 | return output 127 | 128 | 129 | class DropPath(nn.Module): 130 | def __init__(self, drop_prob=None, batch_dim=0): 131 | super(DropPath, self).__init__() 132 | self.drop_prob = drop_prob 133 | self.batch_dim = batch_dim 134 | 135 | def forward(self, x): 136 | return self.drop_path(x, self.drop_prob) 137 | 138 | def drop_path(self, x, drop_prob): 139 | if drop_prob == 0. or not self.training: 140 | return x 141 | keep_prob = 1 - drop_prob 142 | shape = [1 for _ in range(x.ndim)] 143 | shape[self.batch_dim] = x.shape[self.batch_dim] 144 | random_tensor = keep_prob + torch.rand( 145 | shape, dtype=x.dtype, device=x.device) 146 | random_tensor.floor_() # binarize 147 | output = x.div(keep_prob) * random_tensor 148 | return output 149 | 150 | 151 | class DropOutLogit(nn.Module): 152 | def __init__(self, drop_prob=None): 153 | super(DropOutLogit, self).__init__() 154 | self.drop_prob = drop_prob 155 | 156 | def forward(self, x): 157 | return self.drop_logit(x, self.drop_prob) 158 | 159 | def drop_logit(self, x, drop_prob): 160 | if drop_prob == 0. or not self.training: 161 | return x 162 | random_tensor = drop_prob + torch.rand( 163 | x.shape, dtype=x.dtype, device=x.device) 164 | random_tensor.floor_() # binarize 165 | mask = random_tensor * 1e+8 if ( 166 | x.dtype == torch.float32) else random_tensor * 1e+4 167 | output = x - mask 168 | return output 169 | -------------------------------------------------------------------------------- /networks/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from itertools import ifilterfalse 7 | except ImportError: # py3k 8 | from itertools import filterfalse as ifilterfalse 9 | 10 | 11 | def dice_loss(probas, labels, smooth=1): 12 | 13 | C = probas.size(1) 14 | losses = [] 15 | for c in list(range(C)): 16 | fg = (labels == c).float() 17 | if fg.sum() == 0: 18 | continue 19 | class_pred = probas[:, c] 20 | p0 = class_pred 21 | g0 = fg 22 | numerator = 2 * torch.sum(p0 * g0) + smooth 23 | denominator = torch.sum(p0) + torch.sum(g0) + smooth 24 | losses.append(1 - ((numerator) / (denominator))) 25 | return mean(losses) 26 | 27 | 28 | def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6): 29 | ''' 30 | Tversky loss function. 31 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 32 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 33 | 34 | Same as soft dice loss when alpha=beta=0.5. 35 | Same as Jaccord loss when alpha=beta=1.0. 36 | See `Tversky loss function for image segmentation using 3D fully convolutional deep networks` 37 | https://arxiv.org/pdf/1706.05721.pdf 38 | ''' 39 | C = probas.size(1) 40 | losses = [] 41 | for c in list(range(C)): 42 | fg = (labels == c).float() 43 | if fg.sum() == 0: 44 | continue 45 | class_pred = probas[:, c] 46 | p0 = class_pred 47 | p1 = 1 - class_pred 48 | g0 = fg 49 | g1 = 1 - fg 50 | numerator = torch.sum(p0 * g0) 51 | denominator = numerator + alpha * \ 52 | torch.sum(p0*g1) + beta*torch.sum(p1*g0) 53 | losses.append(1 - ((numerator) / (denominator + epsilon))) 54 | return mean(losses) 55 | 56 | 57 | def flatten_probas(probas, labels, ignore=255): 58 | """ 59 | Flattens predictions in the batch 60 | """ 61 | B, C, H, W = probas.size() 62 | probas = probas.permute(0, 2, 3, 63 | 1).contiguous().view(-1, C) # B * H * W, C = P, C 64 | labels = labels.view(-1) 65 | if ignore is None: 66 | return probas, labels 67 | valid = (labels != ignore) 68 | vprobas = probas[valid.view(-1, 1).expand(-1, C)].reshape(-1, C) 69 | # vprobas = probas[torch.nonzero(valid).squeeze()] 70 | vlabels = labels[valid] 71 | return vprobas, vlabels 72 | 73 | 74 | def isnan(x): 75 | return x != x 76 | 77 | 78 | def mean(l, ignore_nan=False, empty=0): 79 | """ 80 | nanmean compatible with generators. 81 | """ 82 | l = iter(l) 83 | if ignore_nan: 84 | l = ifilterfalse(isnan, l) 85 | try: 86 | n = 1 87 | acc = next(l) 88 | except StopIteration: 89 | if empty == 'raise': 90 | raise ValueError('Empty mean') 91 | return empty 92 | for n, v in enumerate(l, 2): 93 | acc += v 94 | if n == 1: 95 | return acc 96 | return acc / n 97 | 98 | 99 | class DiceLoss(nn.Module): 100 | def __init__(self, ignore_index=255): 101 | super(DiceLoss, self).__init__() 102 | self.ignore_index = ignore_index 103 | 104 | def forward(self, tmp_dic, label_dic, step=None): 105 | total_loss = [] 106 | for idx in range(len(tmp_dic)): 107 | pred = tmp_dic[idx] 108 | label = label_dic[idx] 109 | pred = F.softmax(pred, dim=1) 110 | label = label.view(1, 1, pred.size()[2], pred.size()[3]) 111 | loss = dice_loss( 112 | *flatten_probas(pred, label, ignore=self.ignore_index)) 113 | total_loss.append(loss.unsqueeze(0)) 114 | total_loss = torch.cat(total_loss, dim=0) 115 | return total_loss 116 | 117 | 118 | class SoftJaccordLoss(nn.Module): 119 | def __init__(self, ignore_index=255): 120 | super(SoftJaccordLoss, self).__init__() 121 | self.ignore_index = ignore_index 122 | 123 | def forward(self, tmp_dic, label_dic, step=None): 124 | total_loss = [] 125 | for idx in range(len(tmp_dic)): 126 | pred = tmp_dic[idx] 127 | label = label_dic[idx] 128 | pred = F.softmax(pred, dim=1) 129 | label = label.view(1, 1, pred.size()[2], pred.size()[3]) 130 | loss = tversky_loss(*flatten_probas(pred, 131 | label, 132 | ignore=self.ignore_index), 133 | alpha=1.0, 134 | beta=1.0) 135 | total_loss.append(loss.unsqueeze(0)) 136 | total_loss = torch.cat(total_loss, dim=0) 137 | return total_loss 138 | 139 | 140 | class CrossEntropyLoss(nn.Module): 141 | def __init__(self, 142 | top_k_percent_pixels=None, 143 | hard_example_mining_step=100000): 144 | super(CrossEntropyLoss, self).__init__() 145 | self.top_k_percent_pixels = top_k_percent_pixels 146 | if top_k_percent_pixels is not None: 147 | assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1) 148 | self.hard_example_mining_step = hard_example_mining_step + 1e-5 149 | if self.top_k_percent_pixels is None: 150 | self.celoss = nn.CrossEntropyLoss(ignore_index=255, 151 | reduction='mean') 152 | else: 153 | self.celoss = nn.CrossEntropyLoss(ignore_index=255, 154 | reduction='none') 155 | 156 | def forward(self, dic_tmp, y, step): 157 | total_loss = [] 158 | for i in range(len(dic_tmp)): 159 | pred_logits = dic_tmp[i] 160 | gts = y[i] 161 | if self.top_k_percent_pixels is None: 162 | final_loss = self.celoss(pred_logits, gts) 163 | else: 164 | # Only compute the loss for top k percent pixels. 165 | # First, compute the loss for all pixels. Note we do not put the loss 166 | # to loss_collection and set reduction = None to keep the shape. 167 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) 168 | pred_logits = pred_logits.view( 169 | -1, pred_logits.size(1), 170 | pred_logits.size(2) * pred_logits.size(3)) 171 | gts = gts.view(-1, gts.size(1) * gts.size(2)) 172 | pixel_losses = self.celoss(pred_logits, gts) 173 | if self.hard_example_mining_step == 0: 174 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels) 175 | else: 176 | ratio = min(1.0, 177 | step / float(self.hard_example_mining_step)) 178 | top_k_pixels = int((ratio * self.top_k_percent_pixels + 179 | (1.0 - ratio)) * num_pixels) 180 | top_k_loss, top_k_indices = torch.topk(pixel_losses, 181 | k=top_k_pixels, 182 | dim=1) 183 | 184 | final_loss = torch.mean(top_k_loss) 185 | final_loss = final_loss.unsqueeze(0) 186 | total_loss.append(final_loss) 187 | total_loss = torch.cat(total_loss, dim=0) 188 | return total_loss 189 | -------------------------------------------------------------------------------- /networks/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FrozenBatchNorm2d(nn.Module): 7 | """ 8 | BatchNorm2d where the batch statistics and the affine parameters 9 | are fixed 10 | """ 11 | def __init__(self, n, epsilon=1e-5): 12 | super(FrozenBatchNorm2d, self).__init__() 13 | self.register_buffer("weight", torch.ones(n)) 14 | self.register_buffer("bias", torch.zeros(n)) 15 | self.register_buffer("running_mean", torch.zeros(n)) 16 | self.register_buffer("running_var", torch.ones(n) - epsilon) 17 | self.epsilon = epsilon 18 | 19 | def forward(self, x): 20 | """ 21 | Refer to Detectron2 (https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py) 22 | """ 23 | if x.requires_grad: 24 | # When gradients are needed, F.batch_norm will use extra memory 25 | # because its backward op computes gradients for weight/bias as well. 26 | scale = self.weight * (self.running_var + self.epsilon).rsqrt() 27 | bias = self.bias - self.running_mean * scale 28 | scale = scale.reshape(1, -1, 1, 1) 29 | bias = bias.reshape(1, -1, 1, 1) 30 | out_dtype = x.dtype # may be half 31 | return x * scale.to(out_dtype) + bias.to(out_dtype) 32 | else: 33 | # When gradients are not needed, F.batch_norm is a single fused op 34 | # and provide more optimization opportunities. 35 | return F.batch_norm( 36 | x, 37 | self.running_mean, 38 | self.running_var, 39 | self.weight, 40 | self.bias, 41 | training=False, 42 | eps=self.epsilon, 43 | ) 44 | -------------------------------------------------------------------------------- /networks/layers/position.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils.math import truncated_normal_ 8 | 9 | 10 | class Downsample2D(nn.Module): 11 | def __init__(self, mode='nearest', scale=4): 12 | super().__init__() 13 | self.mode = mode 14 | self.scale = scale 15 | 16 | def forward(self, x): 17 | n, c, h, w = x.size() 18 | x = F.interpolate(x, 19 | size=(h // self.scale + 1, w // self.scale + 1), 20 | mode=self.mode) 21 | return x 22 | 23 | 24 | def generate_coord(x): 25 | _, _, h, w = x.size() 26 | device = x.device 27 | col = torch.arange(0, h, device=device) 28 | row = torch.arange(0, w, device=device) 29 | grid_h, grid_w = torch.meshgrid(col, row) 30 | return grid_h, grid_w 31 | 32 | 33 | class PositionEmbeddingSine(nn.Module): 34 | def __init__(self, 35 | num_pos_feats=64, 36 | temperature=10000, 37 | normalize=False, 38 | scale=None): 39 | super().__init__() 40 | self.num_pos_feats = num_pos_feats 41 | self.temperature = temperature 42 | self.normalize = normalize 43 | if scale is not None and normalize is False: 44 | raise ValueError("normalize should be True if scale is passed") 45 | if scale is None: 46 | scale = 2 * math.pi 47 | self.scale = scale 48 | 49 | def forward(self, x): 50 | grid_y, grid_x = generate_coord(x) 51 | 52 | y_embed = grid_y.unsqueeze(0).float() 53 | x_embed = grid_x.unsqueeze(0).float() 54 | 55 | if self.normalize: 56 | eps = 1e-6 57 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 58 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 59 | 60 | dim_t = torch.arange(self.num_pos_feats, 61 | dtype=torch.float32, 62 | device=x.device) 63 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) 64 | 65 | pos_x = x_embed[:, :, :, None] / dim_t 66 | pos_y = y_embed[:, :, :, None] / dim_t 67 | pos_x = torch.stack( 68 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), 69 | dim=4).flatten(3) 70 | pos_y = torch.stack( 71 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), 72 | dim=4).flatten(3) 73 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 74 | return pos 75 | 76 | 77 | class PositionEmbeddingLearned(nn.Module): 78 | def __init__(self, num_pos_feats=64, H=30, W=30): 79 | super().__init__() 80 | self.H = H 81 | self.W = W 82 | self.pos_emb = nn.Parameter( 83 | truncated_normal_(torch.zeros(1, num_pos_feats, H, W))) 84 | 85 | def forward(self, x): 86 | bs, _, h, w = x.size() 87 | pos_emb = self.pos_emb 88 | if h != self.H or w != self.W: 89 | pos_emb = F.interpolate(pos_emb, size=(h, w), mode="bilinear") 90 | return pos_emb 91 | -------------------------------------------------------------------------------- /networks/models/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.models.aot import AOT 2 | from networks.models.deaot import DeAOT 3 | 4 | 5 | def build_vos_model(name, cfg, **kwargs): 6 | if name == 'aot': 7 | return AOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) 8 | elif name == 'deaot': 9 | return DeAOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) 10 | else: 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /networks/models/aot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.encoders import build_encoder 4 | from networks.layers.transformer import LongShortTermTransformer 5 | from networks.decoders import build_decoder 6 | from networks.layers.position import PositionEmbeddingSine 7 | 8 | 9 | class AOT(nn.Module): 10 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): 11 | super().__init__() 12 | self.cfg = cfg 13 | self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM 14 | self.epsilon = cfg.MODEL_EPSILON 15 | 16 | self.encoder = build_encoder(encoder, 17 | frozen_bn=cfg.MODEL_FREEZE_BN, 18 | freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT) 19 | self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1], 20 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 21 | kernel_size=1) 22 | 23 | self.LSTT = LongShortTermTransformer( 24 | cfg.MODEL_LSTT_NUM, 25 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 26 | cfg.MODEL_SELF_HEADS, 27 | cfg.MODEL_ATT_HEADS, 28 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, 29 | droppath=cfg.TRAIN_LSTT_DROPPATH, 30 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, 31 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, 32 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, 33 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, 34 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 35 | return_intermediate=True) 36 | 37 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ 38 | (cfg.MODEL_LSTT_NUM + 39 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM 40 | 41 | self.decoder = build_decoder( 42 | decoder, 43 | in_dim=decoder_indim, 44 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, 45 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 46 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, 47 | shortcut_dims=cfg.MODEL_ENCODER_DIM, 48 | align_corners=cfg.MODEL_ALIGN_CORNERS) 49 | 50 | if cfg.MODEL_ALIGN_CORNERS: 51 | self.patch_wise_id_bank = nn.Conv2d( 52 | cfg.MODEL_MAX_OBJ_NUM + 1, 53 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 54 | kernel_size=17, 55 | stride=16, 56 | padding=8) 57 | else: 58 | self.patch_wise_id_bank = nn.Conv2d( 59 | cfg.MODEL_MAX_OBJ_NUM + 1, 60 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 61 | kernel_size=16, 62 | stride=16, 63 | padding=0) 64 | 65 | self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True) 66 | 67 | self.pos_generator = PositionEmbeddingSine( 68 | cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True) 69 | 70 | self._init_weight() 71 | 72 | def get_pos_emb(self, x): 73 | pos_emb = self.pos_generator(x) 74 | return pos_emb 75 | 76 | def get_id_emb(self, x): 77 | id_emb = self.patch_wise_id_bank(x) 78 | id_emb = self.id_dropout(id_emb) 79 | return id_emb 80 | 81 | def encode_image(self, img): 82 | xs = self.encoder(img) 83 | xs[-1] = self.encoder_projector(xs[-1]) 84 | return xs 85 | 86 | def decode_id_logits(self, lstt_emb, shortcuts): 87 | n, c, h, w = shortcuts[-1].size() 88 | decoder_inputs = [shortcuts[-1]] 89 | for emb in lstt_emb: 90 | decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1)) 91 | pred_logit = self.decoder(decoder_inputs, shortcuts) 92 | return pred_logit 93 | 94 | def LSTT_forward(self, 95 | curr_embs, 96 | long_term_memories, 97 | short_term_memories, 98 | curr_id_emb=None, 99 | pos_emb=None, 100 | size_2d=(30, 30)): 101 | n, c, h, w = curr_embs[-1].size() 102 | curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1) 103 | lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories, 104 | short_term_memories, curr_id_emb, 105 | pos_emb, size_2d) 106 | lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip( 107 | *lstt_memories) 108 | return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories 109 | 110 | def _init_weight(self): 111 | nn.init.xavier_uniform_(self.encoder_projector.weight) 112 | nn.init.orthogonal_( 113 | self.patch_wise_id_bank.weight.view( 114 | self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1), 115 | gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2) 116 | -------------------------------------------------------------------------------- /networks/models/deaot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.layers.transformer import DualBranchGPM 4 | from networks.models.aot import AOT 5 | from networks.decoders import build_decoder 6 | 7 | 8 | class DeAOT(AOT): 9 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): 10 | super().__init__(cfg, encoder, decoder) 11 | 12 | self.LSTT = DualBranchGPM( 13 | cfg.MODEL_LSTT_NUM, 14 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 15 | cfg.MODEL_SELF_HEADS, 16 | cfg.MODEL_ATT_HEADS, 17 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, 18 | droppath=cfg.TRAIN_LSTT_DROPPATH, 19 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, 20 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, 21 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, 22 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, 23 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 24 | return_intermediate=True) 25 | 26 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ 27 | (cfg.MODEL_LSTT_NUM * 2 + 28 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM * 2 29 | 30 | self.decoder = build_decoder( 31 | decoder, 32 | in_dim=decoder_indim, 33 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, 34 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 35 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, 36 | shortcut_dims=cfg.MODEL_ENCODER_DIM, 37 | align_corners=cfg.MODEL_ALIGN_CORNERS) 38 | 39 | self.id_norm = nn.LayerNorm(cfg.MODEL_ENCODER_EMBEDDING_DIM) 40 | 41 | self._init_weight() 42 | 43 | def decode_id_logits(self, lstt_emb, shortcuts): 44 | n, c, h, w = shortcuts[-1].size() 45 | decoder_inputs = [shortcuts[-1]] 46 | for emb in lstt_emb: 47 | decoder_inputs.append(emb.view(h, w, n, -1).permute(2, 3, 0, 1)) 48 | pred_logit = self.decoder(decoder_inputs, shortcuts) 49 | return pred_logit 50 | 51 | def get_id_emb(self, x): 52 | id_emb = self.patch_wise_id_bank(x) 53 | id_emb = self.id_norm(id_emb.permute(2, 3, 0, 1)).permute(2, 3, 0, 1) 54 | id_emb = self.id_dropout(id_emb) 55 | return id_emb 56 | -------------------------------------------------------------------------------- /pretrain_models/README.md: -------------------------------------------------------------------------------- 1 | Put pretrained models here. -------------------------------------------------------------------------------- /source/1001_3iEIq5HBY1s.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/1001_3iEIq5HBY1s.gif -------------------------------------------------------------------------------- /source/1007_YCTBBdbKSSg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/1007_YCTBBdbKSSg.gif -------------------------------------------------------------------------------- /source/kobe.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/kobe.gif -------------------------------------------------------------------------------- /source/messi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/messi.gif -------------------------------------------------------------------------------- /source/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/overview.png -------------------------------------------------------------------------------- /source/overview_deaot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/overview_deaot.png -------------------------------------------------------------------------------- /source/some_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/source/some_results.png -------------------------------------------------------------------------------- /tools/demo.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | import os 4 | 5 | sys.path.append('.') 6 | sys.path.append('..') 7 | 8 | import cv2 9 | from PIL import Image 10 | from skimage.morphology.binary import binary_dilation 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader 16 | from torchvision import transforms 17 | 18 | from networks.models import build_vos_model 19 | from networks.engines import build_engine 20 | from utils.checkpoint import load_network 21 | 22 | from dataloaders.eval_datasets import VOSTest 23 | import dataloaders.video_transforms as tr 24 | from utils.image import save_mask 25 | 26 | _palette = [ 27 | 255, 0, 0, 0, 0, 139, 255, 255, 84, 0, 255, 0, 139, 0, 139, 0, 128, 128, 28 | 128, 128, 128, 139, 0, 0, 218, 165, 32, 144, 238, 144, 160, 82, 45, 148, 0, 29 | 211, 255, 0, 255, 30, 144, 255, 255, 218, 185, 85, 107, 47, 255, 140, 0, 30 | 50, 205, 50, 123, 104, 238, 240, 230, 140, 72, 61, 139, 128, 128, 0, 0, 0, 31 | 205, 221, 160, 221, 143, 188, 143, 127, 255, 212, 176, 224, 230, 244, 164, 32 | 96, 250, 128, 114, 70, 130, 180, 0, 128, 0, 173, 255, 47, 255, 105, 180, 33 | 238, 130, 238, 154, 205, 50, 220, 20, 60, 176, 48, 96, 0, 206, 209, 0, 191, 34 | 255, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 35 | 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, 36 | 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58, 37 | 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64, 38 | 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70, 39 | 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77, 40 | 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83, 41 | 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89, 42 | 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96, 43 | 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101, 44 | 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106, 45 | 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111, 46 | 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116, 47 | 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, 48 | 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126, 49 | 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, 50 | 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136, 51 | 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141, 52 | 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146, 53 | 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151, 54 | 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156, 55 | 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161, 56 | 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166, 57 | 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171, 58 | 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176, 59 | 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181, 60 | 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186, 61 | 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191, 62 | 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196, 63 | 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201, 64 | 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206, 65 | 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 66 | 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216, 67 | 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221, 68 | 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226, 69 | 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231, 70 | 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236, 71 | 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241, 72 | 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246, 73 | 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251, 74 | 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255, 0, 0, 0 75 | ] 76 | color_palette = np.array(_palette).reshape(-1, 3) 77 | 78 | 79 | def overlay(image, mask, colors=[255, 0, 0], cscale=1, alpha=0.4): 80 | colors = np.atleast_2d(colors) * cscale 81 | 82 | im_overlay = image.copy() 83 | object_ids = np.unique(mask) 84 | 85 | for object_id in object_ids[1:]: 86 | # Overlay color on binary mask 87 | 88 | foreground = image * alpha + np.ones( 89 | image.shape) * (1 - alpha) * np.array(colors[object_id]) 90 | binary_mask = mask == object_id 91 | 92 | # Compose image 93 | im_overlay[binary_mask] = foreground[binary_mask] 94 | 95 | countours = binary_dilation(binary_mask) ^ binary_mask 96 | im_overlay[countours, :] = 0 97 | 98 | return im_overlay.astype(image.dtype) 99 | 100 | 101 | def demo(cfg): 102 | video_fps = 15 103 | gpu_id = cfg.TEST_GPU_ID 104 | 105 | # Load pre-trained model 106 | print('Build AOT model.') 107 | model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(gpu_id) 108 | 109 | print('Load checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) 110 | model, _ = load_network(model, cfg.TEST_CKPT_PATH, gpu_id) 111 | 112 | print('Build AOT engine.') 113 | engine = build_engine(cfg.MODEL_ENGINE, 114 | phase='eval', 115 | aot_model=model, 116 | gpu_id=gpu_id, 117 | long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP) 118 | 119 | # Prepare datasets for each sequence 120 | transform = transforms.Compose([ 121 | tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE, 122 | cfg.TEST_FLIP, cfg.TEST_MULTISCALE, 123 | cfg.MODEL_ALIGN_CORNERS), 124 | tr.MultiToTensor() 125 | ]) 126 | image_root = os.path.join(cfg.TEST_DATA_PATH, 'images') 127 | label_root = os.path.join(cfg.TEST_DATA_PATH, 'masks') 128 | 129 | sequences = os.listdir(image_root) 130 | seq_datasets = [] 131 | for seq_name in sequences: 132 | print('Build a dataset for sequence {}.'.format(seq_name)) 133 | seq_images = np.sort(os.listdir(os.path.join(image_root, seq_name))) 134 | seq_labels = [seq_images[0].replace('jpg', 'png')] 135 | seq_dataset = VOSTest(image_root, 136 | label_root, 137 | seq_name, 138 | seq_images, 139 | seq_labels, 140 | transform=transform) 141 | seq_datasets.append(seq_dataset) 142 | 143 | # Infer 144 | output_root = cfg.TEST_OUTPUT_PATH 145 | output_mask_root = os.path.join(output_root, 'pred_masks') 146 | if not os.path.exists(output_mask_root): 147 | os.makedirs(output_mask_root) 148 | 149 | for seq_dataset in seq_datasets: 150 | seq_name = seq_dataset.seq_name 151 | image_seq_root = os.path.join(image_root, seq_name) 152 | output_mask_seq_root = os.path.join(output_mask_root, seq_name) 153 | if not os.path.exists(output_mask_seq_root): 154 | os.makedirs(output_mask_seq_root) 155 | print('Build a dataloader for sequence {}.'.format(seq_name)) 156 | seq_dataloader = DataLoader(seq_dataset, 157 | batch_size=1, 158 | shuffle=False, 159 | num_workers=cfg.TEST_WORKERS, 160 | pin_memory=True) 161 | 162 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 163 | output_video_path = os.path.join( 164 | output_root, '{}_{}fps.avi'.format(seq_name, video_fps)) 165 | 166 | print('Start the inference of sequence {}:'.format(seq_name)) 167 | model.eval() 168 | engine.restart_engine() 169 | with torch.no_grad(): 170 | for frame_idx, samples in enumerate(seq_dataloader): 171 | sample = samples[0] 172 | img_name = sample['meta']['current_name'][0] 173 | 174 | obj_nums = sample['meta']['obj_num'] 175 | output_height = sample['meta']['height'] 176 | output_width = sample['meta']['width'] 177 | obj_idx = sample['meta']['obj_idx'] 178 | 179 | obj_nums = [int(obj_num) for obj_num in obj_nums] 180 | obj_idx = [int(_obj_idx) for _obj_idx in obj_idx] 181 | 182 | current_img = sample['current_img'] 183 | current_img = current_img.cuda(gpu_id, non_blocking=True) 184 | 185 | if frame_idx == 0: 186 | videoWriter = cv2.VideoWriter( 187 | output_video_path, fourcc, video_fps, 188 | (int(output_width), int(output_height))) 189 | print( 190 | 'Object number: {}. Inference size: {}x{}. Output size: {}x{}.' 191 | .format(obj_nums[0], 192 | current_img.size()[2], 193 | current_img.size()[3], int(output_height), 194 | int(output_width))) 195 | current_label = sample['current_label'].cuda( 196 | gpu_id, non_blocking=True).float() 197 | current_label = F.interpolate(current_label, 198 | size=current_img.size()[2:], 199 | mode="nearest") 200 | # add reference frame 201 | engine.add_reference_frame(current_img, 202 | current_label, 203 | frame_step=0, 204 | obj_nums=obj_nums) 205 | else: 206 | print('Processing image {}...'.format(img_name)) 207 | # predict segmentation 208 | engine.match_propogate_one_frame(current_img) 209 | pred_logit = engine.decode_current_logits( 210 | (output_height, output_width)) 211 | pred_prob = torch.softmax(pred_logit, dim=1) 212 | pred_label = torch.argmax(pred_prob, dim=1, 213 | keepdim=True).float() 214 | _pred_label = F.interpolate(pred_label, 215 | size=engine.input_size_2d, 216 | mode="nearest") 217 | # update memory 218 | engine.update_memory(_pred_label) 219 | 220 | # save results 221 | input_image_path = os.path.join(image_seq_root, img_name) 222 | output_mask_path = os.path.join( 223 | output_mask_seq_root, 224 | img_name.split('.')[0] + '.png') 225 | 226 | pred_label = Image.fromarray( 227 | pred_label.squeeze(0).squeeze(0).cpu().numpy().astype( 228 | 'uint8')).convert('P') 229 | pred_label.putpalette(_palette) 230 | pred_label.save(output_mask_path) 231 | 232 | input_image = Image.open(input_image_path) 233 | 234 | overlayed_image = overlay( 235 | np.array(input_image, dtype=np.uint8), 236 | np.array(pred_label, dtype=np.uint8), color_palette) 237 | videoWriter.write(overlayed_image[..., [2, 1, 0]]) 238 | 239 | print('Save a visualization video to {}.'.format(output_video_path)) 240 | videoWriter.release() 241 | 242 | 243 | def main(): 244 | import argparse 245 | parser = argparse.ArgumentParser(description="AOT Demo") 246 | parser.add_argument('--exp_name', type=str, default='default') 247 | 248 | parser.add_argument('--stage', type=str, default='pre_ytb_dav') 249 | parser.add_argument('--model', type=str, default='r50_aotl') 250 | 251 | parser.add_argument('--gpu_id', type=int, default=0) 252 | 253 | parser.add_argument('--data_path', type=str, default='./datasets/Demo') 254 | parser.add_argument('--output_path', type=str, default='./demo_output') 255 | parser.add_argument('--ckpt_path', 256 | type=str, 257 | default='./pretrain_models/R50_AOTL_PRE_YTB_DAV.pth') 258 | 259 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3) 260 | 261 | parser.add_argument('--amp', action='store_true') 262 | parser.set_defaults(amp=False) 263 | 264 | args = parser.parse_args() 265 | 266 | engine_config = importlib.import_module('configs.' + args.stage) 267 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 268 | 269 | cfg.TEST_GPU_ID = args.gpu_id 270 | 271 | cfg.TEST_CKPT_PATH = args.ckpt_path 272 | cfg.TEST_DATA_PATH = args.data_path 273 | cfg.TEST_OUTPUT_PATH = args.output_path 274 | 275 | cfg.TEST_MIN_SIZE = None 276 | cfg.TEST_MAX_SIZE = args.max_resolution * 800. / 480. 277 | 278 | if args.amp: 279 | with torch.cuda.amp.autocast(enabled=True): 280 | demo(cfg) 281 | else: 282 | demo(cfg) 283 | 284 | 285 | if __name__ == '__main__': 286 | main() 287 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('..') 6 | 7 | import torch 8 | import torch.multiprocessing as mp 9 | 10 | from networks.managers.evaluator import Evaluator 11 | 12 | 13 | def main_worker(gpu, cfg, seq_queue=None, info_queue=None, enable_amp=False): 14 | # Initiate a evaluating manager 15 | evaluator = Evaluator(rank=gpu, 16 | cfg=cfg, 17 | seq_queue=seq_queue, 18 | info_queue=info_queue) 19 | # Start evaluation 20 | if enable_amp: 21 | with torch.cuda.amp.autocast(enabled=True): 22 | evaluator.evaluating() 23 | else: 24 | evaluator.evaluating() 25 | 26 | 27 | def main(): 28 | import argparse 29 | parser = argparse.ArgumentParser(description="Eval VOS") 30 | parser.add_argument('--exp_name', type=str, default='default') 31 | 32 | parser.add_argument('--stage', type=str, default='pre') 33 | parser.add_argument('--model', type=str, default='aott') 34 | parser.add_argument('--lstt_num', type=int, default=-1) 35 | parser.add_argument('--lt_gap', type=int, default=-1) 36 | parser.add_argument('--st_skip', type=int, default=-1) 37 | parser.add_argument('--max_id_num', type=int, default='-1') 38 | 39 | parser.add_argument('--gpu_id', type=int, default=0) 40 | parser.add_argument('--gpu_num', type=int, default=1) 41 | 42 | parser.add_argument('--ckpt_path', type=str, default='') 43 | parser.add_argument('--ckpt_step', type=int, default=-1) 44 | 45 | parser.add_argument('--dataset', type=str, default='') 46 | parser.add_argument('--split', type=str, default='') 47 | 48 | parser.add_argument('--ema', action='store_true') 49 | parser.set_defaults(ema=False) 50 | 51 | parser.add_argument('--flip', action='store_true') 52 | parser.set_defaults(flip=False) 53 | parser.add_argument('--ms', nargs='+', type=float, default=[1.]) 54 | 55 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3) 56 | 57 | parser.add_argument('--amp', action='store_true') 58 | parser.set_defaults(amp=False) 59 | 60 | args = parser.parse_args() 61 | 62 | engine_config = importlib.import_module('configs.' + args.stage) 63 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 64 | 65 | cfg.TEST_EMA = args.ema 66 | 67 | cfg.TEST_GPU_ID = args.gpu_id 68 | cfg.TEST_GPU_NUM = args.gpu_num 69 | 70 | if args.lstt_num > 0: 71 | cfg.MODEL_LSTT_NUM = args.lstt_num 72 | if args.lt_gap > 0: 73 | cfg.TEST_LONG_TERM_MEM_GAP = args.lt_gap 74 | if args.st_skip > 0: 75 | cfg.TEST_SHORT_TERM_MEM_SKIP = args.st_skip 76 | 77 | if args.max_id_num > 0: 78 | cfg.MODEL_MAX_OBJ_NUM = args.max_id_num 79 | 80 | if args.ckpt_path != '': 81 | cfg.TEST_CKPT_PATH = args.ckpt_path 82 | if args.ckpt_step > 0: 83 | cfg.TEST_CKPT_STEP = args.ckpt_step 84 | 85 | if args.dataset != '': 86 | cfg.TEST_DATASET = args.dataset 87 | 88 | if args.split != '': 89 | cfg.TEST_DATASET_SPLIT = args.split 90 | 91 | cfg.TEST_FLIP = args.flip 92 | cfg.TEST_MULTISCALE = args.ms 93 | 94 | if cfg.TEST_MULTISCALE != [1.]: 95 | cfg.TEST_MAX_SHORT_EDGE = args.max_resolution # for preventing OOM 96 | else: 97 | cfg.TEST_MAX_SHORT_EDGE = None # the default resolution setting of CFBI and AOT 98 | cfg.TEST_MAX_LONG_EDGE = args.max_resolution * 800. / 480. 99 | 100 | if args.gpu_num > 1: 101 | mp.set_start_method('spawn') 102 | seq_queue = mp.Queue() 103 | info_queue = mp.Queue() 104 | mp.spawn(main_worker, 105 | nprocs=cfg.TEST_GPU_NUM, 106 | args=(cfg, seq_queue, info_queue, args.amp)) 107 | else: 108 | main_worker(0, cfg, enable_amp=args.amp) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import random 3 | import sys 4 | 5 | sys.setrecursionlimit(10000) 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | 9 | import torch.multiprocessing as mp 10 | 11 | from networks.managers.trainer import Trainer 12 | 13 | 14 | def main_worker(gpu, cfg, enable_amp=True): 15 | # Initiate a training manager 16 | trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp) 17 | # Start Training 18 | trainer.sequential_training() 19 | 20 | 21 | def main(): 22 | import argparse 23 | parser = argparse.ArgumentParser(description="Train VOS") 24 | parser.add_argument('--exp_name', type=str, default='') 25 | parser.add_argument('--stage', type=str, default='pre') 26 | parser.add_argument('--model', type=str, default='aott') 27 | parser.add_argument('--max_id_num', type=int, default='-1') 28 | 29 | parser.add_argument('--start_gpu', type=int, default=0) 30 | parser.add_argument('--gpu_num', type=int, default=-1) 31 | parser.add_argument('--batch_size', type=int, default=-1) 32 | parser.add_argument('--dist_url', type=str, default='') 33 | parser.add_argument('--amp', action='store_true') 34 | parser.set_defaults(amp=False) 35 | 36 | parser.add_argument('--pretrained_path', type=str, default='') 37 | 38 | parser.add_argument('--datasets', nargs='+', type=str, default=[]) 39 | parser.add_argument('--lr', type=float, default=-1.) 40 | parser.add_argument('--total_step', type=int, default=-1.) 41 | parser.add_argument('--start_step', type=int, default=-1.) 42 | 43 | args = parser.parse_args() 44 | 45 | engine_config = importlib.import_module('configs.' + args.stage) 46 | 47 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 48 | 49 | if len(args.datasets) > 0: 50 | cfg.DATASETS = args.datasets 51 | 52 | cfg.DIST_START_GPU = args.start_gpu 53 | if args.gpu_num > 0: 54 | cfg.TRAIN_GPUS = args.gpu_num 55 | if args.batch_size > 0: 56 | cfg.TRAIN_BATCH_SIZE = args.batch_size 57 | 58 | if args.pretrained_path != '': 59 | cfg.PRETRAIN_MODEL = args.pretrained_path 60 | 61 | if args.max_id_num > 0: 62 | cfg.MODEL_MAX_OBJ_NUM = args.max_id_num 63 | 64 | if args.lr > 0: 65 | cfg.TRAIN_LR = args.lr 66 | 67 | if args.total_step > 0: 68 | cfg.TRAIN_TOTAL_STEPS = args.total_step 69 | 70 | if args.start_step > 0: 71 | cfg.TRAIN_START_STEP = args.start_step 72 | 73 | if args.dist_url == '': 74 | cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str( 75 | random.randint(0, 9)) 76 | else: 77 | cfg.DIST_URL = args.dist_url 78 | 79 | if cfg.TRAIN_GPUS > 1: 80 | # Use torch.multiprocessing.spawn to launch distributed processes 81 | mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp)) 82 | else: 83 | cfg.TRAIN_GPUS = 1 84 | main_worker(0, cfg, args.amp) 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /train_eval.sh: -------------------------------------------------------------------------------- 1 | exp="default" 2 | gpu_num="4" 3 | 4 | model="aott" 5 | # model="aots" 6 | # model="aotb" 7 | # model="aotl" 8 | # model="r50_aotl" 9 | # model="swinb_aotl" 10 | 11 | ## Training ## 12 | stage="pre" 13 | python tools/train.py --amp \ 14 | --exp_name ${exp} \ 15 | --stage ${stage} \ 16 | --model ${model} \ 17 | --gpu_num ${gpu_num} 18 | 19 | stage="pre_ytb_dav" 20 | python tools/train.py --amp \ 21 | --exp_name ${exp} \ 22 | --stage ${stage} \ 23 | --model ${model} \ 24 | --gpu_num ${gpu_num} 25 | 26 | ## Evaluation ## 27 | dataset="davis2017" 28 | split="test" 29 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 30 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 31 | 32 | dataset="davis2017" 33 | split="val" 34 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 35 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 36 | 37 | dataset="davis2016" 38 | split="val" 39 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 40 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 41 | 42 | dataset="youtubevos2018" 43 | split="val" # or "val_all_frames" 44 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 45 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} 46 | 47 | dataset="youtubevos2019" 48 | split="val" # or "val_all_frames" 49 | python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 50 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoxu515/aot-benchmark/6852c2d2284b1ebeb7e4dd0c0f05fdf4102bd34d/utils/__init__.py -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | import numpy as np 5 | 6 | 7 | def load_network_and_optimizer(net, opt, pretrained_dir, gpu, scaler=None): 8 | pretrained = torch.load(pretrained_dir, 9 | map_location=torch.device("cuda:" + str(gpu))) 10 | pretrained_dict = pretrained['state_dict'] 11 | model_dict = net.state_dict() 12 | pretrained_dict_update = {} 13 | pretrained_dict_remove = [] 14 | for k, v in pretrained_dict.items(): 15 | if k in model_dict: 16 | pretrained_dict_update[k] = v 17 | elif k[:7] == 'module.': 18 | if k[7:] in model_dict: 19 | pretrained_dict_update[k[7:]] = v 20 | else: 21 | pretrained_dict_remove.append(k) 22 | model_dict.update(pretrained_dict_update) 23 | net.load_state_dict(model_dict) 24 | opt.load_state_dict(pretrained['optimizer']) 25 | if scaler is not None and 'scaler' in pretrained.keys(): 26 | scaler.load_state_dict(pretrained['scaler']) 27 | del (pretrained) 28 | return net.cuda(gpu), opt, pretrained_dict_remove 29 | 30 | 31 | def load_network_and_optimizer_v2(net, opt, pretrained_dir, gpu, scaler=None): 32 | pretrained = torch.load(pretrained_dir, 33 | map_location=torch.device("cuda:" + str(gpu))) 34 | # load model 35 | pretrained_dict = pretrained['state_dict'] 36 | model_dict = net.state_dict() 37 | pretrained_dict_update = {} 38 | pretrained_dict_remove = [] 39 | for k, v in pretrained_dict.items(): 40 | if k in model_dict: 41 | pretrained_dict_update[k] = v 42 | elif k[:7] == 'module.': 43 | if k[7:] in model_dict: 44 | pretrained_dict_update[k[7:]] = v 45 | else: 46 | pretrained_dict_remove.append(k) 47 | model_dict.update(pretrained_dict_update) 48 | net.load_state_dict(model_dict) 49 | 50 | # load optimizer 51 | opt_dict = opt.state_dict() 52 | all_params = { 53 | param_group['name']: param_group['params'][0] 54 | for param_group in opt_dict['param_groups'] 55 | } 56 | pretrained_opt_dict = {'state': {}, 'param_groups': []} 57 | for idx in range(len(pretrained['optimizer']['param_groups'])): 58 | param_group = pretrained['optimizer']['param_groups'][idx] 59 | if param_group['name'] in all_params.keys(): 60 | pretrained_opt_dict['state'][all_params[ 61 | param_group['name']]] = pretrained['optimizer']['state'][ 62 | param_group['params'][0]] 63 | param_group['params'][0] = all_params[param_group['name']] 64 | pretrained_opt_dict['param_groups'].append(param_group) 65 | 66 | opt_dict.update(pretrained_opt_dict) 67 | opt.load_state_dict(opt_dict) 68 | 69 | # load scaler 70 | if scaler is not None and 'scaler' in pretrained.keys(): 71 | scaler.load_state_dict(pretrained['scaler']) 72 | del (pretrained) 73 | return net.cuda(gpu), opt, pretrained_dict_remove 74 | 75 | 76 | def load_network(net, pretrained_dir, gpu): 77 | pretrained = torch.load(pretrained_dir, 78 | map_location=torch.device("cuda:" + str(gpu))) 79 | if 'state_dict' in pretrained.keys(): 80 | pretrained_dict = pretrained['state_dict'] 81 | elif 'model' in pretrained.keys(): 82 | pretrained_dict = pretrained['model'] 83 | else: 84 | pretrained_dict = pretrained 85 | model_dict = net.state_dict() 86 | pretrained_dict_update = {} 87 | pretrained_dict_remove = [] 88 | for k, v in pretrained_dict.items(): 89 | if k in model_dict: 90 | pretrained_dict_update[k] = v 91 | elif k[:7] == 'module.': 92 | if k[7:] in model_dict: 93 | pretrained_dict_update[k[7:]] = v 94 | else: 95 | pretrained_dict_remove.append(k) 96 | model_dict.update(pretrained_dict_update) 97 | net.load_state_dict(model_dict) 98 | del (pretrained) 99 | return net.cuda(gpu), pretrained_dict_remove 100 | 101 | 102 | def save_network(net, 103 | opt, 104 | step, 105 | save_path, 106 | max_keep=8, 107 | backup_dir='./saved_models', 108 | scaler=None): 109 | ckpt = {'state_dict': net.state_dict(), 'optimizer': opt.state_dict()} 110 | if scaler is not None: 111 | ckpt['scaler'] = scaler.state_dict() 112 | 113 | try: 114 | if not os.path.exists(save_path): 115 | os.makedirs(save_path) 116 | save_file = 'save_step_%s.pth' % (step) 117 | save_dir = os.path.join(save_path, save_file) 118 | torch.save(ckpt, save_dir) 119 | except: 120 | save_path = backup_dir 121 | if not os.path.exists(save_path): 122 | os.makedirs(save_path) 123 | save_file = 'save_step_%s.pth' % (step) 124 | save_dir = os.path.join(save_path, save_file) 125 | torch.save(ckpt, save_dir) 126 | 127 | all_ckpt = os.listdir(save_path) 128 | if len(all_ckpt) > max_keep: 129 | all_step = [] 130 | for ckpt_name in all_ckpt: 131 | step = int(ckpt_name.split('_')[-1].split('.')[0]) 132 | all_step.append(step) 133 | all_step = list(np.sort(all_step))[:-max_keep] 134 | for step in all_step: 135 | ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step)) 136 | os.system('rm {}'.format(ckpt_path)) 137 | 138 | 139 | def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"): 140 | exps = os.listdir(curr_dir) 141 | for exp in exps: 142 | exp_dir = os.path.join(curr_dir, exp) 143 | stages = os.listdir(exp_dir) 144 | for stage in stages: 145 | stage_dir = os.path.join(exp_dir, stage) 146 | finals = ["ema_ckpt", "ckpt"] 147 | for final in finals: 148 | final_dir = os.path.join(stage_dir, final) 149 | ckpts = os.listdir(final_dir) 150 | for ckpt in ckpts: 151 | if '.pth' not in ckpt: 152 | continue 153 | curr_ckpt_path = os.path.join(final_dir, ckpt) 154 | remote_ckpt_path = os.path.join(remote_dir, exp, stage, 155 | final, ckpt) 156 | if os.path.exists(remote_ckpt_path): 157 | os.system('rm {}'.format(remote_ckpt_path)) 158 | try: 159 | shutil.copy(curr_ckpt_path, remote_ckpt_path) 160 | print("Copy {} to {}.".format(curr_ckpt_path, 161 | remote_ckpt_path)) 162 | except OSError as Inst: 163 | return 164 | -------------------------------------------------------------------------------- /utils/cp_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def cp_ckpt(remote_dir="data_wd/youtube_vos_jobs/result", curr_dir="backup"): 6 | exps = os.listdir(curr_dir) 7 | for exp in exps: 8 | print("Exp: ", exp) 9 | exp_dir = os.path.join(curr_dir, exp) 10 | stages = os.listdir(exp_dir) 11 | for stage in stages: 12 | print("Stage: ", stage) 13 | stage_dir = os.path.join(exp_dir, stage) 14 | finals = ["ema_ckpt", "ckpt"] 15 | for final in finals: 16 | print("Final: ", final) 17 | final_dir = os.path.join(stage_dir, final) 18 | ckpts = os.listdir(final_dir) 19 | for ckpt in ckpts: 20 | if '.pth' not in ckpt: 21 | continue 22 | curr_ckpt_path = os.path.join(final_dir, ckpt) 23 | remote_ckpt_path = os.path.join(remote_dir, exp, stage, 24 | final, ckpt) 25 | if os.path.exists(remote_ckpt_path): 26 | os.system('rm {}'.format(remote_ckpt_path)) 27 | try: 28 | shutil.copy(curr_ckpt_path, remote_ckpt_path) 29 | print(ckpt, ': OK') 30 | except OSError as Inst: 31 | print(Inst) 32 | print(ckpt, ': Fail') 33 | 34 | 35 | if __name__ == "__main__": 36 | cp_ckpt() 37 | -------------------------------------------------------------------------------- /utils/ema.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import unicode_literals 3 | 4 | import torch 5 | 6 | 7 | def get_param_buffer_for_ema(model, 8 | update_buffer=False, 9 | required_buffers=['running_mean', 'running_var']): 10 | params = model.parameters() 11 | all_param_buffer = [p for p in params if p.requires_grad] 12 | if update_buffer: 13 | named_buffers = model.named_buffers() 14 | for key, value in named_buffers: 15 | for buffer_name in required_buffers: 16 | if buffer_name in key: 17 | all_param_buffer.append(value) 18 | break 19 | return all_param_buffer 20 | 21 | 22 | class ExponentialMovingAverage: 23 | """ 24 | Maintains (exponential) moving average of a set of parameters. 25 | """ 26 | def __init__(self, parameters, decay, use_num_updates=True): 27 | """ 28 | Args: 29 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 30 | `model.parameters()`. 31 | decay: The exponential decay. 32 | use_num_updates: Whether to use number of updates when computing 33 | averages. 34 | """ 35 | if decay < 0.0 or decay > 1.0: 36 | raise ValueError('Decay must be between 0 and 1') 37 | self.decay = decay 38 | self.num_updates = 0 if use_num_updates else None 39 | self.shadow_params = [p.clone().detach() for p in parameters] 40 | self.collected_params = [] 41 | 42 | def update(self, parameters): 43 | """ 44 | Update currently maintained parameters. 45 | Call this every time the parameters are updated, such as the result of 46 | the `optimizer.step()` call. 47 | Args: 48 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 49 | parameters used to initialize this object. 50 | """ 51 | decay = self.decay 52 | if self.num_updates is not None: 53 | self.num_updates += 1 54 | decay = min(decay, 55 | (1 + self.num_updates) / (10 + self.num_updates)) 56 | one_minus_decay = 1.0 - decay 57 | with torch.no_grad(): 58 | for s_param, param in zip(self.shadow_params, parameters): 59 | s_param.sub_(one_minus_decay * (s_param - param)) 60 | 61 | def copy_to(self, parameters): 62 | """ 63 | Copy current parameters into given collection of parameters. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | updated with the stored moving averages. 67 | """ 68 | for s_param, param in zip(self.shadow_params, parameters): 69 | param.data.copy_(s_param.data) 70 | 71 | def store(self, parameters): 72 | """ 73 | Save the current parameters for restoring later. 74 | Args: 75 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 76 | temporarily stored. 77 | """ 78 | self.collected_params = [param.clone() for param in parameters] 79 | 80 | def restore(self, parameters): 81 | """ 82 | Restore the parameters stored with the `store` method. 83 | Useful to validate the model with EMA parameters without affecting the 84 | original optimization process. Store the parameters before the 85 | `copy_to` method. After validation (or model saving), use this to 86 | restore the former parameters. 87 | Args: 88 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 89 | updated with the stored parameters. 90 | """ 91 | for c_param, param in zip(self.collected_params, parameters): 92 | param.data.copy_(c_param.data) 93 | del (self.collected_params) 94 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os 3 | 4 | 5 | def zip_folder(source_folder, zip_dir): 6 | f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED) 7 | pre_len = len(os.path.dirname(source_folder)) 8 | for dirpath, dirnames, filenames in os.walk(source_folder): 9 | for filename in filenames: 10 | pathfile = os.path.join(dirpath, filename) 11 | arcname = pathfile[pre_len:].strip(os.path.sep) 12 | f.write(pathfile, arcname) 13 | f.close() -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import threading 5 | 6 | _palette = [ 7 | 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 8 | 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 9 | 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, 10 | 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, 11 | 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 12 | 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 13 | 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 14 | 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 15 | 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 16 | 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 17 | 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 18 | 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 19 | 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 20 | 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 21 | 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 22 | 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 23 | 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, 24 | 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, 25 | 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, 26 | 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 27 | 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, 28 | 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, 29 | 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, 30 | 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, 31 | 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 32 | 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, 33 | 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, 34 | 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, 35 | 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, 36 | 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, 37 | 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, 38 | 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, 39 | 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, 40 | 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, 41 | 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, 42 | 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, 43 | 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, 44 | 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, 45 | 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, 46 | 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, 47 | 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, 48 | 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, 49 | 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, 50 | 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, 51 | 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, 52 | 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, 53 | 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, 54 | 255, 255, 255 55 | ] 56 | 57 | 58 | def label2colormap(label): 59 | 60 | m = label.astype(np.uint8) 61 | r, c = m.shape 62 | cmap = np.zeros((r, c, 3), dtype=np.uint8) 63 | cmap[:, :, 0] = (m & 1) << 7 | (m & 8) << 3 | (m & 64) >> 1 64 | cmap[:, :, 1] = (m & 2) << 6 | (m & 16) << 2 | (m & 128) >> 2 65 | cmap[:, :, 2] = (m & 4) << 5 | (m & 32) << 1 66 | return cmap 67 | 68 | 69 | def one_hot_mask(mask, cls_num): 70 | if len(mask.size()) == 3: 71 | mask = mask.unsqueeze(1) 72 | indices = torch.arange(0, cls_num + 1, 73 | device=mask.device).view(1, -1, 1, 1) 74 | return (mask == indices).float() 75 | 76 | 77 | def masked_image(image, colored_mask, mask, alpha=0.7): 78 | mask = np.expand_dims(mask > 0, axis=0) 79 | mask = np.repeat(mask, 3, axis=0) 80 | show_img = (image * alpha + colored_mask * 81 | (1 - alpha)) * mask + image * (1 - mask) 82 | return show_img 83 | 84 | 85 | def save_image(image, path): 86 | im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0))) 87 | im.save(path) 88 | 89 | 90 | def _save_mask(mask, path, squeeze_idx=None): 91 | if squeeze_idx is not None: 92 | unsqueezed_mask = mask * 0 93 | for idx in range(1, len(squeeze_idx)): 94 | obj_id = squeeze_idx[idx] 95 | mask_i = mask == idx 96 | unsqueezed_mask += (mask_i * obj_id).astype(np.uint8) 97 | mask = unsqueezed_mask 98 | mask = Image.fromarray(mask).convert('P') 99 | mask.putpalette(_palette) 100 | mask.save(path) 101 | 102 | 103 | def save_mask(mask_tensor, path, squeeze_idx=None): 104 | mask = mask_tensor.cpu().numpy().astype('uint8') 105 | threading.Thread(target=_save_mask, args=[mask, path, squeeze_idx]).start() 106 | 107 | 108 | def flip_tensor(tensor, dim=0): 109 | inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1, 110 | device=tensor.device).long() 111 | tensor = tensor.index_select(dim, inv_idx) 112 | return tensor 113 | 114 | 115 | def shuffle_obj_mask(mask): 116 | 117 | bs, obj_num, _, _ = mask.size() 118 | new_masks = [] 119 | for idx in range(bs): 120 | now_mask = mask[idx] 121 | random_matrix = torch.eye(obj_num, device=mask.device) 122 | fg = random_matrix[1:][torch.randperm(obj_num - 1)] 123 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) 124 | now_mask = torch.einsum('nm,nhw->mhw', random_matrix, now_mask) 125 | new_masks.append(now_mask) 126 | 127 | return torch.stack(new_masks, dim=0) 128 | -------------------------------------------------------------------------------- /utils/learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, 5 | base_lr, 6 | p, 7 | itr, 8 | max_itr, 9 | restart=1, 10 | warm_up_steps=1000, 11 | is_cosine_decay=False, 12 | min_lr=1e-5, 13 | encoder_lr_ratio=1.0, 14 | freeze_params=[]): 15 | 16 | if restart > 1: 17 | each_max_itr = int(math.ceil(float(max_itr) / restart)) 18 | itr = itr % each_max_itr 19 | warm_up_steps /= restart 20 | max_itr = each_max_itr 21 | 22 | if itr < warm_up_steps: 23 | now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps 24 | else: 25 | itr = itr - warm_up_steps 26 | max_itr = max_itr - warm_up_steps 27 | if is_cosine_decay: 28 | now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr / 29 | (max_itr + 1)) + 30 | 1.) * 0.5 31 | else: 32 | now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p 33 | 34 | for param_group in optimizer.param_groups: 35 | if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]: 36 | param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr 37 | else: 38 | param_group['lr'] = now_lr 39 | 40 | for freeze_param in freeze_params: 41 | if freeze_param in param_group["name"]: 42 | param_group['lr'] = 0 43 | param_group['weight_decay'] = 0 44 | break 45 | 46 | return now_lr 47 | 48 | 49 | def get_trainable_params(model, 50 | base_lr, 51 | weight_decay, 52 | use_frozen_bn=False, 53 | exclusive_wd_dict={}, 54 | no_wd_keys=[]): 55 | params = [] 56 | memo = set() 57 | total_param = 0 58 | for key, value in model.named_parameters(): 59 | if value in memo: 60 | continue 61 | total_param += value.numel() 62 | if not value.requires_grad: 63 | continue 64 | memo.add(value) 65 | wd = weight_decay 66 | for exclusive_key in exclusive_wd_dict.keys(): 67 | if exclusive_key in key: 68 | wd = exclusive_wd_dict[exclusive_key] 69 | break 70 | if len(value.shape) == 1: # normalization layers 71 | if 'bias' in key: # bias requires no weight decay 72 | wd = 0. 73 | elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay 74 | wd = 0. 75 | elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder 76 | wd = 0. 77 | else: 78 | for no_wd_key in no_wd_keys: 79 | if no_wd_key in key: 80 | wd = 0. 81 | break 82 | params += [{ 83 | "params": [value], 84 | "lr": base_lr, 85 | "weight_decay": wd, 86 | "name": key 87 | }] 88 | 89 | print('Total Param: {:.2f}M'.format(total_param / 1e6)) 90 | return params 91 | 92 | 93 | def freeze_params(module): 94 | for p in module.parameters(): 95 | p.requires_grad = False 96 | 97 | 98 | def calculate_params(state_dict): 99 | memo = set() 100 | total_param = 0 101 | for key, value in state_dict.items(): 102 | if value in memo: 103 | continue 104 | memo.add(value) 105 | total_param += value.numel() 106 | print('Total Param: {:.2f}M'.format(total_param / 1e6)) 107 | -------------------------------------------------------------------------------- /utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0): 5 | all_matrix = [] 6 | for idx in range(num): 7 | random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id)) 8 | if keep_first: 9 | fg = random_matrix[1:][torch.randperm(dim - 1)] 10 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) 11 | else: 12 | random_matrix = random_matrix[torch.randperm(dim)] 13 | all_matrix.append(random_matrix) 14 | return torch.stack(all_matrix, dim=0) 15 | 16 | 17 | def truncated_normal_(tensor, mean=0, std=.02): 18 | size = tensor.shape 19 | tmp = tensor.new_empty(size + (4, )).normal_() 20 | valid = (tmp < 2) & (tmp > -2) 21 | ind = valid.max(-1, keepdim=True)[1] 22 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 23 | tensor.data.mul_(std).add_(mean) 24 | return tensor 25 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self, momentum=0.999): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | self.long_count = 0 12 | self.momentum = momentum 13 | self.moving_avg = 0 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | if self.long_count == 0: 23 | self.moving_avg = val 24 | else: 25 | momentum = min(self.momentum, 1. - 1. / self.long_count) 26 | self.moving_avg = self.moving_avg * momentum + val * (1 - momentum) 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.long_count += n 31 | self.avg = self.sum / self.count 32 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pytorch_iou(pred, target, obj_num, epsilon=1e-6): 5 | ''' 6 | pred: [bs, h, w] 7 | target: [bs, h, w] 8 | obj_num: [bs] 9 | ''' 10 | bs = pred.size(0) 11 | all_iou = [] 12 | for idx in range(bs): 13 | now_pred = pred[idx].unsqueeze(0) 14 | now_target = target[idx].unsqueeze(0) 15 | now_obj_num = obj_num[idx] 16 | 17 | obj_ids = torch.arange(0, now_obj_num + 1, 18 | device=now_pred.device).int().view(-1, 1, 1) 19 | if obj_ids.size(0) == 1: # only contain background 20 | continue 21 | else: 22 | obj_ids = obj_ids[1:] 23 | now_pred = (now_pred == obj_ids).float() 24 | now_target = (now_target == obj_ids).float() 25 | 26 | intersection = (now_pred * now_target).sum((1, 2)) 27 | union = ((now_pred + now_target) > 0).float().sum((1, 2)) 28 | 29 | now_iou = (intersection + epsilon) / (union + epsilon) 30 | 31 | all_iou.append(now_iou.mean()) 32 | if len(all_iou) > 0: 33 | all_iou = torch.stack(all_iou).mean() 34 | else: 35 | all_iou = torch.ones((1), device=pred.device) 36 | return all_iou 37 | --------------------------------------------------------------------------------