├── CNAME ├── README.md ├── aot_plus ├── LICENSE ├── Makefile ├── README.md ├── configs │ ├── default.py │ ├── models │ │ ├── aotb.py │ │ ├── aotl.py │ │ ├── aots.py │ │ ├── aott.py │ │ ├── default.py │ │ ├── r101_aotl.py │ │ ├── r50_aotl.py │ │ ├── rs101_aotl.py │ │ └── swinb_aotl.py │ ├── pre.py │ ├── pre_dav.py │ ├── pre_vost.py │ ├── pre_ytb.py │ ├── pre_ytb_dav.py │ └── ytb.py ├── dataloaders │ ├── __init__.py │ ├── eval_datasets.py │ ├── image_transforms.py │ ├── train_datasets.py │ └── video_transforms.py ├── docker │ └── Dockerfile ├── networks │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── __init__.cpython-38.pyc │ ├── decoders │ │ ├── __init__.py │ │ └── fpn.py │ ├── encoders │ │ ├── __init__.py │ │ ├── mobilenetv2.py │ │ ├── mobilenetv3.py │ │ ├── resnest │ │ │ ├── __init__.py │ │ │ ├── resnest.py │ │ │ ├── resnet.py │ │ │ └── splat.py │ │ ├── resnet.py │ │ └── swin │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── swin_transformer.py │ ├── engines │ │ ├── __init__.py │ │ └── aot_engine.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── basic.py │ │ ├── loss.py │ │ ├── normalization.py │ │ ├── position.py │ │ └── transformer.py │ ├── managers │ │ ├── evaluator.py │ │ └── trainer.py │ └── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── aot.cpython-37.pyc │ │ └── aot.cpython-38.pyc │ │ └── aot.py ├── tools │ ├── demo.py │ ├── eval.py │ └── train.py ├── train_vost.sh └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── ema.py │ ├── eval.py │ ├── image.py │ ├── learning.py │ ├── math.py │ ├── meters.py │ └── metric.py ├── data.html ├── evaluation ├── .gitignore ├── LICENSE ├── README.md ├── evaluation_method.py └── source │ ├── __init__.py │ ├── dataset.py │ ├── evaluation.py │ ├── metrics.py │ ├── results.py │ └── utils.py ├── figs ├── .DS_Store ├── 1403_cut_corn.gif ├── 1GijsAKflo683C-s55149QvaQ47_PXz0q.png ├── 3525_break_eggs.gif ├── 4751_mold_clay.gif ├── 9672_divide_wheel.gif ├── 9699_divide_wheel.gif ├── bear.gif ├── bibtex.txt ├── firstpage.png └── skater.gif ├── index.html └── workshop.html /CNAME: -------------------------------------------------------------------------------- 1 | www.vostdataset.org 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VOST: Video Object Segmentation under Transformations 2 | 3 | The evaluation metric implementation is in the [evaluation](evaluation) folder. 4 | AOT+ baseline implementation and model checkpoints are available in the [aot_plus](aot_plus) folder. 5 | -------------------------------------------------------------------------------- /aot_plus/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, z-x-yang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /aot_plus/Makefile: -------------------------------------------------------------------------------- 1 | # Handy commands: 2 | # - `make docker-build`: builds DOCKERIMAGE (default: `packnet-sfm:latest`) 3 | PROJECT ?= aot_plus 4 | WORKSPACE ?= /workspace/$(PROJECT) 5 | DOCKER_IMAGE ?= ${PROJECT}:latest 6 | 7 | SHMSIZE ?= 444G 8 | DOCKER_OPTS := \ 9 | --name ${PROJECT} \ 10 | --rm -it \ 11 | --shm-size=${SHMSIZE} \ 12 | -e AWS_DEFAULT_REGION \ 13 | -e AWS_ACCESS_KEY_ID \ 14 | -e AWS_SECRET_ACCESS_KEY \ 15 | -e HOST_HOSTNAME= \ 16 | -e NCCL_DEBUG=VERSION \ 17 | -e DISPLAY=${DISPLAY} \ 18 | -e XAUTHORITY \ 19 | -e NVIDIA_DRIVER_CAPABILITIES=all \ 20 | -v ~/.aws:/root/.aws \ 21 | -v /root/.ssh:/root/.ssh \ 22 | -v ~/.cache:/root/.cache \ 23 | -v /data:/data \ 24 | -v /mnt/fsx/:/mnt/fsx \ 25 | -v /dev/null:/dev/raw1394 \ 26 | -v /tmp:/tmp \ 27 | -v /tmp/.X11-unix/X0:/tmp/.X11-unix/X0 \ 28 | -v /var/run/docker.sock:/var/run/docker.sock \ 29 | -v ${PWD}:${WORKSPACE} \ 30 | -w ${WORKSPACE} \ 31 | --privileged \ 32 | --ipc=host \ 33 | --network=host 34 | 35 | NGPUS=$(shell nvidia-smi -L | wc -l) 36 | 37 | 38 | .PHONY: all clean docker-build docker-overfit-pose 39 | 40 | all: clean 41 | 42 | clean: 43 | find . -name "*.pyc" | xargs rm -f && \ 44 | find . -name "__pycache__" | xargs rm -rf 45 | 46 | docker-build: 47 | docker build \ 48 | -f docker/Dockerfile \ 49 | -t ${DOCKER_IMAGE} . 50 | 51 | docker-start-interactive: docker-build 52 | nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash 53 | 54 | docker-start-jupyter: docker-build 55 | nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ 56 | bash -c "jupyter notebook --port=8888 -ip=0.0.0.0 --allow-root --no-browser" 57 | 58 | docker-run: docker-build 59 | nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ 60 | bash -c "${COMMAND}" 61 | 62 | docker-run-mpi: docker-build 63 | nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ 64 | bash -c "${MPI_CMD} ${COMMAND}" -------------------------------------------------------------------------------- /aot_plus/README.md: -------------------------------------------------------------------------------- 1 | # AOT+ 2 | This is the implementation of AOT+ baseline used in [VOST](https://www.vostdataset.org) dataset. The implementation is derived from [AOT](https://github.com/z-x-yang/AOT). 3 | 4 | ## Installation 5 | We provide a Docker file to re-create the environment which was used in our experiments under `$AOT_ROOT/docker/Dockerfile`. You can either configure the environment yourself using the docker file as a guide or build it via: 6 | ~~~ 7 | cd $AOT_ROOT 8 | make docker-build 9 | make docker-start-interactive 10 | ~~~ 11 | 12 | 13 | ## Training and evalaution 14 | 1. Link the VOST folder in [datasets/VOST](datasets/VOST) 15 | 16 | 2. To evaluate the pre-trained AOT+ model on the validation set of VOST download it from [here](https://tri-ml-public.s3.amazonaws.com/datasets/aotplus.pth) into the [pretrain_models](pretrain_models) folder and run the following command: 17 | 18 | ~~~ 19 | python tools/eval.py --exp_name aotplus --stage pre_vost --model r50_aotl --dataset vost --split val --gpu_num 8 --ckpt_path pretrain_models/aotplus.pth --ms 1.0 1.1 1.2 0.9 0.8 20 | ~~~ 21 | To compute the metrics please refer to the [evaluation](../evaluation/) folder. 22 | 23 | 3. To train AOT+ on VOST yourself download the chekpoint pre-trained on static imges and YouTubeVOS from [here](https://tri-ml-public.s3.amazonaws.com/datasets/pre_ytb.pth) into the [pretrain_models](pretrain_models) and run this script: 24 | 25 | ~~~ 26 | sh train_vost.sh 27 | ~~~ 28 | 29 | 30 | ## Citations 31 | Please consider citing the related paper(s) in your publications if it helps your research. 32 | ``` 33 | @inproceedings{tokmakov2023breaking, 34 | title={Breaking the “Object” in Video Object Segmentation}, 35 | author={Tokmakov, Pavel and Li, Jie and Gaidon, Adrien}, 36 | booktitle={CVPR}, 37 | year={2023} 38 | } 39 | 40 | @inproceedings{yang2021aot, 41 | title={Associating Objects with Transformers for Video Object Segmentation}, 42 | author={Yang, Zongxin and Wei, Yunchao and Yang, Yi}, 43 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 44 | year={2021} 45 | } 46 | ``` 47 | 48 | ## License 49 | This project is released under the BSD-3-Clause license. See [LICENSE](LICENSE) for additional details. 50 | -------------------------------------------------------------------------------- /aot_plus/configs/default.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | 5 | class DefaultEngineConfig(): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | model_cfg = importlib.import_module('configs.models.' + 8 | model).ModelConfig() 9 | self.__dict__.update(model_cfg.__dict__) # add model config 10 | 11 | self.EXP_NAME = exp_name + '_' + self.MODEL_NAME 12 | 13 | self.STAGE_NAME = 'default' 14 | 15 | self.DATASETS = ['youtubevos'] 16 | self.DATA_WORKERS = 8 17 | self.DATA_RANDOMCROP = (465, 18 | 465) if self.MODEL_ALIGN_CORNERS else (464, 19 | 464) 20 | self.DATA_RANDOMFLIP = 0.5 21 | self.DATA_MAX_CROP_STEPS = 10 22 | self.DATA_SHORT_EDGE_LEN = 480 23 | self.DATA_MIN_SCALE_FACTOR = 0.7 24 | self.DATA_MAX_SCALE_FACTOR = 1.3 25 | self.DATA_RANDOM_REVERSE_SEQ = True 26 | self.DATA_SEQ_LEN = 5 27 | self.DATA_DAVIS_REPEAT = 5 28 | self.DATA_VOST_REPEAT = 1 29 | self.DATA_VOST_IGNORE_THRESH = 0.2 30 | self.DATA_VOST_ALL_FRAMES = False 31 | self.DATA_VOST_VALID_FRAMES = False 32 | self.DATA_RANDOM_GAP_DAVIS = 12 # max frame interval between two sampled frames for DAVIS (24fps) 33 | self.DATA_RANDOM_GAP_YTB = 3 # max frame interval between two sampled frames for YouTube-VOS (6fps) 34 | self.DATA_RANDOM_GAP_VOST = 3 35 | self.DATA_RANDOM_GAP_VISOR = 1 36 | self.DATA_DYNAMIC_MERGE_PROB = 0.2 37 | self.IGNORE_IN_MERGE = True 38 | self.DATA_VISOR_REPEAT = 1 39 | self.DATA_VISOR_IGNORE_THRESH = 0.2 40 | 41 | self.PRETRAIN = True 42 | self.PRETRAIN_FULL = False # if False, load encoder only 43 | self.PRETRAIN_MODEL = '' 44 | 45 | self.TRAIN_TOTAL_STEPS = 100000 46 | self.TRAIN_START_STEP = 0 47 | self.TRAIN_WEIGHT_DECAY = 0.07 48 | self.TRAIN_WEIGHT_DECAY_EXCLUSIVE = { 49 | # 'encoder.': 0.01 50 | } 51 | self.TRAIN_WEIGHT_DECAY_EXEMPTION = [ 52 | 'absolute_pos_embed', 'relative_position_bias_table', 53 | 'relative_emb_v', 'conv_out' 54 | ] 55 | self.TRAIN_LR = 2e-4 56 | self.TRAIN_LR_MIN = 2e-5 if 'mobilenetv2' in self.MODEL_ENCODER else 1e-5 57 | self.TRAIN_LR_POWER = 0.9 58 | self.TRAIN_LR_ENCODER_RATIO = 0.1 59 | self.TRAIN_LR_WARM_UP_RATIO = 0.05 60 | self.TRAIN_LR_COSINE_DECAY = False 61 | self.TRAIN_LR_RESTART = 1 62 | self.TRAIN_LR_UPDATE_STEP = 1 63 | self.TRAIN_AUX_LOSS_WEIGHT = 1.0 64 | self.TRAIN_AUX_LOSS_RATIO = 1.0 65 | self.TRAIN_OPT = 'adamw' 66 | self.TRAIN_SGD_MOMENTUM = 0.9 67 | self.TRAIN_GPUS = 4 68 | self.TRAIN_BATCH_SIZE = 16 69 | self.TRAIN_TBLOG = False 70 | self.TRAIN_TBLOG_STEP = 50 71 | self.TRAIN_LOG_STEP = 20 72 | self.TRAIN_IMG_LOG = True 73 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15 74 | self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank'] 75 | self.TRAIN_SEQ_TRAINING_START_RATIO = 0.5 76 | self.TRAIN_HARD_MINING_RATIO = 0.5 77 | self.TRAIN_EMA_RATIO = 0.1 78 | self.TRAIN_CLIP_GRAD_NORM = 5. 79 | self.TRAIN_SAVE_STEP = 1000 80 | self.TRAIN_EVAL = False 81 | self.TRAIN_MAX_KEEP_CKPT = 8 82 | self.TRAIN_RESUME = False 83 | self.TRAIN_RESUME_CKPT = None 84 | self.TRAIN_RESUME_STEP = 0 85 | self.TRAIN_AUTO_RESUME = True 86 | self.TRAIN_DATASET_FULL_RESOLUTION = False 87 | self.TRAIN_ENABLE_PREV_FRAME = False 88 | self.TRAIN_ENCODER_FREEZE_AT = 2 89 | self.TRAIN_LSTT_EMB_DROPOUT = 0. 90 | self.TRAIN_LSTT_ID_DROPOUT = 0. 91 | self.TRAIN_LSTT_DROPPATH = 0.1 92 | self.TRAIN_LSTT_DROPPATH_SCALING = False 93 | self.TRAIN_LSTT_DROPPATH_LST = False 94 | self.TRAIN_LSTT_LT_DROPOUT = 0. 95 | self.TRAIN_LSTT_ST_DROPOUT = 0. 96 | 97 | self.TEST_GPU_ID = 0 98 | self.TEST_GPU_NUM = 1 99 | self.TEST_FRAME_LOG = False 100 | self.TEST_DATASET = 'youtubevos' 101 | self.TEST_DATASET_FULL_RESOLUTION = False 102 | self.TEST_DATASET_SPLIT = 'val' 103 | self.TEST_CKPT_PATH = None 104 | # if "None", evaluate the latest checkpoint. 105 | self.TEST_CKPT_STEP = None 106 | self.TEST_FLIP = False 107 | self.TEST_MULTISCALE = [1] 108 | self.TEST_MIN_SIZE = None 109 | self.TEST_MAX_SIZE = 800 * 1.3 110 | self.TEST_WORKERS = 4 111 | 112 | # GPU distribution 113 | self.DIST_ENABLE = True 114 | self.DIST_BACKEND = "nccl" # "gloo" 115 | self.DIST_URL = "tcp://127.0.0.1:13241" 116 | self.DIST_START_GPU = 0 117 | 118 | def init_dir(self): 119 | self.DIR_DATA = './datasets' 120 | self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS') 121 | self.DIR_VOST = os.path.join(self.DIR_DATA, 'VOST') 122 | self.DIR_VISOR = os.path.join(self.DIR_DATA, 'VISOR') 123 | self.DIR_YTB = os.path.join(self.DIR_DATA, 'YTB') 124 | self.DIR_STATIC = os.path.join(self.DIR_DATA, 'Static') 125 | 126 | self.DIR_ROOT = './results' 127 | 128 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, self.EXP_NAME, 129 | self.STAGE_NAME) 130 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt') 131 | self.DIR_EMA_CKPT = os.path.join(self.DIR_RESULT, 'ema_ckpt') 132 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log') 133 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard') 134 | self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img') 135 | self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval') 136 | 137 | for path in [ 138 | self.DIR_RESULT, self.DIR_CKPT, self.DIR_EMA_CKPT, 139 | self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, 140 | self.DIR_TB_LOG 141 | ]: 142 | if not os.path.isdir(path): 143 | try: 144 | os.makedirs(path) 145 | except Exception as inst: 146 | print(inst) 147 | print('Failed to make dir: {}.'.format(path)) 148 | -------------------------------------------------------------------------------- /aot_plus/configs/models/aotb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTB' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | -------------------------------------------------------------------------------- /aot_plus/configs/models/aotl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTL' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | 11 | self.TRAIN_LONG_TERM_MEM_GAP = 2 12 | 13 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/models/aots.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTS' 8 | 9 | self.MODEL_LSTT_NUM = 2 10 | -------------------------------------------------------------------------------- /aot_plus/configs/models/aott.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTT' 8 | -------------------------------------------------------------------------------- /aot_plus/configs/models/default.py: -------------------------------------------------------------------------------- 1 | class DefaultModelConfig(): 2 | def __init__(self): 3 | self.MODEL_NAME = 'AOTDefault' 4 | 5 | self.MODEL_VOS = 'aot' 6 | self.MODEL_ENGINE = 'aotengine' 7 | self.MODEL_ALIGN_CORNERS = True 8 | self.MODEL_ENCODER = 'mobilenetv2' 9 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/mobilenet_v2-b0353104.pth' # https://download.pytorch.org/models/mobilenet_v2-b0353104.pth 10 | self.MODEL_ENCODER_DIM = [24, 32, 96, 1280] # 4x, 8x, 16x, 16x 11 | self.MODEL_ENCODER_EMBEDDING_DIM = 256 12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = True 13 | self.MODEL_SIMPLIFIED_STM = True 14 | self.MODEL_STM_STOPGRAD = False 15 | self.MODEL_LINEAR_Q = True 16 | self.MODEL_NORM_INP = True 17 | self.MODEL_RECURRENT_LTM = False 18 | self.MODEL_RECURRENT_STM = True 19 | self.MODEL_JOINT_LONGATT = True 20 | self.MODEL_FREEZE_BN = True 21 | self.MODEL_FREEZE_BACKBONE = False 22 | self.MODEL_MAX_OBJ_NUM = 10 23 | self.MODEL_IGNORE_TOKEN = True 24 | self.MODEL_SELF_HEADS = 8 25 | self.MODEL_ATT_HEADS = 8 26 | self.MODEL_LSTT_NUM = 1 27 | self.MODEL_EPSILON = 1e-5 28 | self.MODEL_USE_PREV_PROB = False 29 | 30 | self.TRAIN_LONG_TERM_MEM_GAP = 9999 31 | 32 | self.TEST_LONG_TERM_MEM_GAP = 9999 33 | -------------------------------------------------------------------------------- /aot_plus/configs/models/r101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/models/r50_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R50_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet50' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/models/rs101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnest101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnest101-22405ba7.pth' # https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/models/swinb_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'SwinB_AOTL' 8 | 9 | self.MODEL_ENCODER = 'swin_base' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth 11 | self.MODEL_ALIGN_CORNERS = False 12 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x 13 | self.MODEL_LSTT_NUM = 3 14 | 15 | self.TRAIN_LONG_TERM_MEM_GAP = 2 16 | 17 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/pre.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultEngineConfig 2 | 3 | 4 | class EngineConfig(DefaultEngineConfig): 5 | def __init__(self, exp_name='default', model='AOTT'): 6 | super().__init__(exp_name, model) 7 | self.STAGE_NAME = 'pre' 8 | 9 | self.init_dir() 10 | 11 | self.DATASETS = ['static'] 12 | 13 | self.DATA_DYNAMIC_MERGE_PROB = 1.0 14 | 15 | self.TRAIN_LR = 4e-4 16 | self.TRAIN_LR_MIN = 2e-5 17 | self.TRAIN_WEIGHT_DECAY = 0.03 18 | self.TRAIN_SEQ_TRAINING_START_RATIO = 1.0 19 | self.TRAIN_AUX_LOSS_RATIO = 0.1 20 | 21 | self.MODEL_SIMPLIFIED_STM = True 22 | self.MODEL_LINEAR_Q = True 23 | self.MODEL_RECURRENT_STM = True 24 | self.MODEL_JOINT_LONGATT = True 25 | -------------------------------------------------------------------------------- /aot_plus/configs/pre_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['davis2017'] 13 | 14 | self.TRAIN_TOTAL_STEPS = 50000 15 | 16 | pretrain_stage = 'PRE' 17 | pretrain_ckpt = 'save_step_100000.pth' 18 | self.PRETRAIN_FULL = True # if False, load encoder only 19 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 20 | self.EXP_NAME, pretrain_stage, 21 | 'ema_ckpt', pretrain_ckpt) 22 | -------------------------------------------------------------------------------- /aot_plus/configs/pre_vost.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'pre_vost' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['vost'] 13 | self.TRAIN_TOTAL_STEPS = 20000 14 | self.DATA_SEQ_LEN = 15 15 | self.TRAIN_LONG_TERM_MEM_GAP = 4 16 | self.MODEL_LINEAR_Q = False 17 | self.MODEL_JOINT_LONGATT = False 18 | self.MODEL_SIMPLIFIED_STM = True 19 | self.MODEL_RECURRENT_STM = True 20 | self.MODEL_IGNORE_TOKEN = True 21 | 22 | self.PRETRAIN_FULL = True # if False, load encoder only 23 | self.PRETRAIN_MODEL = os.path.join('pretrain_models', 'pre_ytb.pth') -------------------------------------------------------------------------------- /aot_plus/configs/pre_ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'pre_ytb' 9 | 10 | self.init_dir() 11 | 12 | pretrain_stage = 'PRE' 13 | pretrain_ckpt = 'save_step_100000.pth' 14 | self.DATA_SEQ_LEN = 10 15 | self.TRAIN_LONG_TERM_MEM_GAP = 4 16 | self.MODEL_JOINT_LONGATT = False 17 | self.MODEL_STM_STOPGRAD = False 18 | self.TRAIN_TOTAL_STEPS = 80000 19 | self.MODEL_LINEAR_Q = True 20 | self.PRETRAIN_FULL = True # if False, load encoder only 21 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 22 | self.EXP_NAME, pretrain_stage, 23 | 'ema_ckpt', pretrain_ckpt) 24 | -------------------------------------------------------------------------------- /aot_plus/configs/pre_ytb_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_YTB_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['youtubevos', 'davis2017'] 13 | 14 | pretrain_stage = 'PRE' 15 | pretrain_ckpt = 'save_step_100000.pth' 16 | self.PRETRAIN_FULL = True # if False, load encoder only 17 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 18 | self.EXP_NAME, pretrain_stage, 19 | 'ema_ckpt', pretrain_ckpt) 20 | -------------------------------------------------------------------------------- /aot_plus/configs/ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'YTB' 9 | 10 | self.init_dir() 11 | -------------------------------------------------------------------------------- /aot_plus/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/dataloaders/__init__.py -------------------------------------------------------------------------------- /aot_plus/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.9.0" 2 | ARG CUDA="11.1" 3 | ARG CUDNN="8" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | RUN rm -rf /etc/apt/sources.list.d/cuda.list 7 | RUN rm -rf /etc/apt/sources.list.d/nvidia-ml.list 8 | RUN apt-key del 7fa2af80 9 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 10 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 11 | 12 | ENV PROJECT=aot 13 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX" 14 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" 15 | ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" 16 | ARG python=3.7 17 | ENV PYTHON_VERSION=${python} 18 | 19 | # To fix GPG key error when running apt-get update 20 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 21 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 22 | 23 | # Core tools 24 | RUN apt-get update && apt-get install -y \ 25 | cmake \ 26 | curl \ 27 | docker.io \ 28 | ffmpeg \ 29 | git \ 30 | htop \ 31 | libsm6 \ 32 | libxext6 \ 33 | libglib2.0-0 \ 34 | libsm6 \ 35 | libxrender-dev \ 36 | libxext6 \ 37 | ninja-build \ 38 | unzip \ 39 | vim \ 40 | wget \ 41 | sudo \ 42 | && apt-get clean \ 43 | && rm -rf /var/lib/apt/lists/* 44 | 45 | # Install OpenSSH for MPI to communicate between containers 46 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ 47 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ 48 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config 49 | 50 | RUN ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python 51 | 52 | RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ 53 | python get-pip.py && \ 54 | rm get-pip.py 55 | 56 | # Install Pydata and other deps 57 | RUN pip install easydict scipy numpy pyquaternion matplotlib jupyter h5py \ 58 | awscli tqdm progress path.py pyyaml opencv-python pandas \ 59 | numba cython scikit-learn moviepy imageio yacs Pillow scikit-image 60 | 61 | # Settings for S3 62 | RUN aws configure set default.s3.max_concurrent_requests 100 && \ 63 | aws configure set default.s3.max_queue_size 10000 64 | 65 | #pip install spatial-correlation-sampler 66 | 67 | # Expose Port for jupyter (8888) 68 | EXPOSE 8888 69 | 70 | # create project workspace dir 71 | WORKDIR /workspace/${PROJECT} 72 | 73 | ENV PYTHONPATH="/workspace/${PROJECT}:$PYTHONPATH" 74 | RUN git config --global --add safe.directory /workspace/${PROJECT} 75 | -------------------------------------------------------------------------------- /aot_plus/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/__init__.py -------------------------------------------------------------------------------- /aot_plus/networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /aot_plus/networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /aot_plus/networks/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.decoders.fpn import FPNSegmentationHead 2 | 3 | 4 | def build_decoder(name, **kwargs): 5 | 6 | if name == 'fpn': 7 | return FPNSegmentationHead(**kwargs) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /aot_plus/networks/decoders/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.layers.basic import ConvGN 5 | 6 | 7 | class FPNSegmentationHead(nn.Module): 8 | def __init__(self, 9 | in_dim, 10 | out_dim, 11 | decode_intermediate_input=True, 12 | hidden_dim=256, 13 | shortcut_dims=[24, 32, 96, 1280], 14 | align_corners=True): 15 | super().__init__() 16 | self.align_corners = align_corners 17 | 18 | self.decode_intermediate_input = decode_intermediate_input 19 | 20 | self.conv_in = ConvGN(in_dim, hidden_dim, 1) 21 | 22 | self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3) 23 | self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3) 24 | self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3) 25 | 26 | self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1) 27 | self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1) 28 | self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1) 29 | 30 | self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1) 31 | 32 | self._init_weight() 33 | 34 | def forward(self, inputs, shortcuts): 35 | 36 | if self.decode_intermediate_input: 37 | x = torch.cat(inputs, dim=1) 38 | else: 39 | x = inputs[-1] 40 | 41 | x = F.relu_(self.conv_in(x)) 42 | x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x)) 43 | 44 | x = F.interpolate(x, 45 | size=shortcuts[-3].size()[-2:], 46 | mode="bilinear", 47 | align_corners=self.align_corners) 48 | x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x)) 49 | 50 | x = F.interpolate(x, 51 | size=shortcuts[-4].size()[-2:], 52 | mode="bilinear", 53 | align_corners=self.align_corners) 54 | x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x)) 55 | 56 | x = self.conv_out(x) 57 | 58 | return x 59 | 60 | def _init_weight(self): 61 | for p in self.parameters(): 62 | if p.dim() > 1: 63 | nn.init.xavier_uniform_(p) 64 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.encoders.mobilenetv2 import MobileNetV2 2 | from networks.encoders.mobilenetv3 import MobileNetV3Large 3 | from networks.encoders.resnet import ResNet101, ResNet50 4 | from networks.encoders.resnest import resnest 5 | from networks.encoders.swin import build_swin_model 6 | from networks.layers.normalization import FrozenBatchNorm2d 7 | from torch import nn 8 | 9 | 10 | def build_encoder(name, frozen_bn=True, freeze_at=-1): 11 | if frozen_bn: 12 | BatchNorm = FrozenBatchNorm2d 13 | else: 14 | BatchNorm = nn.BatchNorm2d 15 | 16 | if name == 'mobilenetv2': 17 | return MobileNetV2(16, BatchNorm, freeze_at=freeze_at) 18 | elif name == 'mobilenetv3': 19 | return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at) 20 | elif name == 'resnet50': 21 | return ResNet50(16, BatchNorm, freeze_at=freeze_at) 22 | elif name == 'resnet101': 23 | return ResNet101(16, BatchNorm, freeze_at=freeze_at) 24 | elif name == 'resnest50': 25 | return resnest.resnest50(norm_layer=BatchNorm, 26 | dilation=2, 27 | freeze_at=freeze_at) 28 | elif name == 'resnest101': 29 | return resnest.resnest101(norm_layer=BatchNorm, 30 | dilation=2, 31 | freeze_at=freeze_at) 32 | elif 'swin' in name: 33 | return build_swin_model(name, freeze_at=freeze_at) 34 | else: 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import Tensor 3 | from typing import Callable, Optional, List 4 | from utils.learning import freeze_params 5 | 6 | __all__ = ['MobileNetV2'] 7 | 8 | 9 | def _make_divisible(v: float, 10 | divisor: int, 11 | min_value: Optional[int] = None) -> int: 12 | """ 13 | This function is taken from the original tf repo. 14 | It ensures that all layers have a channel number that is divisible by 8 15 | It can be seen here: 16 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 17 | """ 18 | if min_value is None: 19 | min_value = divisor 20 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 21 | # Make sure that round down does not go down by more than 10%. 22 | if new_v < 0.9 * v: 23 | new_v += divisor 24 | return new_v 25 | 26 | 27 | class ConvBNActivation(nn.Sequential): 28 | def __init__( 29 | self, 30 | in_planes: int, 31 | out_planes: int, 32 | kernel_size: int = 3, 33 | stride: int = 1, 34 | groups: int = 1, 35 | padding: int = -1, 36 | norm_layer: Optional[Callable[..., nn.Module]] = None, 37 | activation_layer: Optional[Callable[..., nn.Module]] = None, 38 | dilation: int = 1, 39 | ) -> None: 40 | if padding == -1: 41 | padding = (kernel_size - 1) // 2 * dilation 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if activation_layer is None: 45 | activation_layer = nn.ReLU6 46 | super().__init__( 47 | nn.Conv2d(in_planes, 48 | out_planes, 49 | kernel_size, 50 | stride, 51 | padding, 52 | dilation=dilation, 53 | groups=groups, 54 | bias=False), norm_layer(out_planes), 55 | activation_layer(inplace=True)) 56 | self.out_channels = out_planes 57 | 58 | 59 | # necessary for backwards compatibility 60 | ConvBNReLU = ConvBNActivation 61 | 62 | 63 | class InvertedResidual(nn.Module): 64 | def __init__( 65 | self, 66 | inp: int, 67 | oup: int, 68 | stride: int, 69 | dilation: int, 70 | expand_ratio: int, 71 | norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: 72 | super(InvertedResidual, self).__init__() 73 | self.stride = stride 74 | assert stride in [1, 2] 75 | 76 | if norm_layer is None: 77 | norm_layer = nn.BatchNorm2d 78 | 79 | self.kernel_size = 3 80 | self.dilation = dilation 81 | 82 | hidden_dim = int(round(inp * expand_ratio)) 83 | self.use_res_connect = self.stride == 1 and inp == oup 84 | 85 | layers: List[nn.Module] = [] 86 | if expand_ratio != 1: 87 | # pw 88 | layers.append( 89 | ConvBNReLU(inp, 90 | hidden_dim, 91 | kernel_size=1, 92 | norm_layer=norm_layer)) 93 | layers.extend([ 94 | # dw 95 | ConvBNReLU(hidden_dim, 96 | hidden_dim, 97 | stride=stride, 98 | dilation=dilation, 99 | groups=hidden_dim, 100 | norm_layer=norm_layer), 101 | # pw-linear 102 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 103 | norm_layer(oup), 104 | ]) 105 | self.conv = nn.Sequential(*layers) 106 | self.out_channels = oup 107 | self._is_cn = stride > 1 108 | 109 | def forward(self, x: Tensor) -> Tensor: 110 | if self.use_res_connect: 111 | return x + self.conv(x) 112 | else: 113 | return self.conv(x) 114 | 115 | 116 | class MobileNetV2(nn.Module): 117 | def __init__(self, 118 | output_stride=8, 119 | norm_layer: Optional[Callable[..., nn.Module]] = None, 120 | width_mult: float = 1.0, 121 | inverted_residual_setting: Optional[List[List[int]]] = None, 122 | round_nearest: int = 8, 123 | block: Optional[Callable[..., nn.Module]] = None, 124 | freeze_at=0) -> None: 125 | """ 126 | MobileNet V2 main class 127 | Args: 128 | num_classes (int): Number of classes 129 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 130 | inverted_residual_setting: Network structure 131 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 132 | Set to 1 to turn off rounding 133 | block: Module specifying inverted residual building block for mobilenet 134 | norm_layer: Module specifying the normalization layer to use 135 | """ 136 | super(MobileNetV2, self).__init__() 137 | 138 | if block is None: 139 | block = InvertedResidual 140 | 141 | if norm_layer is None: 142 | norm_layer = nn.BatchNorm2d 143 | 144 | last_channel = 1280 145 | input_channel = 32 146 | current_stride = 1 147 | rate = 1 148 | 149 | if inverted_residual_setting is None: 150 | inverted_residual_setting = [ 151 | # t, c, n, s 152 | [1, 16, 1, 1], 153 | [6, 24, 2, 2], 154 | [6, 32, 3, 2], 155 | [6, 64, 4, 2], 156 | [6, 96, 3, 1], 157 | [6, 160, 3, 2], 158 | [6, 320, 1, 1], 159 | ] 160 | 161 | # only check the first element, assuming user knows t,c,n,s are required 162 | if len(inverted_residual_setting) == 0 or len( 163 | inverted_residual_setting[0]) != 4: 164 | raise ValueError("inverted_residual_setting should be non-empty " 165 | "or a 4-element list, got {}".format( 166 | inverted_residual_setting)) 167 | 168 | # building first layer 169 | input_channel = _make_divisible(input_channel * width_mult, 170 | round_nearest) 171 | self.last_channel = _make_divisible( 172 | last_channel * max(1.0, width_mult), round_nearest) 173 | features: List[nn.Module] = [ 174 | ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer) 175 | ] 176 | current_stride *= 2 177 | # building inverted residual blocks 178 | for t, c, n, s in inverted_residual_setting: 179 | if current_stride == output_stride: 180 | stride = 1 181 | dilation = rate 182 | rate *= s 183 | else: 184 | stride = s 185 | dilation = 1 186 | current_stride *= s 187 | output_channel = _make_divisible(c * width_mult, round_nearest) 188 | for i in range(n): 189 | if i == 0: 190 | features.append( 191 | block(input_channel, output_channel, stride, dilation, 192 | t, norm_layer)) 193 | else: 194 | features.append( 195 | block(input_channel, output_channel, 1, rate, t, 196 | norm_layer)) 197 | input_channel = output_channel 198 | 199 | # building last several layers 200 | features.append( 201 | ConvBNReLU(input_channel, 202 | self.last_channel, 203 | kernel_size=1, 204 | norm_layer=norm_layer)) 205 | # make it nn.Sequential 206 | self.features = nn.Sequential(*features) 207 | 208 | self._initialize_weights() 209 | 210 | feature_4x = self.features[0:4] 211 | feautre_8x = self.features[4:7] 212 | feature_16x = self.features[7:14] 213 | feature_32x = self.features[14:] 214 | 215 | self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x] 216 | 217 | self.freeze(freeze_at) 218 | 219 | def forward(self, x): 220 | xs = [] 221 | for stage in self.stages: 222 | x = stage(x) 223 | xs.append(x) 224 | return xs 225 | 226 | def _initialize_weights(self): 227 | # weight initialization 228 | for m in self.modules(): 229 | if isinstance(m, nn.Conv2d): 230 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 231 | if m.bias is not None: 232 | nn.init.zeros_(m.bias) 233 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 234 | nn.init.ones_(m.weight) 235 | nn.init.zeros_(m.bias) 236 | elif isinstance(m, nn.Linear): 237 | nn.init.normal_(m.weight, 0, 0.01) 238 | nn.init.zeros_(m.bias) 239 | 240 | def freeze(self, freeze_at): 241 | if freeze_at >= 1: 242 | for m in self.stages[0][0]: 243 | freeze_params(m) 244 | 245 | for idx, stage in enumerate(self.stages, start=2): 246 | if freeze_at >= idx: 247 | freeze_params(stage) 248 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a MobileNetV3 Model as defined in: 3 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019). 4 | Searching for MobileNetV3 5 | arXiv preprint arXiv:1905.02244. 6 | """ 7 | 8 | import torch.nn as nn 9 | import math 10 | from utils.learning import freeze_params 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class h_sigmoid(nn.Module): 34 | def __init__(self, inplace=True): 35 | super(h_sigmoid, self).__init__() 36 | self.relu = nn.ReLU6(inplace=inplace) 37 | 38 | def forward(self, x): 39 | return self.relu(x + 3) / 6 40 | 41 | 42 | class h_swish(nn.Module): 43 | def __init__(self, inplace=True): 44 | super(h_swish, self).__init__() 45 | self.sigmoid = h_sigmoid(inplace=inplace) 46 | 47 | def forward(self, x): 48 | return x * self.sigmoid(x) 49 | 50 | 51 | class SELayer(nn.Module): 52 | def __init__(self, channel, reduction=4): 53 | super(SELayer, self).__init__() 54 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 55 | self.fc = nn.Sequential( 56 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 57 | nn.ReLU(inplace=True), 58 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 59 | h_sigmoid()) 60 | 61 | def forward(self, x): 62 | b, c, _, _ = x.size() 63 | y = self.avg_pool(x).view(b, c) 64 | y = self.fc(y).view(b, c, 1, 1) 65 | return x * y 66 | 67 | 68 | def conv_3x3_bn(inp, oup, stride, norm_layer=nn.BatchNorm2d): 69 | return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 70 | norm_layer(oup), h_swish()) 71 | 72 | 73 | def conv_1x1_bn(inp, oup, norm_layer=nn.BatchNorm2d): 74 | return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 75 | norm_layer(oup), h_swish()) 76 | 77 | 78 | class InvertedResidual(nn.Module): 79 | def __init__(self, 80 | inp, 81 | hidden_dim, 82 | oup, 83 | kernel_size, 84 | stride, 85 | use_se, 86 | use_hs, 87 | dilation=1, 88 | norm_layer=nn.BatchNorm2d): 89 | super(InvertedResidual, self).__init__() 90 | assert stride in [1, 2] 91 | 92 | self.identity = stride == 1 and inp == oup 93 | 94 | if inp == hidden_dim: 95 | self.conv = nn.Sequential( 96 | # dw 97 | nn.Conv2d(hidden_dim, 98 | hidden_dim, 99 | kernel_size, 100 | stride, (kernel_size - 1) // 2 * dilation, 101 | dilation=dilation, 102 | groups=hidden_dim, 103 | bias=False), 104 | norm_layer(hidden_dim), 105 | h_swish() if use_hs else nn.ReLU(inplace=True), 106 | # Squeeze-and-Excite 107 | SELayer(hidden_dim) if use_se else nn.Identity(), 108 | # pw-linear 109 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 110 | norm_layer(oup), 111 | ) 112 | else: 113 | self.conv = nn.Sequential( 114 | # pw 115 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 116 | norm_layer(hidden_dim), 117 | h_swish() if use_hs else nn.ReLU(inplace=True), 118 | # dw 119 | nn.Conv2d(hidden_dim, 120 | hidden_dim, 121 | kernel_size, 122 | stride, (kernel_size - 1) // 2 * dilation, 123 | dilation=dilation, 124 | groups=hidden_dim, 125 | bias=False), 126 | norm_layer(hidden_dim), 127 | # Squeeze-and-Excite 128 | SELayer(hidden_dim) if use_se else nn.Identity(), 129 | h_swish() if use_hs else nn.ReLU(inplace=True), 130 | # pw-linear 131 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 132 | norm_layer(oup), 133 | ) 134 | 135 | def forward(self, x): 136 | if self.identity: 137 | return x + self.conv(x) 138 | else: 139 | return self.conv(x) 140 | 141 | 142 | class MobileNetV3Large(nn.Module): 143 | def __init__(self, 144 | output_stride=16, 145 | norm_layer=nn.BatchNorm2d, 146 | width_mult=1., 147 | freeze_at=0): 148 | super(MobileNetV3Large, self).__init__() 149 | """ 150 | Constructs a MobileNetV3-Large model 151 | """ 152 | cfgs = [ 153 | # k, t, c, SE, HS, s 154 | [3, 1, 16, 0, 0, 1], 155 | [3, 4, 24, 0, 0, 2], 156 | [3, 3, 24, 0, 0, 1], 157 | [5, 3, 40, 1, 0, 2], 158 | [5, 3, 40, 1, 0, 1], 159 | [5, 3, 40, 1, 0, 1], 160 | [3, 6, 80, 0, 1, 2], 161 | [3, 2.5, 80, 0, 1, 1], 162 | [3, 2.3, 80, 0, 1, 1], 163 | [3, 2.3, 80, 0, 1, 1], 164 | [3, 6, 112, 1, 1, 1], 165 | [3, 6, 112, 1, 1, 1], 166 | [5, 6, 160, 1, 1, 2], 167 | [5, 6, 160, 1, 1, 1], 168 | [5, 6, 160, 1, 1, 1] 169 | ] 170 | self.cfgs = cfgs 171 | 172 | # building first layer 173 | input_channel = _make_divisible(16 * width_mult, 8) 174 | layers = [conv_3x3_bn(3, input_channel, 2, norm_layer)] 175 | # building inverted residual blocks 176 | block = InvertedResidual 177 | now_stride = 2 178 | rate = 1 179 | for k, t, c, use_se, use_hs, s in self.cfgs: 180 | if now_stride == output_stride: 181 | dilation = rate 182 | rate *= s 183 | s = 1 184 | else: 185 | dilation = 1 186 | now_stride *= s 187 | output_channel = _make_divisible(c * width_mult, 8) 188 | exp_size = _make_divisible(input_channel * t, 8) 189 | layers.append( 190 | block(input_channel, exp_size, output_channel, k, s, use_se, 191 | use_hs, dilation, norm_layer)) 192 | input_channel = output_channel 193 | 194 | self.features = nn.Sequential(*layers) 195 | self.conv = conv_1x1_bn(input_channel, exp_size, norm_layer) 196 | # building last several layers 197 | 198 | self._initialize_weights() 199 | 200 | feature_4x = self.features[0:4] 201 | feautre_8x = self.features[4:7] 202 | feature_16x = self.features[7:13] 203 | feature_32x = self.features[13:] 204 | 205 | self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x] 206 | 207 | self.freeze(freeze_at) 208 | 209 | def forward(self, x): 210 | xs = [] 211 | for stage in self.stages: 212 | x = stage(x) 213 | xs.append(x) 214 | xs[-1] = self.conv(xs[-1]) 215 | return xs 216 | 217 | def _initialize_weights(self): 218 | for m in self.modules(): 219 | if isinstance(m, nn.Conv2d): 220 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 221 | m.weight.data.normal_(0, math.sqrt(2. / n)) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | elif isinstance(m, nn.BatchNorm2d): 225 | m.weight.data.fill_(1) 226 | m.bias.data.zero_() 227 | elif isinstance(m, nn.Linear): 228 | n = m.weight.size(1) 229 | m.weight.data.normal_(0, 0.01) 230 | m.bias.data.zero_() 231 | 232 | def freeze(self, freeze_at): 233 | if freeze_at >= 1: 234 | for m in self.stages[0][0]: 235 | freeze_params(m) 236 | 237 | for idx, stage in enumerate(self.stages, start=2): 238 | if freeze_at >= idx: 239 | freeze_params(stage) 240 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnest/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnest import * 2 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnest/resnest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnet import ResNet, Bottleneck 3 | 4 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 5 | 6 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 7 | 8 | _model_sha256 = { 9 | name: checksum 10 | for checksum, name in [ 11 | ('528c19ca', 'resnest50'), 12 | ('22405ba7', 'resnest101'), 13 | ('75117900', 'resnest200'), 14 | ('0cc87c48', 'resnest269'), 15 | ] 16 | } 17 | 18 | 19 | def short_hash(name): 20 | if name not in _model_sha256: 21 | raise ValueError( 22 | 'Pretrained model for {name} is not available.'.format(name=name)) 23 | return _model_sha256[name][:8] 24 | 25 | 26 | resnest_model_urls = { 27 | name: _url_format.format(name, short_hash(name)) 28 | for name in _model_sha256.keys() 29 | } 30 | 31 | 32 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 33 | model = ResNet(Bottleneck, [3, 4, 6, 3], 34 | radix=2, 35 | groups=1, 36 | bottleneck_width=64, 37 | deep_stem=True, 38 | stem_width=32, 39 | avg_down=True, 40 | avd=True, 41 | avd_first=False, 42 | **kwargs) 43 | if pretrained: 44 | model.load_state_dict( 45 | torch.hub.load_state_dict_from_url(resnest_model_urls['resnest50'], 46 | progress=True, 47 | check_hash=True)) 48 | return model 49 | 50 | 51 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 52 | model = ResNet(Bottleneck, [3, 4, 23, 3], 53 | radix=2, 54 | groups=1, 55 | bottleneck_width=64, 56 | deep_stem=True, 57 | stem_width=64, 58 | avg_down=True, 59 | avd=True, 60 | avd_first=False, 61 | **kwargs) 62 | if pretrained: 63 | model.load_state_dict( 64 | torch.hub.load_state_dict_from_url( 65 | resnest_model_urls['resnest101'], 66 | progress=True, 67 | check_hash=True)) 68 | return model 69 | 70 | 71 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): 72 | model = ResNet(Bottleneck, [3, 24, 36, 3], 73 | radix=2, 74 | groups=1, 75 | bottleneck_width=64, 76 | deep_stem=True, 77 | stem_width=64, 78 | avg_down=True, 79 | avd=True, 80 | avd_first=False, 81 | **kwargs) 82 | if pretrained: 83 | model.load_state_dict( 84 | torch.hub.load_state_dict_from_url( 85 | resnest_model_urls['resnest200'], 86 | progress=True, 87 | check_hash=True)) 88 | return model 89 | 90 | 91 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): 92 | model = ResNet(Bottleneck, [3, 30, 48, 8], 93 | radix=2, 94 | groups=1, 95 | bottleneck_width=64, 96 | deep_stem=True, 97 | stem_width=64, 98 | avg_down=True, 99 | avd=True, 100 | avd_first=False, 101 | **kwargs) 102 | if pretrained: 103 | model.load_state_dict( 104 | torch.hub.load_state_dict_from_url( 105 | resnest_model_urls['resnest269'], 106 | progress=True, 107 | check_hash=True)) 108 | return model 109 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnest/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | 4 | from .splat import SplAtConv2d, DropBlock2D 5 | from utils.learning import freeze_params 6 | 7 | __all__ = ['ResNet', 'Bottleneck'] 8 | 9 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 10 | 11 | _model_sha256 = {name: checksum for checksum, name in []} 12 | 13 | 14 | def short_hash(name): 15 | if name not in _model_sha256: 16 | raise ValueError( 17 | 'Pretrained model for {name} is not available.'.format(name=name)) 18 | return _model_sha256[name][:8] 19 | 20 | 21 | resnest_model_urls = { 22 | name: _url_format.format(name, short_hash(name)) 23 | for name in _model_sha256.keys() 24 | } 25 | 26 | 27 | class GlobalAvgPool2d(nn.Module): 28 | def __init__(self): 29 | """Global average pooling over the input's spatial dimensions""" 30 | super(GlobalAvgPool2d, self).__init__() 31 | 32 | def forward(self, inputs): 33 | return nn.functional.adaptive_avg_pool2d(inputs, 34 | 1).view(inputs.size(0), -1) 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | """ResNet Bottleneck 39 | """ 40 | # pylint: disable=unused-argument 41 | expansion = 4 42 | 43 | def __init__(self, 44 | inplanes, 45 | planes, 46 | stride=1, 47 | downsample=None, 48 | radix=1, 49 | cardinality=1, 50 | bottleneck_width=64, 51 | avd=False, 52 | avd_first=False, 53 | dilation=1, 54 | is_first=False, 55 | rectified_conv=False, 56 | rectify_avg=False, 57 | norm_layer=None, 58 | dropblock_prob=0.0, 59 | last_gamma=False): 60 | super(Bottleneck, self).__init__() 61 | group_width = int(planes * (bottleneck_width / 64.)) * cardinality 62 | self.conv1 = nn.Conv2d(inplanes, 63 | group_width, 64 | kernel_size=1, 65 | bias=False) 66 | self.bn1 = norm_layer(group_width) 67 | self.dropblock_prob = dropblock_prob 68 | self.radix = radix 69 | self.avd = avd and (stride > 1 or is_first) 70 | self.avd_first = avd_first 71 | 72 | if self.avd: 73 | self.avd_layer = nn.AvgPool2d(3, stride, padding=1) 74 | stride = 1 75 | 76 | if dropblock_prob > 0.0: 77 | self.dropblock1 = DropBlock2D(dropblock_prob, 3) 78 | if radix == 1: 79 | self.dropblock2 = DropBlock2D(dropblock_prob, 3) 80 | self.dropblock3 = DropBlock2D(dropblock_prob, 3) 81 | 82 | if radix >= 1: 83 | self.conv2 = SplAtConv2d(group_width, 84 | group_width, 85 | kernel_size=3, 86 | stride=stride, 87 | padding=dilation, 88 | dilation=dilation, 89 | groups=cardinality, 90 | bias=False, 91 | radix=radix, 92 | rectify=rectified_conv, 93 | rectify_avg=rectify_avg, 94 | norm_layer=norm_layer, 95 | dropblock_prob=dropblock_prob) 96 | elif rectified_conv: 97 | from rfconv import RFConv2d 98 | self.conv2 = RFConv2d(group_width, 99 | group_width, 100 | kernel_size=3, 101 | stride=stride, 102 | padding=dilation, 103 | dilation=dilation, 104 | groups=cardinality, 105 | bias=False, 106 | average_mode=rectify_avg) 107 | self.bn2 = norm_layer(group_width) 108 | else: 109 | self.conv2 = nn.Conv2d(group_width, 110 | group_width, 111 | kernel_size=3, 112 | stride=stride, 113 | padding=dilation, 114 | dilation=dilation, 115 | groups=cardinality, 116 | bias=False) 117 | self.bn2 = norm_layer(group_width) 118 | 119 | self.conv3 = nn.Conv2d(group_width, 120 | planes * 4, 121 | kernel_size=1, 122 | bias=False) 123 | self.bn3 = norm_layer(planes * 4) 124 | 125 | if last_gamma: 126 | from torch.nn.init import zeros_ 127 | zeros_(self.bn3.weight) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.downsample = downsample 130 | self.dilation = dilation 131 | self.stride = stride 132 | 133 | def forward(self, x): 134 | residual = x 135 | 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | if self.dropblock_prob > 0.0: 139 | out = self.dropblock1(out) 140 | out = self.relu(out) 141 | 142 | if self.avd and self.avd_first: 143 | out = self.avd_layer(out) 144 | 145 | out = self.conv2(out) 146 | if self.radix == 0: 147 | out = self.bn2(out) 148 | if self.dropblock_prob > 0.0: 149 | out = self.dropblock2(out) 150 | out = self.relu(out) 151 | 152 | if self.avd and not self.avd_first: 153 | out = self.avd_layer(out) 154 | 155 | out = self.conv3(out) 156 | out = self.bn3(out) 157 | if self.dropblock_prob > 0.0: 158 | out = self.dropblock3(out) 159 | 160 | if self.downsample is not None: 161 | residual = self.downsample(x) 162 | 163 | out += residual 164 | out = self.relu(out) 165 | 166 | return out 167 | 168 | 169 | class ResNet(nn.Module): 170 | """ResNet Variants 171 | Parameters 172 | ---------- 173 | block : Block 174 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 175 | layers : list of int 176 | Numbers of layers in each block 177 | classes : int, default 1000 178 | Number of classification classes. 179 | dilated : bool, default False 180 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 181 | typically used in Semantic Segmentation. 182 | norm_layer : object 183 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 184 | for Synchronized Cross-GPU BachNormalization). 185 | Reference: 186 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 187 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 188 | """ 189 | 190 | # pylint: disable=unused-variable 191 | def __init__(self, 192 | block, 193 | layers, 194 | radix=1, 195 | groups=1, 196 | bottleneck_width=64, 197 | num_classes=1000, 198 | dilated=False, 199 | dilation=1, 200 | deep_stem=False, 201 | stem_width=64, 202 | avg_down=False, 203 | rectified_conv=False, 204 | rectify_avg=False, 205 | avd=False, 206 | avd_first=False, 207 | final_drop=0.0, 208 | dropblock_prob=0, 209 | last_gamma=False, 210 | norm_layer=nn.BatchNorm2d, 211 | freeze_at=0): 212 | self.cardinality = groups 213 | self.bottleneck_width = bottleneck_width 214 | # ResNet-D params 215 | self.inplanes = stem_width * 2 if deep_stem else 64 216 | self.avg_down = avg_down 217 | self.last_gamma = last_gamma 218 | # ResNeSt params 219 | self.radix = radix 220 | self.avd = avd 221 | self.avd_first = avd_first 222 | 223 | super(ResNet, self).__init__() 224 | self.rectified_conv = rectified_conv 225 | self.rectify_avg = rectify_avg 226 | if rectified_conv: 227 | from rfconv import RFConv2d 228 | conv_layer = RFConv2d 229 | else: 230 | conv_layer = nn.Conv2d 231 | conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} 232 | if deep_stem: 233 | self.conv1 = nn.Sequential( 234 | conv_layer(3, 235 | stem_width, 236 | kernel_size=3, 237 | stride=2, 238 | padding=1, 239 | bias=False, 240 | **conv_kwargs), 241 | norm_layer(stem_width), 242 | nn.ReLU(inplace=True), 243 | conv_layer(stem_width, 244 | stem_width, 245 | kernel_size=3, 246 | stride=1, 247 | padding=1, 248 | bias=False, 249 | **conv_kwargs), 250 | norm_layer(stem_width), 251 | nn.ReLU(inplace=True), 252 | conv_layer(stem_width, 253 | stem_width * 2, 254 | kernel_size=3, 255 | stride=1, 256 | padding=1, 257 | bias=False, 258 | **conv_kwargs), 259 | ) 260 | else: 261 | self.conv1 = conv_layer(3, 262 | 64, 263 | kernel_size=7, 264 | stride=2, 265 | padding=3, 266 | bias=False, 267 | **conv_kwargs) 268 | self.bn1 = norm_layer(self.inplanes) 269 | self.relu = nn.ReLU(inplace=True) 270 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 271 | self.layer1 = self._make_layer(block, 272 | 64, 273 | layers[0], 274 | norm_layer=norm_layer, 275 | is_first=False) 276 | self.layer2 = self._make_layer(block, 277 | 128, 278 | layers[1], 279 | stride=2, 280 | norm_layer=norm_layer) 281 | if dilated or dilation == 4: 282 | self.layer3 = self._make_layer(block, 283 | 256, 284 | layers[2], 285 | stride=1, 286 | dilation=2, 287 | norm_layer=norm_layer, 288 | dropblock_prob=dropblock_prob) 289 | elif dilation == 2: 290 | self.layer3 = self._make_layer(block, 291 | 256, 292 | layers[2], 293 | stride=2, 294 | dilation=1, 295 | norm_layer=norm_layer, 296 | dropblock_prob=dropblock_prob) 297 | else: 298 | self.layer3 = self._make_layer(block, 299 | 256, 300 | layers[2], 301 | stride=2, 302 | norm_layer=norm_layer, 303 | dropblock_prob=dropblock_prob) 304 | 305 | self.stem = [self.conv1, self.bn1] 306 | self.stages = [self.layer1, self.layer2, self.layer3] 307 | 308 | for m in self.modules(): 309 | if isinstance(m, nn.Conv2d): 310 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 311 | m.weight.data.normal_(0, math.sqrt(2. / n)) 312 | elif isinstance(m, norm_layer): 313 | m.weight.data.fill_(1) 314 | m.bias.data.zero_() 315 | 316 | self.freeze(freeze_at) 317 | 318 | def _make_layer(self, 319 | block, 320 | planes, 321 | blocks, 322 | stride=1, 323 | dilation=1, 324 | norm_layer=None, 325 | dropblock_prob=0.0, 326 | is_first=True): 327 | downsample = None 328 | if stride != 1 or self.inplanes != planes * block.expansion: 329 | down_layers = [] 330 | if self.avg_down: 331 | if dilation == 1: 332 | down_layers.append( 333 | nn.AvgPool2d(kernel_size=stride, 334 | stride=stride, 335 | ceil_mode=True, 336 | count_include_pad=False)) 337 | else: 338 | down_layers.append( 339 | nn.AvgPool2d(kernel_size=1, 340 | stride=1, 341 | ceil_mode=True, 342 | count_include_pad=False)) 343 | down_layers.append( 344 | nn.Conv2d(self.inplanes, 345 | planes * block.expansion, 346 | kernel_size=1, 347 | stride=1, 348 | bias=False)) 349 | else: 350 | down_layers.append( 351 | nn.Conv2d(self.inplanes, 352 | planes * block.expansion, 353 | kernel_size=1, 354 | stride=stride, 355 | bias=False)) 356 | down_layers.append(norm_layer(planes * block.expansion)) 357 | downsample = nn.Sequential(*down_layers) 358 | 359 | layers = [] 360 | if dilation == 1 or dilation == 2: 361 | layers.append( 362 | block(self.inplanes, 363 | planes, 364 | stride, 365 | downsample=downsample, 366 | radix=self.radix, 367 | cardinality=self.cardinality, 368 | bottleneck_width=self.bottleneck_width, 369 | avd=self.avd, 370 | avd_first=self.avd_first, 371 | dilation=1, 372 | is_first=is_first, 373 | rectified_conv=self.rectified_conv, 374 | rectify_avg=self.rectify_avg, 375 | norm_layer=norm_layer, 376 | dropblock_prob=dropblock_prob, 377 | last_gamma=self.last_gamma)) 378 | elif dilation == 4: 379 | layers.append( 380 | block(self.inplanes, 381 | planes, 382 | stride, 383 | downsample=downsample, 384 | radix=self.radix, 385 | cardinality=self.cardinality, 386 | bottleneck_width=self.bottleneck_width, 387 | avd=self.avd, 388 | avd_first=self.avd_first, 389 | dilation=2, 390 | is_first=is_first, 391 | rectified_conv=self.rectified_conv, 392 | rectify_avg=self.rectify_avg, 393 | norm_layer=norm_layer, 394 | dropblock_prob=dropblock_prob, 395 | last_gamma=self.last_gamma)) 396 | else: 397 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 398 | 399 | self.inplanes = planes * block.expansion 400 | for i in range(1, blocks): 401 | layers.append( 402 | block(self.inplanes, 403 | planes, 404 | radix=self.radix, 405 | cardinality=self.cardinality, 406 | bottleneck_width=self.bottleneck_width, 407 | avd=self.avd, 408 | avd_first=self.avd_first, 409 | dilation=dilation, 410 | rectified_conv=self.rectified_conv, 411 | rectify_avg=self.rectify_avg, 412 | norm_layer=norm_layer, 413 | dropblock_prob=dropblock_prob, 414 | last_gamma=self.last_gamma)) 415 | 416 | return nn.Sequential(*layers) 417 | 418 | def forward(self, x): 419 | x = self.conv1(x) 420 | x = self.bn1(x) 421 | x = self.relu(x) 422 | x = self.maxpool(x) 423 | 424 | xs = [] 425 | 426 | x = self.layer1(x) 427 | xs.append(x) # 4X 428 | x = self.layer2(x) 429 | xs.append(x) # 8X 430 | x = self.layer3(x) 431 | xs.append(x) # 16X 432 | # Following STMVOS, we drop stage 5. 433 | xs.append(x) # 16X 434 | 435 | return xs 436 | 437 | def freeze(self, freeze_at): 438 | if freeze_at >= 1: 439 | for m in self.stem: 440 | freeze_params(m) 441 | 442 | for idx, stage in enumerate(self.stages, start=2): 443 | if freeze_at >= idx: 444 | freeze_params(stage) 445 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnest/splat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv2d, Module, ReLU 5 | from torch.nn.modules.utils import _pair 6 | 7 | __all__ = ['SplAtConv2d', 'DropBlock2D'] 8 | 9 | 10 | class DropBlock2D(object): 11 | def __init__(self, *args, **kwargs): 12 | raise NotImplementedError 13 | 14 | 15 | class SplAtConv2d(Module): 16 | """Split-Attention Conv2d 17 | """ 18 | def __init__(self, 19 | in_channels, 20 | channels, 21 | kernel_size, 22 | stride=(1, 1), 23 | padding=(0, 0), 24 | dilation=(1, 1), 25 | groups=1, 26 | bias=True, 27 | radix=2, 28 | reduction_factor=4, 29 | rectify=False, 30 | rectify_avg=False, 31 | norm_layer=None, 32 | dropblock_prob=0.0, 33 | **kwargs): 34 | super(SplAtConv2d, self).__init__() 35 | padding = _pair(padding) 36 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 37 | self.rectify_avg = rectify_avg 38 | inter_channels = max(in_channels * radix // reduction_factor, 32) 39 | self.radix = radix 40 | self.cardinality = groups 41 | self.channels = channels 42 | self.dropblock_prob = dropblock_prob 43 | if self.rectify: 44 | from rfconv import RFConv2d 45 | self.conv = RFConv2d(in_channels, 46 | channels * radix, 47 | kernel_size, 48 | stride, 49 | padding, 50 | dilation, 51 | groups=groups * radix, 52 | bias=bias, 53 | average_mode=rectify_avg, 54 | **kwargs) 55 | else: 56 | self.conv = Conv2d(in_channels, 57 | channels * radix, 58 | kernel_size, 59 | stride, 60 | padding, 61 | dilation, 62 | groups=groups * radix, 63 | bias=bias, 64 | **kwargs) 65 | self.use_bn = norm_layer is not None 66 | if self.use_bn: 67 | self.bn0 = norm_layer(channels * radix) 68 | self.relu = ReLU(inplace=True) 69 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 70 | if self.use_bn: 71 | self.bn1 = norm_layer(inter_channels) 72 | self.fc2 = Conv2d(inter_channels, 73 | channels * radix, 74 | 1, 75 | groups=self.cardinality) 76 | if dropblock_prob > 0.0: 77 | self.dropblock = DropBlock2D(dropblock_prob, 3) 78 | self.rsoftmax = rSoftMax(radix, groups) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | if self.use_bn: 83 | x = self.bn0(x) 84 | if self.dropblock_prob > 0.0: 85 | x = self.dropblock(x) 86 | x = self.relu(x) 87 | 88 | batch, rchannel = x.shape[:2] 89 | if self.radix > 1: 90 | if torch.__version__ < '1.5': 91 | splited = torch.split(x, int(rchannel // self.radix), dim=1) 92 | else: 93 | splited = torch.split(x, rchannel // self.radix, dim=1) 94 | gap = sum(splited) 95 | else: 96 | gap = x 97 | gap = F.adaptive_avg_pool2d(gap, 1) 98 | gap = self.fc1(gap) 99 | 100 | if self.use_bn: 101 | gap = self.bn1(gap) 102 | gap = self.relu(gap) 103 | 104 | atten = self.fc2(gap) 105 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 106 | 107 | if self.radix > 1: 108 | if torch.__version__ < '1.5': 109 | attens = torch.split(atten, int(rchannel // self.radix), dim=1) 110 | else: 111 | attens = torch.split(atten, rchannel // self.radix, dim=1) 112 | out = sum([att * split for (att, split) in zip(attens, splited)]) 113 | else: 114 | out = atten * x 115 | return out.contiguous() 116 | 117 | 118 | class rSoftMax(nn.Module): 119 | def __init__(self, radix, cardinality): 120 | super().__init__() 121 | self.radix = radix 122 | self.cardinality = cardinality 123 | 124 | def forward(self, x): 125 | batch = x.size(0) 126 | if self.radix > 1: 127 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 128 | x = F.softmax(x, dim=1) 129 | x = x.reshape(batch, -1) 130 | else: 131 | x = torch.sigmoid(x) 132 | return x 133 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from utils.learning import freeze_params 4 | 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, 10 | inplanes, 11 | planes, 12 | stride=1, 13 | dilation=1, 14 | downsample=None, 15 | BatchNorm=None): 16 | super(Bottleneck, self).__init__() 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn1 = BatchNorm(planes) 19 | self.conv2 = nn.Conv2d(planes, 20 | planes, 21 | kernel_size=3, 22 | stride=stride, 23 | dilation=dilation, 24 | padding=dilation, 25 | bias=False) 26 | self.bn2 = BatchNorm(planes) 27 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 28 | self.bn3 = BatchNorm(planes * 4) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = downsample 31 | self.stride = stride 32 | self.dilation = dilation 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class ResNet(nn.Module): 58 | def __init__(self, block, layers, output_stride, BatchNorm, freeze_at=0): 59 | self.inplanes = 64 60 | super(ResNet, self).__init__() 61 | 62 | if output_stride == 16: 63 | strides = [1, 2, 2, 1] 64 | dilations = [1, 1, 1, 2] 65 | elif output_stride == 8: 66 | strides = [1, 2, 1, 1] 67 | dilations = [1, 1, 2, 4] 68 | else: 69 | raise NotImplementedError 70 | 71 | # Modules 72 | self.conv1 = nn.Conv2d(3, 73 | 64, 74 | kernel_size=7, 75 | stride=2, 76 | padding=3, 77 | bias=False) 78 | self.bn1 = BatchNorm(64) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 81 | 82 | self.layer1 = self._make_layer(block, 83 | 64, 84 | layers[0], 85 | stride=strides[0], 86 | dilation=dilations[0], 87 | BatchNorm=BatchNorm) 88 | self.layer2 = self._make_layer(block, 89 | 128, 90 | layers[1], 91 | stride=strides[1], 92 | dilation=dilations[1], 93 | BatchNorm=BatchNorm) 94 | self.layer3 = self._make_layer(block, 95 | 256, 96 | layers[2], 97 | stride=strides[2], 98 | dilation=dilations[2], 99 | BatchNorm=BatchNorm) 100 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 101 | 102 | self.stem = [self.conv1, self.bn1] 103 | self.stages = [self.layer1, self.layer2, self.layer3] 104 | 105 | self._init_weight() 106 | self.freeze(freeze_at) 107 | 108 | def _make_layer(self, 109 | block, 110 | planes, 111 | blocks, 112 | stride=1, 113 | dilation=1, 114 | BatchNorm=None): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | nn.Conv2d(self.inplanes, 119 | planes * block.expansion, 120 | kernel_size=1, 121 | stride=stride, 122 | bias=False), 123 | BatchNorm(planes * block.expansion), 124 | ) 125 | 126 | layers = [] 127 | layers.append( 128 | block(self.inplanes, planes, stride, max(dilation // 2, 1), 129 | downsample, BatchNorm)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append( 133 | block(self.inplanes, 134 | planes, 135 | dilation=dilation, 136 | BatchNorm=BatchNorm)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, input): 141 | x = self.conv1(input) 142 | x = self.bn1(x) 143 | x = self.relu(x) 144 | x = self.maxpool(x) 145 | 146 | xs = [] 147 | 148 | x = self.layer1(x) 149 | xs.append(x) # 4X 150 | x = self.layer2(x) 151 | xs.append(x) # 8X 152 | x = self.layer3(x) 153 | xs.append(x) # 16X 154 | # Following STMVOS, we drop stage 5. 155 | xs.append(x) # 16X 156 | 157 | return xs 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, nn.BatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | 168 | def freeze(self, freeze_at): 169 | if freeze_at >= 1: 170 | for m in self.stem: 171 | freeze_params(m) 172 | 173 | for idx, stage in enumerate(self.stages, start=2): 174 | if freeze_at >= idx: 175 | freeze_params(stage) 176 | 177 | 178 | def ResNet50(output_stride, BatchNorm, freeze_at=0): 179 | """Constructs a ResNet-50 model. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(Bottleneck, [3, 4, 6, 3], 184 | output_stride, 185 | BatchNorm, 186 | freeze_at=freeze_at) 187 | return model 188 | 189 | 190 | def ResNet101(output_stride, BatchNorm, freeze_at=0): 191 | """Constructs a ResNet-101 model. 192 | Args: 193 | pretrained (bool): If True, returns a model pre-trained on ImageNet 194 | """ 195 | model = ResNet(Bottleneck, [3, 4, 23, 3], 196 | output_stride, 197 | BatchNorm, 198 | freeze_at=freeze_at) 199 | return model 200 | 201 | 202 | if __name__ == "__main__": 203 | import torch 204 | model = ResNet101(BatchNorm=nn.BatchNorm2d, output_stride=8) 205 | input = torch.rand(1, 3, 512, 512) 206 | output, low_level_feat = model(input) 207 | print(output.size()) 208 | print(low_level_feat.size()) 209 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/swin/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_swin_model -------------------------------------------------------------------------------- /aot_plus/networks/encoders/swin/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | 10 | 11 | def build_swin_model(model_type, freeze_at=0): 12 | if model_type == 'swin_base': 13 | model = SwinTransformer(embed_dim=128, 14 | depths=[2, 2, 18, 2], 15 | num_heads=[4, 8, 16, 32], 16 | window_size=7, 17 | drop_path_rate=0.3, 18 | out_indices=(0, 1, 2), 19 | ape=False, 20 | patch_norm=True, 21 | frozen_stages=freeze_at, 22 | use_checkpoint=False) 23 | 24 | else: 25 | raise NotImplementedError(f"Unkown model: {model_type}") 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /aot_plus/networks/engines/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine 2 | 3 | 4 | def build_engine(name, phase='train', **kwargs): 5 | if name == 'aotengine': 6 | if phase == 'train': 7 | return AOTEngine(**kwargs) 8 | elif phase == 'eval': 9 | return AOTInferEngine(**kwargs) 10 | else: 11 | raise NotImplementedError 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /aot_plus/networks/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/layers/__init__.py -------------------------------------------------------------------------------- /aot_plus/networks/layers/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class GroupNorm1D(nn.Module): 7 | def __init__(self, indim, groups=8): 8 | super().__init__() 9 | self.gn = nn.GroupNorm(groups, indim) 10 | 11 | def forward(self, x): 12 | return self.gn(x.permute(1, 2, 0)).permute(2, 0, 1) 13 | 14 | 15 | class GNActDWConv2d(nn.Module): 16 | def __init__(self, indim, gn_groups=32): 17 | super().__init__() 18 | self.gn = nn.GroupNorm(gn_groups, indim) 19 | self.conv = nn.Conv2d(indim, 20 | indim, 21 | 5, 22 | dilation=1, 23 | padding=2, 24 | groups=indim, 25 | bias=False) 26 | 27 | def forward(self, x, size_2d): 28 | h, w = size_2d 29 | _, bs, c = x.size() 30 | x = x.view(h, w, bs, c).permute(2, 3, 0, 1) 31 | x = self.gn(x) 32 | x = F.gelu(x) 33 | x = self.conv(x) 34 | x = x.view(bs, c, h * w).permute(2, 0, 1) 35 | return x 36 | 37 | 38 | class ConvGN(nn.Module): 39 | def __init__(self, indim, outdim, kernel_size, gn_groups=8): 40 | super().__init__() 41 | self.conv = nn.Conv2d(indim, 42 | outdim, 43 | kernel_size, 44 | padding=kernel_size // 2) 45 | self.gn = nn.GroupNorm(gn_groups, outdim) 46 | 47 | def forward(self, x): 48 | return self.gn(self.conv(x)) 49 | 50 | 51 | def seq_to_2d(tensor, size_2d): 52 | h, w = size_2d 53 | _, n, c = tensor.size() 54 | tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() 55 | return tensor 56 | 57 | def twod_to_seq(tensor): 58 | n, c, h, w = tensor.size() 59 | tensor = tensor.view(n, c, h * w).permute(3, 0, 1).contiguous() 60 | return tensor 61 | 62 | 63 | def drop_path(x, drop_prob: float = 0., training: bool = False): 64 | if drop_prob == 0. or not training: 65 | return x 66 | keep_prob = 1 - drop_prob 67 | shape = (x.shape[0], ) + (1, ) * ( 68 | x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 69 | random_tensor = keep_prob + torch.rand( 70 | shape, dtype=x.dtype, device=x.device) 71 | random_tensor.floor_() # binarize 72 | output = x.div(keep_prob) * random_tensor 73 | return output 74 | 75 | 76 | class DropPath(nn.Module): 77 | def __init__(self, drop_prob=None, batch_dim=0): 78 | super(DropPath, self).__init__() 79 | self.drop_prob = drop_prob 80 | self.batch_dim = batch_dim 81 | 82 | def forward(self, x): 83 | return self.drop_path(x, self.drop_prob) 84 | 85 | def drop_path(self, x, drop_prob): 86 | if drop_prob == 0. or not self.training: 87 | return x 88 | keep_prob = 1 - drop_prob 89 | shape = [1 for _ in range(x.ndim)] 90 | shape[self.batch_dim] = x.shape[self.batch_dim] 91 | random_tensor = keep_prob + torch.rand( 92 | shape, dtype=x.dtype, device=x.device) 93 | random_tensor.floor_() # binarize 94 | output = x.div(keep_prob) * random_tensor 95 | return output 96 | 97 | 98 | class DropOutLogit(nn.Module): 99 | def __init__(self, drop_prob=None): 100 | super(DropOutLogit, self).__init__() 101 | self.drop_prob = drop_prob 102 | 103 | def forward(self, x): 104 | return self.drop_logit(x, self.drop_prob) 105 | 106 | def drop_logit(self, x, drop_prob): 107 | if drop_prob == 0. or not self.training: 108 | return x 109 | random_tensor = drop_prob + torch.rand( 110 | x.shape, dtype=x.dtype, device=x.device) 111 | random_tensor.floor_() # binarize 112 | mask = random_tensor * 1e+8 if ( 113 | x.dtype == torch.float32) else random_tensor * 1e+4 114 | output = x - mask 115 | return output 116 | -------------------------------------------------------------------------------- /aot_plus/networks/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from itertools import ifilterfalse 7 | except ImportError: # py3k 8 | from itertools import filterfalse as ifilterfalse 9 | 10 | 11 | def dice_loss(probas, labels, smooth=1): 12 | 13 | C = probas.size(1) 14 | losses = [] 15 | for c in list(range(C)): 16 | fg = (labels == c).float() 17 | if fg.sum() == 0: 18 | continue 19 | class_pred = probas[:, c] 20 | p0 = class_pred 21 | g0 = fg 22 | numerator = 2 * torch.sum(p0 * g0) + smooth 23 | denominator = torch.sum(p0) + torch.sum(g0) + smooth 24 | losses.append(1 - ((numerator) / (denominator))) 25 | return mean(losses) 26 | 27 | 28 | def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6): 29 | ''' 30 | Tversky loss function. 31 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 32 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 33 | 34 | Same as soft dice loss when alpha=beta=0.5. 35 | Same as Jaccord loss when alpha=beta=1.0. 36 | See `Tversky loss function for image segmentation using 3D fully convolutional deep networks` 37 | https://arxiv.org/pdf/1706.05721.pdf 38 | ''' 39 | C = probas.size(1) 40 | losses = [] 41 | for c in list(range(C)): 42 | fg = (labels == c).float() 43 | if fg.sum() == 0: 44 | continue 45 | class_pred = probas[:, c] 46 | p0 = class_pred 47 | p1 = 1 - class_pred 48 | g0 = fg 49 | g1 = 1 - fg 50 | numerator = torch.sum(p0 * g0) 51 | denominator = numerator + alpha * \ 52 | torch.sum(p0*g1) + beta*torch.sum(p1*g0) 53 | losses.append(1 - ((numerator) / (denominator + epsilon))) 54 | return mean(losses) 55 | 56 | 57 | def flatten_probas(probas, labels, ignore=255): 58 | """ 59 | Flattens predictions in the batch 60 | """ 61 | B, C, H, W = probas.size() 62 | probas = probas.permute(0, 2, 3, 63 | 1).contiguous().view(-1, C) # B * H * W, C = P, C 64 | labels = labels.view(-1) 65 | if ignore is None: 66 | return probas, labels 67 | valid = (labels != ignore) 68 | vprobas = probas[valid.view(-1, 1).expand(-1, C)].reshape(-1, C) 69 | # vprobas = probas[torch.nonzero(valid).squeeze()] 70 | vlabels = labels[valid] 71 | return vprobas, vlabels 72 | 73 | 74 | def isnan(x): 75 | return x != x 76 | 77 | 78 | def mean(l, ignore_nan=False, empty=0): 79 | """ 80 | nanmean compatible with generators. 81 | """ 82 | l = iter(l) 83 | if ignore_nan: 84 | l = ifilterfalse(isnan, l) 85 | try: 86 | n = 1 87 | acc = next(l) 88 | except StopIteration: 89 | if empty == 'raise': 90 | raise ValueError('Empty mean') 91 | return empty 92 | for n, v in enumerate(l, 2): 93 | acc += v 94 | if n == 1: 95 | return acc 96 | return acc / n 97 | 98 | 99 | class DiceLoss(nn.Module): 100 | def __init__(self, ignore_index=255): 101 | super(DiceLoss, self).__init__() 102 | self.ignore_index = ignore_index 103 | 104 | def forward(self, tmp_dic, label_dic, step=None): 105 | total_loss = [] 106 | for idx in range(len(tmp_dic)): 107 | pred = tmp_dic[idx] 108 | label = label_dic[idx] 109 | pred = F.softmax(pred, dim=1) 110 | label = label.view(1, 1, pred.size()[2], pred.size()[3]) 111 | loss = dice_loss( 112 | *flatten_probas(pred, label, ignore=self.ignore_index)) 113 | total_loss.append(loss.unsqueeze(0)) 114 | total_loss = torch.cat(total_loss, dim=0) 115 | return total_loss 116 | 117 | 118 | class SoftJaccordLoss(nn.Module): 119 | def __init__(self, ignore_index=255): 120 | super(SoftJaccordLoss, self).__init__() 121 | self.ignore_index = ignore_index 122 | 123 | def forward(self, tmp_dic, label_dic, step=None): 124 | total_loss = [] 125 | for idx in range(len(tmp_dic)): 126 | pred = tmp_dic[idx] 127 | label = label_dic[idx] 128 | pred = F.softmax(pred, dim=1) 129 | label = label.view(1, 1, pred.size()[2], pred.size()[3]) 130 | loss = tversky_loss(*flatten_probas(pred, 131 | label, 132 | ignore=self.ignore_index), 133 | alpha=1.0, 134 | beta=1.0) 135 | if loss != 0: 136 | total_loss.append(loss.unsqueeze(0)) 137 | else: 138 | total_loss.append(torch.zeros(1).cuda()) 139 | total_loss = torch.cat(total_loss, dim=0) 140 | return total_loss 141 | 142 | 143 | class CrossEntropyLoss(nn.Module): 144 | def __init__(self, 145 | top_k_percent_pixels=None, 146 | hard_example_mining_step=100000): 147 | super(CrossEntropyLoss, self).__init__() 148 | self.top_k_percent_pixels = top_k_percent_pixels 149 | if top_k_percent_pixels is not None: 150 | assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1) 151 | self.hard_example_mining_step = hard_example_mining_step + 1e-5 152 | if self.top_k_percent_pixels is None: 153 | self.celoss = nn.CrossEntropyLoss(ignore_index=255, 154 | reduction='mean') 155 | else: 156 | self.celoss = nn.CrossEntropyLoss(ignore_index=255, 157 | reduction='none') 158 | 159 | def forward(self, dic_tmp, y, step): 160 | total_loss = [] 161 | for i in range(len(dic_tmp)): 162 | pred_logits = dic_tmp[i] 163 | gts = y[i] 164 | if self.top_k_percent_pixels is None: 165 | final_loss = self.celoss(pred_logits, gts) 166 | else: 167 | # Only compute the loss for top k percent pixels. 168 | # First, compute the loss for all pixels. Note we do not put the loss 169 | # to loss_collection and set reduction = None to keep the shape. 170 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) 171 | pred_logits = pred_logits.view( 172 | -1, pred_logits.size(1), 173 | pred_logits.size(2) * pred_logits.size(3)) 174 | gts = gts.view(-1, gts.size(1) * gts.size(2)) 175 | pixel_losses = self.celoss(pred_logits, gts) 176 | if self.hard_example_mining_step == 0: 177 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels) 178 | else: 179 | ratio = min(1.0, 180 | step / float(self.hard_example_mining_step)) 181 | top_k_pixels = int((ratio * self.top_k_percent_pixels + 182 | (1.0 - ratio)) * num_pixels) 183 | top_k_loss, top_k_indices = torch.topk(pixel_losses, 184 | k=top_k_pixels, 185 | dim=1) 186 | 187 | final_loss = torch.mean(top_k_loss) 188 | if final_loss != 0: 189 | final_loss = final_loss.unsqueeze(0) 190 | else: 191 | final_loss = torch.zeros(1).cuda() 192 | total_loss.append(final_loss) 193 | total_loss = torch.cat(total_loss, dim=0) 194 | return total_loss 195 | -------------------------------------------------------------------------------- /aot_plus/networks/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FrozenBatchNorm2d(nn.Module): 7 | """ 8 | BatchNorm2d where the batch statistics and the affine parameters 9 | are fixed 10 | """ 11 | def __init__(self, n, epsilon=1e-5): 12 | super(FrozenBatchNorm2d, self).__init__() 13 | self.register_buffer("weight", torch.ones(n)) 14 | self.register_buffer("bias", torch.zeros(n)) 15 | self.register_buffer("running_mean", torch.zeros(n)) 16 | self.register_buffer("running_var", torch.ones(n) - epsilon) 17 | self.epsilon = epsilon 18 | 19 | def forward(self, x): 20 | """ 21 | Refer to Detectron2 (https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py) 22 | """ 23 | if x.requires_grad: 24 | # When gradients are needed, F.batch_norm will use extra memory 25 | # because its backward op computes gradients for weight/bias as well. 26 | scale = self.weight * (self.running_var + self.epsilon).rsqrt() 27 | bias = self.bias - self.running_mean * scale 28 | scale = scale.reshape(1, -1, 1, 1) 29 | bias = bias.reshape(1, -1, 1, 1) 30 | out_dtype = x.dtype # may be half 31 | return x * scale.to(out_dtype) + bias.to(out_dtype) 32 | else: 33 | # When gradients are not needed, F.batch_norm is a single fused op 34 | # and provide more optimization opportunities. 35 | return F.batch_norm( 36 | x, 37 | self.running_mean, 38 | self.running_var, 39 | self.weight, 40 | self.bias, 41 | training=False, 42 | eps=self.epsilon, 43 | ) 44 | -------------------------------------------------------------------------------- /aot_plus/networks/layers/position.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils.math import truncated_normal_ 8 | 9 | 10 | class Downsample2D(nn.Module): 11 | def __init__(self, mode='nearest', scale=4): 12 | super().__init__() 13 | self.mode = mode 14 | self.scale = scale 15 | 16 | def forward(self, x): 17 | n, c, h, w = x.size() 18 | x = F.interpolate(x, 19 | size=(h // self.scale + 1, w // self.scale + 1), 20 | mode=self.mode) 21 | return x 22 | 23 | 24 | def generate_coord(x): 25 | _, _, h, w = x.size() 26 | device = x.device 27 | col = torch.arange(0, h, device=device) 28 | row = torch.arange(0, w, device=device) 29 | grid_h, grid_w = torch.meshgrid(col, row) 30 | return grid_h, grid_w 31 | 32 | 33 | class PositionEmbeddingSine(nn.Module): 34 | def __init__(self, 35 | num_pos_feats=64, 36 | temperature=10000, 37 | normalize=False, 38 | scale=None): 39 | super().__init__() 40 | self.num_pos_feats = num_pos_feats 41 | self.temperature = temperature 42 | self.normalize = normalize 43 | if scale is not None and normalize is False: 44 | raise ValueError("normalize should be True if scale is passed") 45 | if scale is None: 46 | scale = 2 * math.pi 47 | self.scale = scale 48 | 49 | def forward(self, x): 50 | grid_y, grid_x = generate_coord(x) 51 | 52 | y_embed = grid_y.unsqueeze(0).float() 53 | x_embed = grid_x.unsqueeze(0).float() 54 | 55 | if self.normalize: 56 | eps = 1e-6 57 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 58 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 59 | 60 | dim_t = torch.arange(self.num_pos_feats, 61 | dtype=torch.float32, 62 | device=x.device) 63 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) 64 | 65 | pos_x = x_embed[:, :, :, None] / dim_t 66 | pos_y = y_embed[:, :, :, None] / dim_t 67 | pos_x = torch.stack( 68 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), 69 | dim=4).flatten(3) 70 | pos_y = torch.stack( 71 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), 72 | dim=4).flatten(3) 73 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 74 | return pos 75 | 76 | 77 | class PositionEmbeddingLearned(nn.Module): 78 | def __init__(self, num_pos_feats=64, H=30, W=30): 79 | super().__init__() 80 | self.H = H 81 | self.W = W 82 | self.pos_emb = nn.Parameter( 83 | truncated_normal_(torch.zeros(1, num_pos_feats, H, W))) 84 | 85 | def forward(self, x): 86 | bs, _, h, w = x.size() 87 | pos_emb = self.pos_emb 88 | if h != self.H or w != self.W: 89 | pos_emb = F.interpolate(pos_emb, size=(h, w), mode="bilinear") 90 | return pos_emb 91 | -------------------------------------------------------------------------------- /aot_plus/networks/models/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.models.aot import AOT 2 | 3 | 4 | def build_vos_model(name, cfg, **kwargs): 5 | 6 | if name == 'aot': 7 | return AOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /aot_plus/networks/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /aot_plus/networks/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /aot_plus/networks/models/__pycache__/aot.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/models/__pycache__/aot.cpython-37.pyc -------------------------------------------------------------------------------- /aot_plus/networks/models/__pycache__/aot.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/models/__pycache__/aot.cpython-38.pyc -------------------------------------------------------------------------------- /aot_plus/networks/models/aot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.encoders import build_encoder 4 | from networks.layers.transformer import LongShortTermTransformer 5 | from networks.decoders import build_decoder 6 | from networks.layers.position import PositionEmbeddingSine 7 | 8 | 9 | class AOT(nn.Module): 10 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): 11 | super().__init__() 12 | self.cfg = cfg 13 | self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM 14 | self.epsilon = cfg.MODEL_EPSILON 15 | 16 | self.encoder = build_encoder(encoder, 17 | frozen_bn=cfg.MODEL_FREEZE_BN, 18 | freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT) 19 | self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1], 20 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 21 | kernel_size=1) 22 | 23 | self.LSTT = LongShortTermTransformer( 24 | cfg.MODEL_LSTT_NUM, 25 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 26 | cfg.MODEL_SELF_HEADS, 27 | cfg.MODEL_ATT_HEADS, 28 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, 29 | droppath=cfg.TRAIN_LSTT_DROPPATH, 30 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, 31 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, 32 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, 33 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, 34 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 35 | return_intermediate=True, 36 | simplified=cfg.MODEL_SIMPLIFIED_STM, 37 | stopgrad=cfg.MODEL_STM_STOPGRAD, 38 | joint_longatt=cfg.MODEL_JOINT_LONGATT, 39 | linear_q=cfg.MODEL_LINEAR_Q, 40 | norm_inp=cfg.MODEL_NORM_INP, 41 | recurrent_stm=cfg.MODEL_RECURRENT_STM, 42 | recurrent_ltm=cfg.MODEL_RECURRENT_LTM) 43 | 44 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ 45 | (cfg.MODEL_LSTT_NUM + 46 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM 47 | 48 | self.decoder = build_decoder( 49 | decoder, 50 | in_dim=decoder_indim, 51 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, 52 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 53 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, 54 | shortcut_dims=cfg.MODEL_ENCODER_DIM, 55 | align_corners=cfg.MODEL_ALIGN_CORNERS) 56 | 57 | id_dim = cfg.MODEL_MAX_OBJ_NUM + 1 58 | if cfg.MODEL_IGNORE_TOKEN: 59 | id_dim = cfg.MODEL_MAX_OBJ_NUM + 2 60 | if cfg.MODEL_ALIGN_CORNERS: 61 | self.patch_wise_id_bank = nn.Conv2d( 62 | id_dim, 63 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 64 | kernel_size=17, 65 | stride=16, 66 | padding=8) 67 | else: 68 | self.patch_wise_id_bank = nn.Conv2d( 69 | id_dim, 70 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 71 | kernel_size=16, 72 | stride=16, 73 | padding=0) 74 | 75 | self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True) 76 | 77 | self.pos_generator = PositionEmbeddingSine( 78 | cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True) 79 | 80 | self._init_weight() 81 | 82 | def get_pos_emb(self, x): 83 | pos_emb = self.pos_generator(x) 84 | return pos_emb 85 | 86 | def get_id_emb(self, x): 87 | id_emb = self.patch_wise_id_bank(x) 88 | id_emb = self.id_dropout(id_emb) 89 | return id_emb 90 | 91 | def encode_image(self, img): 92 | xs = self.encoder(img) 93 | xs[-1] = self.encoder_projector(xs[-1]) 94 | return xs 95 | 96 | def decode_id_logits(self, lstt_emb, shortcuts): 97 | n, c, h, w = shortcuts[-1].size() 98 | decoder_inputs = [shortcuts[-1]] 99 | for emb in lstt_emb: 100 | decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1)) 101 | pred_logit = self.decoder(decoder_inputs, shortcuts) 102 | return pred_logit 103 | 104 | def LSTT_forward(self, 105 | curr_embs, 106 | long_term_memories, 107 | short_term_memories, 108 | curr_id_emb=None, 109 | pos_emb=None, 110 | size_2d=(30, 30)): 111 | n, c, h, w = curr_embs[-1].size() 112 | curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1) 113 | lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories, 114 | short_term_memories, curr_id_emb, 115 | pos_emb, size_2d) 116 | lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip( 117 | *lstt_memories) 118 | return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories 119 | 120 | def _init_weight(self): 121 | nn.init.xavier_uniform_(self.encoder_projector.weight) 122 | nn.init.orthogonal_( 123 | self.patch_wise_id_bank.weight.view( 124 | self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1), 125 | gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2) 126 | -------------------------------------------------------------------------------- /aot_plus/tools/demo.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | import os 4 | from time import time 5 | 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | 9 | import cv2 10 | from PIL import Image 11 | from skimage.morphology.binary import binary_dilation 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | from torchvision import transforms 18 | 19 | from networks.models import build_vos_model 20 | from networks.engines import build_engine 21 | from utils.checkpoint import load_network 22 | 23 | from dataloaders.eval_datasets import VOSTest 24 | import dataloaders.video_transforms as tr 25 | from utils.image import save_mask 26 | 27 | _palette = [ 28 | 255, 0, 0, 0, 0, 139, 255, 255, 84, 0, 255, 0, 139, 0, 139, 0, 128, 128, 29 | 128, 128, 128, 139, 0, 0, 218, 165, 32, 144, 238, 144, 160, 82, 45, 148, 0, 30 | 211, 255, 0, 255, 30, 144, 255, 255, 218, 185, 85, 107, 47, 255, 140, 0, 31 | 50, 205, 50, 123, 104, 238, 240, 230, 140, 72, 61, 139, 128, 128, 0, 0, 0, 32 | 205, 221, 160, 221, 143, 188, 143, 127, 255, 212, 176, 224, 230, 244, 164, 33 | 96, 250, 128, 114, 70, 130, 180, 0, 128, 0, 173, 255, 47, 255, 105, 180, 34 | 238, 130, 238, 154, 205, 50, 220, 20, 60, 176, 48, 96, 0, 206, 209, 0, 191, 35 | 255, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 36 | 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, 37 | 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58, 38 | 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64, 39 | 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70, 40 | 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77, 41 | 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83, 42 | 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89, 43 | 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96, 44 | 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101, 45 | 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106, 46 | 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111, 47 | 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116, 48 | 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, 49 | 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126, 50 | 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, 51 | 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136, 52 | 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141, 53 | 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146, 54 | 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151, 55 | 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156, 56 | 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161, 57 | 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166, 58 | 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171, 59 | 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176, 60 | 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181, 61 | 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186, 62 | 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191, 63 | 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196, 64 | 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201, 65 | 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206, 66 | 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 67 | 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216, 68 | 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221, 69 | 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226, 70 | 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231, 71 | 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236, 72 | 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241, 73 | 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246, 74 | 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251, 75 | 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255, 0, 0, 0 76 | ] 77 | color_palette = np.array(_palette).reshape(-1, 3) 78 | 79 | 80 | def overlay(image, mask, colors=[255, 0, 0], cscale=1, alpha=0.4): 81 | colors = np.atleast_2d(colors) * cscale 82 | 83 | im_overlay = image.copy() 84 | object_ids = np.unique(mask) 85 | 86 | for object_id in object_ids[1:]: 87 | # Overlay color on binary mask 88 | 89 | foreground = image * alpha + np.ones( 90 | image.shape) * (1 - alpha) * np.array(colors[object_id]) 91 | binary_mask = mask == object_id 92 | 93 | # Compose image 94 | im_overlay[binary_mask] = foreground[binary_mask] 95 | 96 | countours = binary_dilation(binary_mask) ^ binary_mask 97 | im_overlay[countours, :] = 0 98 | 99 | return im_overlay.astype(image.dtype) 100 | 101 | 102 | def demo(cfg): 103 | video_fps = 10 104 | gpu_id = cfg.TEST_GPU_ID 105 | 106 | # Load pre-trained model 107 | print('Build AOT model.') 108 | model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(gpu_id) 109 | 110 | print('Load checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) 111 | model, _ = load_network(model, cfg.TEST_CKPT_PATH, gpu_id) 112 | 113 | print('Build AOT engine.') 114 | engine = build_engine(cfg.MODEL_ENGINE, 115 | phase='eval', 116 | aot_model=model, 117 | gpu_id=gpu_id, 118 | long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP) 119 | 120 | # Prepare datasets for each sequence 121 | transform = transforms.Compose([ 122 | tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE, 123 | cfg.TEST_FLIP, cfg.TEST_MULTISCALE, 124 | cfg.MODEL_ALIGN_CORNERS), 125 | tr.MultiToTensor() 126 | ]) 127 | if cfg.TEST_DATA_PATH is not None and cfg.SEQ_LIST is None: 128 | image_root = os.path.join(cfg.TEST_DATA_PATH, 'images') 129 | label_root = os.path.join(cfg.TEST_DATA_PATH, 'masks') 130 | sequences = os.listdir(image_root) 131 | elif cfg.SEQ_LIST is not None: 132 | sequences = [seq.strip() for seq in cfg.SEQ_LIST] 133 | print(sequences) 134 | image_root = os.path.join(cfg.TEST_DATA_PATH, 'JPEGImages_10fps') 135 | label_root = os.path.join(cfg.TEST_DATA_PATH, 'Annotations') 136 | else: 137 | image_root = cfg.TEST_FRAME_PATH 138 | label_root = cfg.TEST_LABEL_PATH 139 | sequences = [image_root.split('/')[-2]] 140 | 141 | seq_datasets = [] 142 | for seq_name in sequences: 143 | print('Build a dataset for sequence {}.'.format(seq_name)) 144 | if cfg.TEST_DATA_PATH is not None: 145 | seq_images = np.sort(os.listdir(os.path.join(image_root, seq_name))) 146 | else: 147 | seq_images = np.sort(os.listdir(image_root)) 148 | image_root = "/".join(image_root.split('/')[:-2]) 149 | label_root = "/".join(label_root.split('/')[:-2]) 150 | 151 | seq_labels = [seq_images[0].replace('jpg', 'png')] 152 | seq_dataset = VOSTest(image_root, 153 | label_root, 154 | seq_name, 155 | seq_images, 156 | seq_labels, 157 | transform=transform) 158 | seq_datasets.append(seq_dataset) 159 | 160 | # Infer 161 | output_root = cfg.TEST_OUTPUT_PATH 162 | output_mask_root = os.path.join(output_root, 'pred_masks') 163 | if not os.path.exists(output_mask_root): 164 | os.makedirs(output_mask_root) 165 | 166 | for seq_dataset in seq_datasets: 167 | seq_name = seq_dataset.seq_name 168 | image_seq_root = os.path.join(image_root, seq_name) 169 | output_mask_seq_root = os.path.join(output_mask_root, seq_name) 170 | if not os.path.exists(output_mask_seq_root): 171 | os.makedirs(output_mask_seq_root) 172 | print('Build a dataloader for sequence {}.'.format(seq_name)) 173 | seq_dataloader = DataLoader(seq_dataset, 174 | batch_size=1, 175 | shuffle=False, 176 | num_workers=cfg.TEST_WORKERS, 177 | pin_memory=True) 178 | num_frames = len(seq_dataset) 179 | max_gap = int(round(num_frames / 30)) 180 | gap = max(max_gap, 5) 181 | print(gap) 182 | engine.long_term_mem_gap = gap 183 | 184 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 185 | output_video_path = os.path.join( 186 | output_root, '{}_{}fps.avi'.format(seq_name, video_fps)) 187 | 188 | print('Start the inference of sequence {}:'.format(seq_name)) 189 | model.eval() 190 | engine.restart_engine() 191 | with torch.no_grad(): 192 | time_start = 0 193 | for frame_idx, samples in enumerate(seq_dataloader): 194 | # if time_start != 0: 195 | # print(time() - time_start) 196 | # time_start = time() 197 | sample = samples[0] 198 | img_name = sample['meta']['current_name'][0] 199 | 200 | obj_nums = sample['meta']['obj_num'] 201 | output_height = sample['meta']['height'] 202 | output_width = sample['meta']['width'] 203 | obj_idx = sample['meta']['obj_idx'] 204 | 205 | obj_nums = [int(obj_num) for obj_num in obj_nums] 206 | obj_idx = [int(_obj_idx) for _obj_idx in obj_idx] 207 | 208 | current_img = sample['current_img'] 209 | current_img = current_img.cuda(gpu_id, non_blocking=True) 210 | 211 | if frame_idx == 0: 212 | videoWriter = cv2.VideoWriter( 213 | output_video_path, fourcc, video_fps, 214 | (int(output_width), int(output_height))) 215 | print( 216 | 'Object number: {}. Inference size: {}x{}. Output size: {}x{}.' 217 | .format(obj_nums[0], 218 | current_img.size()[2], 219 | current_img.size()[3], int(output_height), 220 | int(output_width))) 221 | current_label = sample['current_label'].cuda( 222 | gpu_id, non_blocking=True).float() 223 | current_label = F.interpolate(current_label, 224 | size=current_img.size()[2:], 225 | mode="nearest") 226 | # add reference frame 227 | engine.add_reference_frame(current_img, 228 | current_label, 229 | frame_step=0, 230 | obj_nums=obj_nums) 231 | else: 232 | print('Processing image {}...'.format(img_name)) 233 | # predict segmentation 234 | engine.match_propogate_one_frame(current_img) 235 | pred_logit = engine.decode_current_logits( 236 | (output_height, output_width)) 237 | pred_prob = torch.softmax(pred_logit, dim=1) 238 | pred_label = torch.argmax(pred_prob, dim=1, 239 | keepdim=True).float() 240 | _pred_label = F.interpolate(pred_label, 241 | size=engine.input_size_2d, 242 | mode="nearest") 243 | # update memory 244 | engine.update_memory(_pred_label) 245 | 246 | # save results 247 | input_image_path = os.path.join(image_seq_root, img_name) 248 | output_mask_path = os.path.join( 249 | output_mask_seq_root, 250 | img_name.split('.')[0] + '.png') 251 | 252 | pred_label = Image.fromarray( 253 | pred_label.squeeze(0).squeeze(0).cpu().numpy().astype( 254 | 'uint8')).convert('P') 255 | pred_label.putpalette(_palette) 256 | pred_label.save(output_mask_path) 257 | 258 | input_image = Image.open(input_image_path) 259 | 260 | overlayed_image = overlay( 261 | np.array(input_image, dtype=np.uint8), 262 | np.array(pred_label, dtype=np.uint8), color_palette) 263 | videoWriter.write(overlayed_image[..., [2, 1, 0]]) 264 | 265 | print('Save a visualization video to {}.'.format(output_video_path)) 266 | videoWriter.release() 267 | 268 | 269 | #CUDA_VISIBLE_DEVICES=0 python tools/demo.py --seq_list datasets/VOST/ImageSets/val.txt --ckpt_path pretrain_models/vost_aot_best.pth --output_path demo_output/aot_best_10fps 270 | def main(): 271 | import argparse 272 | parser = argparse.ArgumentParser(description="AOT Demo") 273 | parser.add_argument('--exp_name', type=str, default='default') 274 | 275 | parser.add_argument('--stage', type=str, default='pre_vost') 276 | parser.add_argument('--model', type=str, default='r50_aotl') 277 | 278 | parser.add_argument('--gpu_id', type=int, default=0) 279 | 280 | parser.add_argument('--data_path', type=str, default='./datasets/Demo') 281 | parser.add_argument('--seq_name', type=str, default='') 282 | parser.add_argument('--seq_list', type=str, default='') 283 | parser.add_argument('--output_path', type=str, default='./demo_output') 284 | parser.add_argument('--ckpt_path', 285 | type=str, 286 | default='./pretrain_models/R50_AOTL_PRE_YTB_DAV.pth') 287 | 288 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3) 289 | 290 | parser.add_argument('--amp', action='store_true') 291 | parser.set_defaults(amp=False) 292 | 293 | args = parser.parse_args() 294 | 295 | engine_config = importlib.import_module('configs.' + args.stage) 296 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 297 | 298 | cfg.TEST_GPU_ID = args.gpu_id 299 | 300 | cfg.TEST_CKPT_PATH = args.ckpt_path 301 | if len(args.seq_list) != 0: 302 | lst_file = open(args.seq_list, 'r') 303 | lst = lst_file.readlines() 304 | cfg.SEQ_LIST = lst 305 | cfg.TEST_DATA_PATH = '/mnt/fsx/VOST/' 306 | cfg.TEST_FRAME_PATH = None 307 | elif len(args.seq_name) != 0: 308 | cfg.TEST_DATA_PATH = None 309 | cfg.TEST_SEQ_LIST = None 310 | cfg.SEQ_LIST = None 311 | cfg.TEST_FRAME_PATH = 'datasets/VOST/JPEGImages_10fps/%s/' % args.seq_name 312 | cfg.TEST_LABEL_PATH = 'datasets/VOST/Annotations/%s/' % args.seq_name 313 | else: 314 | cfg.TEST_DATA_PATH = args.data_path 315 | cfg.TEST_FRAME_PATH = None 316 | cfg.TEST_SEQ_LIST = None 317 | cfg.TEST_OUTPUT_PATH = args.output_path 318 | 319 | cfg.TEST_MIN_SIZE = None 320 | cfg.TEST_MAX_SIZE = args.max_resolution * 800. / 480. 321 | 322 | if args.amp: 323 | with torch.cuda.amp.autocast(enabled=True): 324 | demo(cfg) 325 | else: 326 | demo(cfg) 327 | 328 | 329 | if __name__ == '__main__': 330 | main() 331 | -------------------------------------------------------------------------------- /aot_plus/tools/eval.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('..') 6 | 7 | import torch 8 | import torch.multiprocessing as mp 9 | 10 | from networks.managers.evaluator import Evaluator 11 | 12 | 13 | def main_worker(gpu, cfg, seq_queue=None, info_queue=None, enable_amp=False): 14 | # Initiate a evaluating manager 15 | evaluator = Evaluator(rank=gpu, 16 | cfg=cfg, 17 | seq_queue=seq_queue, 18 | info_queue=info_queue) 19 | # Start evaluation 20 | if enable_amp: 21 | with torch.cuda.amp.autocast(enabled=True): 22 | evaluator.evaluating() 23 | else: 24 | evaluator.evaluating() 25 | 26 | 27 | #python tools/eval.py --exp_name aotplus --stage pre_vost --model r50_aotl --dataset vost --split val --gpu_num 8 --ckpt_path pretrain_models/aotplus.pth --ms 1.0 1.1 1.2 0.9 0.8 28 | def main(): 29 | import argparse 30 | parser = argparse.ArgumentParser(description="Eval VOS") 31 | parser.add_argument('--exp_name', type=str, default='default') 32 | 33 | parser.add_argument('--stage', type=str, default='pre') 34 | parser.add_argument('--model', type=str, default='aott') 35 | 36 | parser.add_argument('--gpu_id', type=int, default=0) 37 | parser.add_argument('--gpu_num', type=int, default=1) 38 | 39 | parser.add_argument('--ckpt_path', type=str, default='') 40 | parser.add_argument('--ckpt_step', type=int, default=-1) 41 | 42 | parser.add_argument('--dataset', type=str, default='') 43 | parser.add_argument('--split', type=str, default='') 44 | 45 | parser.add_argument('--no_ema', action='store_true') 46 | parser.set_defaults(no_ema=False) 47 | 48 | parser.add_argument('--flip', action='store_true') 49 | parser.set_defaults(flip=False) 50 | parser.add_argument('--ms', nargs='+', type=float, default=[1.]) 51 | 52 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3) 53 | 54 | parser.add_argument('--amp', action='store_true') 55 | parser.set_defaults(amp=False) 56 | 57 | args = parser.parse_args() 58 | 59 | engine_config = importlib.import_module('configs.' + args.stage) 60 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 61 | 62 | cfg.TEST_EMA = not args.no_ema 63 | 64 | cfg.TEST_GPU_ID = args.gpu_id 65 | cfg.TEST_GPU_NUM = args.gpu_num 66 | 67 | if args.ckpt_path != '': 68 | cfg.TEST_CKPT_PATH = args.ckpt_path 69 | if args.ckpt_step > 0: 70 | cfg.TEST_CKPT_STEP = args.ckpt_step 71 | 72 | if args.dataset != '': 73 | cfg.TEST_DATASET = args.dataset 74 | 75 | if args.split != '': 76 | cfg.TEST_DATASET_SPLIT = args.split 77 | 78 | cfg.TEST_FLIP = args.flip 79 | cfg.TEST_MULTISCALE = args.ms 80 | 81 | cfg.TEST_MIN_SIZE = None 82 | cfg.TEST_MAX_SIZE = args.max_resolution * 800. / 480. 83 | 84 | if args.gpu_num > 1: 85 | mp.set_start_method('spawn') 86 | seq_queue = mp.Queue() 87 | info_queue = mp.Queue() 88 | mp.spawn(main_worker, 89 | nprocs=cfg.TEST_GPU_NUM, 90 | args=(cfg, seq_queue, info_queue, args.amp)) 91 | else: 92 | main_worker(0, cfg, enable_amp=args.amp) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /aot_plus/tools/train.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import random 3 | import sys 4 | 5 | sys.setrecursionlimit(10000) 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | 9 | import torch.multiprocessing as mp 10 | 11 | from networks.managers.trainer import Trainer 12 | 13 | 14 | def main_worker(gpu, cfg, enable_amp=True, exp_name='default'): 15 | # Initiate a training manager 16 | trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp) 17 | # Start Training 18 | trainer.sequential_training() 19 | 20 | def main(): 21 | import argparse 22 | parser = argparse.ArgumentParser(description="Train VOS") 23 | parser.add_argument('--exp_name', type=str, default='') 24 | parser.add_argument('--stage', type=str, default='pre') 25 | parser.add_argument('--model', type=str, default='aott') 26 | 27 | parser.add_argument('--start_gpu', type=int, default=0) 28 | parser.add_argument('--gpu_num', type=int, default=-1) 29 | parser.add_argument('--batch_size', type=int, default=-1) 30 | parser.add_argument('--dist_url', type=str, default='') 31 | parser.add_argument('--amp', action='store_true') 32 | parser.set_defaults(amp=False) 33 | 34 | parser.add_argument('--pretrained_path', type=str, default='') 35 | 36 | parser.add_argument('--datasets', nargs='+', type=str, default=[]) 37 | parser.add_argument('--lr', type=float, default=-1.) 38 | parser.add_argument('--total_step', type=int, default=-1.) 39 | parser.add_argument('--start_step', type=int, default=-1.) 40 | 41 | args = parser.parse_args() 42 | 43 | engine_config = importlib.import_module('configs.' + args.stage) 44 | 45 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 46 | 47 | if len(args.datasets) > 0: 48 | cfg.DATASETS = args.datasets 49 | 50 | cfg.DIST_START_GPU = args.start_gpu 51 | if args.gpu_num > 0: 52 | cfg.TRAIN_GPUS = args.gpu_num 53 | if args.batch_size > 0: 54 | cfg.TRAIN_BATCH_SIZE = args.batch_size 55 | 56 | if args.pretrained_path != '': 57 | cfg.PRETRAIN_MODEL = args.pretrained_path 58 | 59 | if args.lr > 0: 60 | cfg.TRAIN_LR = args.lr 61 | 62 | if args.total_step > 0: 63 | cfg.TRAIN_TOTAL_STEPS = args.total_step 64 | 65 | if args.start_step > 0: 66 | cfg.TRAIN_START_STEP = args.start_step 67 | 68 | if args.dist_url == '': 69 | cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str( 70 | random.randint(0, 9)) 71 | else: 72 | cfg.DIST_URL = args.dist_url 73 | if cfg.TRAIN_GPUS == 1: 74 | main_worker(0, cfg, args.amp, args.exp_name) 75 | else: 76 | # Use torch.multiprocessing.spawn to launch distributed processes 77 | mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp, args.exp_name)) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /aot_plus/train_vost.sh: -------------------------------------------------------------------------------- 1 | exp="aotplus" 2 | # exp="debug" 3 | gpu_num="4" 4 | devices="4,5,6,7" 5 | 6 | # model="aott" 7 | # model="aots" 8 | # model="aotb" 9 | # model="aotl" 10 | model="r50_aotl" 11 | # model="swinb_aotl" 12 | 13 | stage="pre_vost" 14 | CUDA_VISIBLE_DEVICES=${devices} python tools/train.py --amp \ 15 | --exp_name ${exp} \ 16 | --stage ${stage} \ 17 | --model ${model} \ 18 | --gpu_num ${gpu_num} 19 | 20 | dataset="vost" 21 | split="val" 22 | CUDA_VISIBLE_DEVICES=${devices} python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 23 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} --ms 1.0 1.1 1.2 0.9 0.8 -------------------------------------------------------------------------------- /aot_plus/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/utils/__init__.py -------------------------------------------------------------------------------- /aot_plus/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | 5 | 6 | def load_network_and_optimizer(net, opt, pretrained_dir, gpu, scaler=None): 7 | pretrained = torch.load(pretrained_dir, 8 | map_location=torch.device("cuda:" + str(gpu))) 9 | pretrained_dict = pretrained['state_dict'] 10 | model_dict = net.state_dict() 11 | pretrained_dict_update = {} 12 | pretrained_dict_remove = [] 13 | for k, v in pretrained_dict.items(): 14 | if k in model_dict: 15 | pretrained_dict_update[k] = v 16 | elif k[:7] == 'module.': 17 | if k[7:] in model_dict: 18 | pretrained_dict_update[k[7:]] = v 19 | else: 20 | pretrained_dict_remove.append(k) 21 | model_dict.update(pretrained_dict_update) 22 | net.load_state_dict(model_dict) 23 | opt.load_state_dict(pretrained['optimizer']) 24 | if scaler is not None and 'scaler' in pretrained.keys(): 25 | scaler.load_state_dict(pretrained['scaler']) 26 | del (pretrained) 27 | return net.cuda(gpu), opt, pretrained_dict_remove 28 | 29 | 30 | def load_network_and_optimizer_v2(net, opt, pretrained_dir, gpu, scaler=None): 31 | pretrained = torch.load(pretrained_dir, 32 | map_location=torch.device("cuda:" + str(gpu))) 33 | # load model 34 | pretrained_dict = pretrained['state_dict'] 35 | model_dict = net.state_dict() 36 | pretrained_dict_update = {} 37 | pretrained_dict_remove = [] 38 | for k, v in pretrained_dict.items(): 39 | if k in model_dict: 40 | pretrained_dict_update[k] = v 41 | elif k[:7] == 'module.': 42 | if k[7:] in model_dict: 43 | pretrained_dict_update[k[7:]] = v 44 | else: 45 | pretrained_dict_remove.append(k) 46 | model_dict.update(pretrained_dict_update) 47 | net.load_state_dict(model_dict) 48 | 49 | # load optimizer 50 | opt_dict = opt.state_dict() 51 | all_params = { 52 | param_group['name']: param_group['params'][0] 53 | for param_group in opt_dict['param_groups'] 54 | } 55 | pretrained_opt_dict = {'state': {}, 'param_groups': []} 56 | for idx in range(len(pretrained['optimizer']['param_groups'])): 57 | param_group = pretrained['optimizer']['param_groups'][idx] 58 | if param_group['name'] in all_params.keys(): 59 | pretrained_opt_dict['state'][all_params[ 60 | param_group['name']]] = pretrained['optimizer']['state'][ 61 | param_group['params'][0]] 62 | param_group['params'][0] = all_params[param_group['name']] 63 | pretrained_opt_dict['param_groups'].append(param_group) 64 | 65 | opt_dict.update(pretrained_opt_dict) 66 | opt.load_state_dict(opt_dict) 67 | 68 | # load scaler 69 | if scaler is not None and 'scaler' in pretrained.keys(): 70 | scaler.load_state_dict(pretrained['scaler']) 71 | del (pretrained) 72 | return net.cuda(gpu), opt, pretrained_dict_remove 73 | 74 | 75 | def load_network(net, pretrained_dir, gpu): 76 | pretrained = torch.load(pretrained_dir, 77 | map_location=torch.device("cuda:" + str(gpu))) 78 | if 'state_dict' in pretrained.keys(): 79 | pretrained_dict = pretrained['state_dict'] 80 | elif 'model' in pretrained.keys(): 81 | pretrained_dict = pretrained['model'] 82 | else: 83 | pretrained_dict = pretrained 84 | model_dict = net.state_dict() 85 | pretrained_dict_update = {} 86 | pretrained_dict_remove = [] 87 | for k, v in pretrained_dict.items(): 88 | if k in model_dict and (len(v.shape) > 2 and v.shape[1] != model_dict[k].shape[1]): 89 | model_dict[k][:, :-1, :, :] = v 90 | continue 91 | if k in model_dict: 92 | pretrained_dict_update[k] = v 93 | elif k[:7] == 'module.': 94 | if k[7:] in model_dict: 95 | pretrained_dict_update[k[7:]] = v 96 | else: 97 | pretrained_dict_remove.append(k) 98 | model_dict.update(pretrained_dict_update) 99 | net.load_state_dict(model_dict) 100 | del (pretrained) 101 | return net.cuda(gpu), pretrained_dict_remove 102 | 103 | 104 | def save_network(net, 105 | opt, 106 | step, 107 | save_path, 108 | max_keep=8, 109 | backup_dir='./saved_models', 110 | scaler=None): 111 | ckpt = {'state_dict': net.state_dict(), 'optimizer': opt.state_dict()} 112 | if scaler is not None: 113 | ckpt['scaler'] = scaler.state_dict() 114 | 115 | try: 116 | if not os.path.exists(save_path): 117 | os.makedirs(save_path) 118 | save_file = 'save_step_%s.pth' % (step) 119 | save_dir = os.path.join(save_path, save_file) 120 | torch.save(ckpt, save_dir) 121 | except: 122 | save_path = backup_dir 123 | if not os.path.exists(save_path): 124 | os.makedirs(save_path) 125 | save_file = 'save_step_%s.pth' % (step) 126 | save_dir = os.path.join(save_path, save_file) 127 | torch.save(ckpt, save_dir) 128 | 129 | all_ckpt = os.listdir(save_path) 130 | if len(all_ckpt) > max_keep: 131 | all_step = [] 132 | for ckpt_name in all_ckpt: 133 | step = int(ckpt_name.split('_')[-1].split('.')[0]) 134 | all_step.append(step) 135 | all_step = list(np.sort(all_step))[:-max_keep] 136 | for step in all_step: 137 | ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step)) 138 | os.system('rm {}'.format(ckpt_path)) 139 | -------------------------------------------------------------------------------- /aot_plus/utils/ema.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import unicode_literals 3 | 4 | import torch 5 | 6 | 7 | def get_param_buffer_for_ema(model, 8 | update_buffer=False, 9 | required_buffers=['running_mean', 'running_var']): 10 | params = model.parameters() 11 | all_param_buffer = [p for p in params if p.requires_grad] 12 | if update_buffer: 13 | named_buffers = model.named_buffers() 14 | for key, value in named_buffers: 15 | for buffer_name in required_buffers: 16 | if buffer_name in key: 17 | all_param_buffer.append(value) 18 | break 19 | return all_param_buffer 20 | 21 | 22 | class ExponentialMovingAverage: 23 | """ 24 | Maintains (exponential) moving average of a set of parameters. 25 | """ 26 | def __init__(self, parameters, decay, use_num_updates=True): 27 | """ 28 | Args: 29 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 30 | `model.parameters()`. 31 | decay: The exponential decay. 32 | use_num_updates: Whether to use number of updates when computing 33 | averages. 34 | """ 35 | if decay < 0.0 or decay > 1.0: 36 | raise ValueError('Decay must be between 0 and 1') 37 | self.decay = decay 38 | self.num_updates = 0 if use_num_updates else None 39 | self.shadow_params = [p.clone().detach() for p in parameters] 40 | self.collected_params = [] 41 | 42 | def update(self, parameters): 43 | """ 44 | Update currently maintained parameters. 45 | Call this every time the parameters are updated, such as the result of 46 | the `optimizer.step()` call. 47 | Args: 48 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 49 | parameters used to initialize this object. 50 | """ 51 | decay = self.decay 52 | if self.num_updates is not None: 53 | self.num_updates += 1 54 | decay = min(decay, 55 | (1 + self.num_updates) / (10 + self.num_updates)) 56 | one_minus_decay = 1.0 - decay 57 | with torch.no_grad(): 58 | for s_param, param in zip(self.shadow_params, parameters): 59 | s_param.sub_(one_minus_decay * (s_param - param)) 60 | 61 | def copy_to(self, parameters): 62 | """ 63 | Copy current parameters into given collection of parameters. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | updated with the stored moving averages. 67 | """ 68 | for s_param, param in zip(self.shadow_params, parameters): 69 | param.data.copy_(s_param.data) 70 | 71 | def store(self, parameters): 72 | """ 73 | Save the current parameters for restoring later. 74 | Args: 75 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 76 | temporarily stored. 77 | """ 78 | self.collected_params = [param.clone() for param in parameters] 79 | 80 | def restore(self, parameters): 81 | """ 82 | Restore the parameters stored with the `store` method. 83 | Useful to validate the model with EMA parameters without affecting the 84 | original optimization process. Store the parameters before the 85 | `copy_to` method. After validation (or model saving), use this to 86 | restore the former parameters. 87 | Args: 88 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 89 | updated with the stored parameters. 90 | """ 91 | for c_param, param in zip(self.collected_params, parameters): 92 | param.data.copy_(c_param.data) 93 | del (self.collected_params) 94 | -------------------------------------------------------------------------------- /aot_plus/utils/eval.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os 3 | 4 | 5 | def zip_folder(source_folder, zip_dir): 6 | f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED) 7 | pre_len = len(os.path.dirname(source_folder)) 8 | for dirpath, dirnames, filenames in os.walk(source_folder): 9 | for filename in filenames: 10 | pathfile = os.path.join(dirpath, filename) 11 | arcname = pathfile[pre_len:].strip(os.path.sep) 12 | f.write(pathfile, arcname) 13 | f.close() -------------------------------------------------------------------------------- /aot_plus/utils/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import threading 5 | 6 | _palette = [ 7 | 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 8 | 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 9 | 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, 10 | 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, 11 | 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 12 | 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 13 | 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 14 | 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 15 | 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 16 | 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 17 | 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 18 | 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 19 | 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 20 | 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 21 | 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 22 | 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 23 | 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, 24 | 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, 25 | 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, 26 | 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 27 | 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, 28 | 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, 29 | 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, 30 | 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, 31 | 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 32 | 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, 33 | 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, 34 | 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, 35 | 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, 36 | 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, 37 | 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, 38 | 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, 39 | 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, 40 | 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, 41 | 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, 42 | 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, 43 | 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, 44 | 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, 45 | 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, 46 | 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, 47 | 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, 48 | 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, 49 | 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, 50 | 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, 51 | 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, 52 | 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, 53 | 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, 54 | 255, 255, 255 55 | ] 56 | 57 | 58 | def label2colormap(label): 59 | 60 | m = label.astype(np.uint8) 61 | r, c = m.shape 62 | cmap = np.zeros((r, c, 3), dtype=np.uint8) 63 | cmap[:, :, 0] = (m & 1) << 7 | (m & 8) << 3 | (m & 64) >> 1 64 | cmap[:, :, 1] = (m & 2) << 6 | (m & 16) << 2 | (m & 128) >> 2 65 | cmap[:, :, 2] = (m & 4) << 5 | (m & 32) << 1 66 | return cmap 67 | 68 | 69 | def one_hot_mask(mask, cls_num): 70 | if len(mask.size()) == 3: 71 | mask = mask.unsqueeze(1) 72 | indices = torch.arange(0, cls_num + 1, 73 | device=mask.device).view(1, -1, 1, 1) 74 | return (mask == indices).float(), (mask == 255).float() 75 | 76 | 77 | def masked_image(image, colored_mask, mask, alpha=0.7): 78 | mask = np.expand_dims(mask > 0, axis=0) 79 | mask = np.repeat(mask, 3, axis=0) 80 | show_img = (image * alpha + colored_mask * 81 | (1 - alpha)) * mask + image * (1 - mask) 82 | return show_img 83 | 84 | 85 | def save_image(image, path): 86 | im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0))) 87 | im.save(path) 88 | 89 | 90 | def _save_mask(mask, path, squeeze_idx=None): 91 | if squeeze_idx is not None: 92 | unsqueezed_mask = mask * 0 93 | for idx in range(1, len(squeeze_idx)): 94 | obj_id = squeeze_idx[idx] 95 | mask_i = mask == idx 96 | unsqueezed_mask += (mask_i * obj_id).astype(np.uint8) 97 | mask = unsqueezed_mask 98 | mask = Image.fromarray(mask).convert('P') 99 | mask.putpalette(_palette) 100 | mask.save(path) 101 | 102 | 103 | def save_mask(mask_tensor, path, squeeze_idx=None): 104 | mask = mask_tensor.cpu().numpy().astype('uint8') 105 | threading.Thread(target=_save_mask, args=[mask, path, squeeze_idx]).start() 106 | 107 | 108 | def flip_tensor(tensor, dim=0): 109 | inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1, 110 | device=tensor.device).long() 111 | tensor = tensor.index_select(dim, inv_idx) 112 | return tensor 113 | 114 | 115 | def shuffle_obj_mask(mask): 116 | 117 | bs, obj_num, _, _ = mask.size() 118 | new_masks = [] 119 | for idx in range(bs): 120 | now_mask = mask[idx] 121 | random_matrix = torch.eye(obj_num, device=mask.device) 122 | fg = random_matrix[1:][torch.randperm(obj_num - 1)] 123 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) 124 | now_mask = torch.einsum('nm,nhw->mhw', random_matrix, now_mask) 125 | new_masks.append(now_mask) 126 | 127 | return torch.stack(new_masks, dim=0) 128 | -------------------------------------------------------------------------------- /aot_plus/utils/learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, 5 | base_lr, 6 | p, 7 | itr, 8 | max_itr, 9 | restart=1, 10 | warm_up_steps=1000, 11 | is_cosine_decay=False, 12 | min_lr=1e-5, 13 | encoder_lr_ratio=1.0, 14 | freeze_params=[]): 15 | 16 | if restart > 1: 17 | each_max_itr = int(math.ceil(float(max_itr) / restart)) 18 | itr = itr % each_max_itr 19 | warm_up_steps /= restart 20 | max_itr = each_max_itr 21 | 22 | if itr < warm_up_steps: 23 | now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps 24 | else: 25 | itr = itr - warm_up_steps 26 | max_itr = max_itr - warm_up_steps 27 | if is_cosine_decay: 28 | now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr / 29 | (max_itr + 1)) + 30 | 1.) * 0.5 31 | else: 32 | now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p 33 | 34 | for param_group in optimizer.param_groups: 35 | if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]: 36 | param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr 37 | else: 38 | param_group['lr'] = now_lr 39 | 40 | for freeze_param in freeze_params: 41 | if freeze_param in param_group["name"]: 42 | param_group['lr'] = 0 43 | param_group['weight_decay'] = 0 44 | break 45 | 46 | return now_lr 47 | 48 | 49 | def get_trainable_params(model, 50 | base_lr, 51 | weight_decay, 52 | use_frozen_bn=False, 53 | exclusive_wd_dict={}, 54 | no_wd_keys=[]): 55 | params = [] 56 | memo = set() 57 | total_param = 0 58 | for key, value in model.named_parameters(): 59 | if value in memo: 60 | continue 61 | total_param += value.numel() 62 | if not value.requires_grad: 63 | continue 64 | memo.add(value) 65 | wd = weight_decay 66 | for exclusive_key in exclusive_wd_dict.keys(): 67 | if exclusive_key in key: 68 | wd = exclusive_wd_dict[exclusive_key] 69 | break 70 | if len(value.shape) == 1: # normalization layers 71 | if 'bias' in key: # bias requires no weight decay 72 | wd = 0. 73 | elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay 74 | wd = 0. 75 | elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder 76 | wd = 0. 77 | else: 78 | for no_wd_key in no_wd_keys: 79 | if no_wd_key in key: 80 | wd = 0. 81 | break 82 | params += [{ 83 | "params": [value], 84 | "lr": base_lr, 85 | "weight_decay": wd, 86 | "name": key 87 | }] 88 | 89 | print('Total Param: {:.2f}M'.format(total_param / 1e6)) 90 | return params 91 | 92 | 93 | def freeze_params(module): 94 | for p in module.parameters(): 95 | p.requires_grad = False 96 | -------------------------------------------------------------------------------- /aot_plus/utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0): 5 | all_matrix = [] 6 | for idx in range(num): 7 | random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id)) 8 | if keep_first: 9 | fg = random_matrix[1:][torch.randperm(dim - 1)] 10 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) 11 | else: 12 | random_matrix = random_matrix[torch.randperm(dim)] 13 | all_matrix.append(random_matrix) 14 | return torch.stack(all_matrix, dim=0) 15 | 16 | 17 | def truncated_normal_(tensor, mean=0, std=.02): 18 | size = tensor.shape 19 | tmp = tensor.new_empty(size + (4, )).normal_() 20 | valid = (tmp < 2) & (tmp > -2) 21 | ind = valid.max(-1, keepdim=True)[1] 22 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 23 | tensor.data.mul_(std).add_(mean) 24 | return tensor 25 | -------------------------------------------------------------------------------- /aot_plus/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self, momentum=0.999): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | self.long_count = 0 12 | self.momentum = momentum 13 | self.moving_avg = 0 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | if self.long_count == 0: 23 | self.moving_avg = val 24 | else: 25 | momentum = min(self.momentum, 1. - 1. / self.long_count) 26 | self.moving_avg = self.moving_avg * momentum + val * (1 - momentum) 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.long_count += n 31 | self.avg = self.sum / self.count 32 | -------------------------------------------------------------------------------- /aot_plus/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pytorch_iou(pred, target, obj_num, epsilon=1e-6): 5 | ''' 6 | pred: [bs, h, w] 7 | target: [bs, h, w] 8 | obj_num: [bs] 9 | ''' 10 | bs = pred.size(0) 11 | all_iou = [] 12 | for idx in range(bs): 13 | now_pred = pred[idx].unsqueeze(0) 14 | now_target = target[idx].unsqueeze(0) 15 | now_obj_num = obj_num[idx] 16 | 17 | obj_ids = torch.arange(0, now_obj_num + 1, 18 | device=now_pred.device).int().view(-1, 1, 1) 19 | if obj_ids.size(0) == 1: # only contain background 20 | continue 21 | else: 22 | obj_ids = obj_ids[1:] 23 | now_pred = (now_pred == obj_ids).float() 24 | now_target = (now_target == obj_ids).float() 25 | 26 | intersection = (now_pred * now_target).sum((1, 2)) 27 | union = ((now_pred + now_target) > 0).float().sum((1, 2)) 28 | 29 | now_iou = (intersection + epsilon) / (union + epsilon) 30 | 31 | all_iou.append(now_iou.mean()) 32 | if len(all_iou) > 0: 33 | all_iou = torch.stack(all_iou).mean() 34 | else: 35 | all_iou = torch.ones((1), device=pred.device) 36 | return all_iou 37 | -------------------------------------------------------------------------------- /data.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 13 | 14 | 15 | 16 | 17 | 140 | 141 | 142 | 143 | 144 | 145 | VOST: Data 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 |



154 |
155 | VOST: Data 156 |

157 | 158 | 159 | 164 | 169 | 174 | 179 | 184 | 185 | 186 |
160 |
161 | [Home] 162 |
163 |
165 |
166 | [Paper] 167 |
168 |
170 |
171 | [Data] 172 |
173 |
175 |
176 | [Code] 177 |
178 |
180 |
181 | [Workshop] 182 |
183 |
187 |
188 | 189 | 190 |
191 |
192 |

You can download the training and validation set videos and annotations under this link. Test set videos will be released separately shortly before the challenge deadline. Please note the the original videos are sourced from Ego4D and EPCI-KITCHENS. Our dataset is shared under the Creative Commons Public license (CC BY-NC-SA 4.0). 193 |

194 |
195 |
196 |
197 |
198 | 199 | 200 |
201 | 202 | 203 | 204 | 210 | 211 |
205 |
206 | Email: 207 | support@vostdataset.org 208 |
209 |
212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /evaluation/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | docs/site/ 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # pytest 105 | .pytest_cache 106 | 107 | # Pylint 108 | .pylintrc 109 | 110 | # PyCharm 111 | .idea/ 112 | .DS_Store 113 | 114 | # Generated C code 115 | _mask.c 116 | -------------------------------------------------------------------------------- /evaluation/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Semi-supervised VOS evaluation 2 | 3 | This package is derived from DAVIS 2017 evaluation implementation and used to evaluate semi-supervised video multi-object segmentation models for the VOST dataset. 4 | 5 | ## Installation 6 | Download the code: 7 | ```bash 8 | git clone https://github.com/TRI-ML/VOST.git 9 | ``` 10 | Install the required dependencies: 11 | ```bash 12 | pip install numpy Pillow opencv-python pandas scikit-image scikit-learn tqdm scipy 13 | ``` 14 | 15 | ## Evaluation 16 | In order to evaluate your method on the validation set of VOST, execute the following command: 17 | ```bash 18 | python evaluation_method.py --results_path PATH_TO_YOUR_RESULTS --dataset_path PATH_TO_VOST --set val 19 | ``` 20 | 21 | If you don't want to specify the dataset path every time, you can modify the default value in the variable `default_dataset_path` in `evaluation_method.py`. 22 | 23 | Once the evaluation has finished, two different CSV files will be generated inside the folder with the results: 24 | - `global_results-SUBSET.csv` contains the overall results for a certain `SUBSET`. 25 | - `per-sequence_results-SUBSET.csv` contain the per sequence results for a certain `SUBSET`. 26 | 27 | If a folder that contains the previous files is evaluated again, the results will be read from the CSV files instead of recomputing them. 28 | 29 | ## Citation 30 | 31 | Please cite the following papers in your publications if this code helps your research. 32 | 33 | ```latex 34 | @inproceedings{tokmakov2023breaking, 35 | title={Breaking the “Object” in Video Object Segmentation}, 36 | author={Tokmakov, Pavel and Li, Jie and Gaidon, Adrien}, 37 | booktitle={CVPR}, 38 | year={2023} 39 | } 40 | ``` 41 | 42 | ```latex 43 | @article{Caelles_arXiv_2019, 44 | author = {Sergi Caelles and Jordi Pont-Tuset and Federico Perazzi and Alberto Montes and Kevis-Kokitsi Maninis and Luc {Van Gool}}, 45 | title = {The 2019 DAVIS Challenge on VOS: Unsupervised Multi-Object Segmentation}, 46 | journal = {arXiv}, 47 | year = {2019} 48 | } 49 | ``` 50 | 51 | ```latex 52 | @article{Pont-Tuset_arXiv_2017, 53 | author = {Jordi Pont-Tuset and Federico Perazzi and Sergi Caelles and Pablo Arbel\'aez and Alexander Sorkine-Hornung and Luc {Van Gool}}, 54 | title = {The 2017 DAVIS Challenge on Video Object Segmentation}, 55 | journal = {arXiv:1704.00675}, 56 | year = {2017} 57 | } 58 | ``` 59 | 60 | -------------------------------------------------------------------------------- /evaluation/evaluation_method.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | from time import time 5 | import argparse 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from source.evaluation import Evaluation 10 | 11 | default_dataset_path = '' 12 | 13 | time_start = time() 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--dataset_path', type=str, help='Path to the dataset folder containing the JPEGImages, Annotations, ' 16 | 'ImageSets, Annotations_unsupervised folders', 17 | required=False, default=default_dataset_path) 18 | parser.add_argument('--set', type=str, help='Subset to evaluate the results', default='val') 19 | parser.add_argument('--results_path', type=str, help='Path to the folder containing the sequences folders', 20 | required=True) 21 | parser.add_argument('--re', action='store_true') 22 | args, _ = parser.parse_known_args() 23 | csv_name_global = f'global_results-{args.set}.csv' 24 | csv_name_per_sequence = f'per-sequence_results-{args.set}.csv' 25 | 26 | # Check if the method has been evaluated before, if so read the results, otherwise compute the results 27 | csv_name_global_path = os.path.join(args.results_path, csv_name_global) 28 | csv_name_per_sequence_path = os.path.join(args.results_path, csv_name_per_sequence) 29 | if os.path.exists(csv_name_global_path) and os.path.exists(csv_name_per_sequence_path) and not args.re: 30 | print('Using precomputed results...') 31 | table_g = pd.read_csv(csv_name_global_path) 32 | table_seq = pd.read_csv(csv_name_per_sequence_path) 33 | else: 34 | print(f'Evaluating sequences ...') 35 | # Create dataset and evaluate 36 | dataset_eval = Evaluation(dataset_root=args.dataset_path, gt_set=args.set) 37 | metrics_res = dataset_eval.evaluate(args.results_path) 38 | J = metrics_res['J'] 39 | J_last = None 40 | if 'J_last' in metrics_res: 41 | J_last = metrics_res['J_last'] 42 | 43 | # Generate dataframe for the general results 44 | g_measures = ['J-Mean', 'J-Recall', 'J-Decay', 'J_last-Mean', 'J_last-Recall', 'J_last-Decay'] 45 | g_res = np.array([np.mean(J["M"]), np.mean(J["R"]), np.mean(J["D"]), np.mean(J_last["M"]), np.mean(J_last["R"]), np.mean(J_last["D"])]) 46 | g_res = np.reshape(g_res, [1, len(g_res)]) 47 | table_g = pd.DataFrame(data=g_res, columns=g_measures) 48 | with open(csv_name_global_path, 'w') as f: 49 | table_g.to_csv(f, index=False, float_format="%.3f") 50 | print(f'Global results saved in {csv_name_global_path}') 51 | 52 | # Generate a dataframe for the per sequence results 53 | seq_names = list(J['M_per_object'].keys()) 54 | seq_measures = ['Sequence', 'J-Mean', 'J_last-Mean'] 55 | J_per_object = [J['M_per_object'][x] for x in seq_names] 56 | J_last_per_object = [J_last['M_per_object'][x] for x in seq_names] 57 | table_seq = pd.DataFrame(data=list(zip(seq_names, J_per_object, J_last_per_object)), columns=seq_measures) 58 | with open(csv_name_per_sequence_path, 'w') as f: 59 | table_seq.to_csv(f, index=False, float_format="%.3f") 60 | print(f'Per-sequence results saved in {csv_name_per_sequence_path}') 61 | 62 | # Print the results 63 | sys.stdout.write(f"--------------------------- Global results for {args.set} ---------------------------\n") 64 | print(table_g.to_string(index=False)) 65 | sys.stdout.write(f"\n---------- Per sequence results for {args.set} ----------\n") 66 | print(table_seq.to_string(index=False)) 67 | total_time = time() - time_start 68 | sys.stdout.write('\nTotal time:' + str(total_time)) 69 | -------------------------------------------------------------------------------- /evaluation/source/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | __version__ = '0.1.0' 4 | -------------------------------------------------------------------------------- /evaluation/source/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from collections import defaultdict 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | class Dataset(object): 9 | SUBSET_OPTIONS = ['train', 'val', 'test'] 10 | VOID_LABEL = 255 11 | 12 | def __init__(self, root, subset='val', sequences='all'): 13 | """ 14 | Class to read the dataset 15 | :param root: Path to the dataset folder that contains JPEGImages, Annotations, etc. folders. 16 | :param subset: Set to load the annotations 17 | :param sequences: Sequences to consider, 'all' to use all the sequences in a set. 18 | """ 19 | if subset not in self.SUBSET_OPTIONS: 20 | raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}') 21 | 22 | self.task = 'semi-supervised' 23 | self.subset = subset 24 | self.root = root 25 | self.img_path = os.path.join(self.root, 'JPEGImages') 26 | annotations_folder = 'Annotations' 27 | self.mask_path = os.path.join(self.root, annotations_folder) 28 | self.imagesets_path = os.path.join(self.root, 'ImageSets') 29 | 30 | self._check_directories() 31 | 32 | if sequences == 'all': 33 | with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f: 34 | tmp = f.readlines() 35 | sequences_names = [x.strip() for x in tmp] 36 | else: 37 | sequences_names = sequences if isinstance(sequences, list) else [sequences] 38 | self.sequences = defaultdict(dict) 39 | 40 | for seq in sequences_names: 41 | masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() 42 | if len(masks) == 0: 43 | raise FileNotFoundError(f'Annotations for sequence {seq} not found.') 44 | self.sequences[seq]['masks'] = masks 45 | images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() 46 | filtered_images = [] 47 | for img in images: 48 | ann = img.replace('jpg', 'png').replace('JPEGImages', 'Annotations') 49 | if ann not in masks: 50 | print(ann) 51 | else: 52 | filtered_images.append(img) 53 | self.sequences[seq]['images'] = filtered_images 54 | 55 | # images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() 56 | # if len(images) == 0: 57 | # raise FileNotFoundError(f'Images for sequence {seq} not found.') 58 | # self.sequences[seq]['images'] = images 59 | # masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() 60 | # masks.extend([-1] * (len(images) - len(masks))) 61 | # self.sequences[seq]['masks'] = masks 62 | 63 | def _check_directories(self): 64 | if not os.path.exists(self.root): 65 | raise FileNotFoundError(f'Dataset not found in the specified directory') 66 | if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')): 67 | raise FileNotFoundError(f'Subset sequences list for {self.subset} not found') 68 | if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path): 69 | raise FileNotFoundError(f'Annotations folder not found') 70 | 71 | def get_frames(self, sequence): 72 | for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']): 73 | image = np.array(Image.open(img)) 74 | mask = None if msk is None else np.array(Image.open(msk)) 75 | yield image, mask 76 | 77 | def _get_all_elements(self, sequence, obj_type): 78 | obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) 79 | all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) 80 | obj_id = [] 81 | for i, obj in enumerate(self.sequences[sequence][obj_type]): 82 | all_objs[i, ...] = np.array(Image.open(obj)) 83 | obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1])) 84 | return all_objs, obj_id 85 | 86 | def get_all_images(self, sequence): 87 | return self._get_all_elements(sequence, 'images') 88 | 89 | def get_all_masks(self, sequence, separate_objects_masks=False): 90 | masks, masks_id = self._get_all_elements(sequence, 'masks') 91 | masks_void = np.zeros_like(masks) 92 | 93 | # Separate void and object masks 94 | for i in range(masks.shape[0]): 95 | masks_void[i, ...] = masks[i, ...] == 255 96 | masks[i, masks[i, ...] == 255] = 0 97 | 98 | if separate_objects_masks: 99 | num_objects = int(np.max(masks[0, ...])) 100 | tmp = np.ones((num_objects, *masks.shape)) 101 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 102 | masks = (tmp == masks[None, ...]) 103 | masks = masks > 0 104 | return masks, masks_void, masks_id 105 | 106 | def get_sequences(self): 107 | for seq in self.sequences: 108 | yield seq 109 | 110 | -------------------------------------------------------------------------------- /evaluation/source/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tqdm import tqdm 3 | import warnings 4 | warnings.filterwarnings("ignore", category=RuntimeWarning) 5 | 6 | import numpy as np 7 | from source.dataset import Dataset 8 | from source.metrics import db_eval_boundary, db_eval_iou 9 | from source import utils 10 | from source.results import Results 11 | from scipy.optimize import linear_sum_assignment 12 | from math import floor 13 | 14 | 15 | class Evaluation(object): 16 | def __init__(self, dataset_root, gt_set, sequences='all'): 17 | """ 18 | Class to evaluate sequences from a certain set 19 | :param dataset_root: Path to the dataset folder that contains JPEGImages, Annotations, etc. folders. 20 | :param gt_set: Set to compute the evaluation 21 | :param sequences: Sequences to consider for the evaluation, 'all' to use all the sequences in a set. 22 | """ 23 | self.dataset_root = dataset_root 24 | self.dataset = Dataset(root=dataset_root, subset=gt_set, sequences=sequences) 25 | 26 | @staticmethod 27 | def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric): 28 | if all_res_masks.shape[0] > all_gt_masks.shape[0]: 29 | print("\nIn your PNG files there is an index higher than the number of objects in the sequence!") 30 | all_res_masks = all_res_masks[:all_gt_masks.shape[0]] 31 | # sys.exit() 32 | elif all_res_masks.shape[0] < all_gt_masks.shape[0]: 33 | zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) 34 | all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) 35 | j_metrics_res = np.zeros(all_gt_masks.shape[:2]) 36 | for ii in range(all_gt_masks.shape[0]): 37 | if 'J' in metric: 38 | j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) 39 | return j_metrics_res 40 | 41 | def evaluate(self, res_path, metric=('J', 'J_last'), debug=False): 42 | metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric] 43 | 44 | # Containers 45 | metrics_res = {} 46 | if 'J' in metric: 47 | metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}} 48 | if 'J_last' in metric: 49 | metrics_res['J_last'] = {"M": [], "R": [], "D": [], "M_per_object": {}} 50 | 51 | # Sweep all sequences 52 | results = Results(root_dir=res_path) 53 | for seq in tqdm(list(self.dataset.get_sequences())): 54 | print(seq) 55 | all_gt_masks, all_void_masks, all_masks_id = self.dataset.get_all_masks(seq, True) 56 | all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] 57 | num_eval_frames = len(all_masks_id) 58 | last_quarter_ind = int(floor(num_eval_frames * 0.75)) 59 | all_res_masks = results.read_masks(seq, all_masks_id) 60 | j_metrics_res = self._evaluate_semisupervised(all_gt_masks, all_res_masks, None, metric) 61 | for ii in range(all_gt_masks.shape[0]): 62 | seq_name = f'{seq}_{ii+1}' 63 | if 'J' in metric: 64 | [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii]) 65 | metrics_res['J']["M"].append(JM) 66 | metrics_res['J']["R"].append(JR) 67 | metrics_res['J']["D"].append(JD) 68 | metrics_res['J']["M_per_object"][seq_name] = JM 69 | if 'J_last' in metric: 70 | [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii][last_quarter_ind:]) 71 | metrics_res['J_last']["M"].append(JM) 72 | metrics_res['J_last']["R"].append(JR) 73 | metrics_res['J_last']["D"].append(JD) 74 | metrics_res['J_last']["M_per_object"][seq_name] = JM 75 | 76 | # Show progress 77 | if debug: 78 | sys.stdout.write(seq + '\n') 79 | sys.stdout.flush() 80 | return metrics_res 81 | -------------------------------------------------------------------------------- /evaluation/source/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def db_eval_iou(annotation, segmentation, void_pixels=None): 7 | """ Compute region similarity as the Jaccard Index. 8 | Arguments: 9 | annotation (ndarray): binary annotation map. 10 | segmentation (ndarray): binary segmentation map. 11 | void_pixels (ndarray): optional mask with void pixels 12 | 13 | Return: 14 | jaccard (float): region similarity 15 | """ 16 | assert annotation.shape == segmentation.shape, \ 17 | f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.' 18 | annotation = annotation.astype(np.bool) 19 | segmentation = segmentation.astype(np.bool) 20 | 21 | if void_pixels is not None: 22 | assert annotation.shape == void_pixels.shape, \ 23 | f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.' 24 | void_pixels = void_pixels.astype(np.bool) 25 | else: 26 | void_pixels = np.zeros_like(segmentation) 27 | 28 | # Intersection between all sets 29 | inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 30 | union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 31 | 32 | j = inters / union 33 | if j.ndim == 0: 34 | j = 1 if np.isclose(union, 0) else j 35 | else: 36 | j[np.isclose(union, 0)] = 1 37 | return j 38 | 39 | 40 | def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008): 41 | assert annotation.shape == segmentation.shape 42 | if void_pixels is not None: 43 | assert annotation.shape == void_pixels.shape 44 | if annotation.ndim == 3: 45 | n_frames = annotation.shape[0] 46 | f_res = np.zeros(n_frames) 47 | for frame_id in range(n_frames): 48 | void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ] 49 | f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th) 50 | elif annotation.ndim == 2: 51 | f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th) 52 | else: 53 | raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions') 54 | return f_res 55 | 56 | 57 | def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008): 58 | """ 59 | Compute mean,recall and decay from per-frame evaluation. 60 | Calculates precision/recall for boundaries between foreground_mask and 61 | gt_mask using morphological operators to speed it up. 62 | 63 | Arguments: 64 | foreground_mask (ndarray): binary segmentation image. 65 | gt_mask (ndarray): binary annotated image. 66 | void_pixels (ndarray): optional mask with void pixels 67 | 68 | Returns: 69 | F (float): boundaries F-measure 70 | """ 71 | assert np.atleast_3d(foreground_mask).shape[2] == 1 72 | if void_pixels is not None: 73 | void_pixels = void_pixels.astype(np.bool) 74 | else: 75 | void_pixels = np.zeros_like(foreground_mask).astype(np.bool) 76 | 77 | bound_pix = bound_th if bound_th >= 1 else \ 78 | np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) 79 | 80 | # Get the pixel boundaries of both masks 81 | fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels)) 82 | gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels)) 83 | 84 | from skimage.morphology import disk 85 | 86 | # fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) 87 | fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 88 | # gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) 89 | gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 90 | 91 | # Get the intersection 92 | gt_match = gt_boundary * fg_dil 93 | fg_match = fg_boundary * gt_dil 94 | 95 | # Area of the intersection 96 | n_fg = np.sum(fg_boundary) 97 | n_gt = np.sum(gt_boundary) 98 | 99 | # % Compute precision and recall 100 | if n_fg == 0 and n_gt > 0: 101 | precision = 1 102 | recall = 0 103 | elif n_fg > 0 and n_gt == 0: 104 | precision = 0 105 | recall = 1 106 | elif n_fg == 0 and n_gt == 0: 107 | precision = 1 108 | recall = 1 109 | else: 110 | precision = np.sum(fg_match) / float(n_fg) 111 | recall = np.sum(gt_match) / float(n_gt) 112 | 113 | # Compute F measure 114 | if precision + recall == 0: 115 | F = 0 116 | else: 117 | F = 2 * precision * recall / (precision + recall) 118 | 119 | return F 120 | 121 | 122 | def _seg2bmap(seg, width=None, height=None): 123 | """ 124 | From a segmentation, compute a binary boundary map with 1 pixel wide 125 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 126 | origin from the actual segment boundary. 127 | Arguments: 128 | seg : Segments labeled from 1..k. 129 | width : Width of desired bmap <= seg.shape[1] 130 | height : Height of desired bmap <= seg.shape[0] 131 | Returns: 132 | bmap (ndarray): Binary boundary map. 133 | David Martin 134 | January 2003 135 | """ 136 | 137 | seg = seg.astype(np.bool) 138 | seg[seg > 0] = 1 139 | 140 | assert np.atleast_3d(seg).shape[2] == 1 141 | 142 | width = seg.shape[1] if width is None else width 143 | height = seg.shape[0] if height is None else height 144 | 145 | h, w = seg.shape[:2] 146 | 147 | ar1 = float(width) / float(height) 148 | ar2 = float(w) / float(h) 149 | 150 | assert not ( 151 | width > w | height > h | abs(ar1 - ar2) > 0.01 152 | ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) 153 | 154 | e = np.zeros_like(seg) 155 | s = np.zeros_like(seg) 156 | se = np.zeros_like(seg) 157 | 158 | e[:, :-1] = seg[:, 1:] 159 | s[:-1, :] = seg[1:, :] 160 | se[:-1, :-1] = seg[1:, 1:] 161 | 162 | b = seg ^ e | seg ^ s | seg ^ se 163 | b[-1, :] = seg[-1, :] ^ e[-1, :] 164 | b[:, -1] = seg[:, -1] ^ s[:, -1] 165 | b[-1, -1] = 0 166 | 167 | if w == width and h == height: 168 | bmap = b 169 | else: 170 | bmap = np.zeros((height, width)) 171 | for x in range(w): 172 | for y in range(h): 173 | if b[y, x]: 174 | j = 1 + math.floor((y - 1) + height / h) 175 | i = 1 + math.floor((x - 1) + width / h) 176 | bmap[j, i] = 1 177 | 178 | return bmap 179 | 180 | 181 | if __name__ == '__main__': 182 | from source.dataset import Dataset 183 | from source.results import Results 184 | 185 | dataset = Dataset(root='input_dir/ref', subset='val', sequences='aerobatics') 186 | results = Results(root_dir='examples/osvos') 187 | # Test timing F measure 188 | for seq in dataset.get_sequences(): 189 | all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True) 190 | all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] 191 | all_res_masks = results.read_masks(seq, all_masks_id) 192 | f_metrics_res = np.zeros(all_gt_masks.shape[:2]) 193 | for ii in range(all_gt_masks.shape[0]): 194 | f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...]) 195 | 196 | # Run using to profile code: python -m cProfile -o f_measure.prof metrics.py 197 | # snakeviz f_measure.prof 198 | -------------------------------------------------------------------------------- /evaluation/source/results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import sys 5 | 6 | 7 | class Results(object): 8 | def __init__(self, root_dir): 9 | self.root_dir = root_dir 10 | 11 | def _read_mask(self, sequence, frame_id): 12 | try: 13 | mask_path = os.path.join(self.root_dir, sequence, f'{frame_id}.png') 14 | return np.array(Image.open(mask_path)) 15 | except IOError as err: 16 | print(os.path.join(self.root_dir, sequence, f'{frame_id}.png')) 17 | sys.stdout.write(sequence + " frame %s not found!\n" % frame_id) 18 | sys.stdout.write("The frames have to be indexed PNG files placed inside the corespondent sequence " 19 | "folder.\nThe indexes have to match with the initial frame.\n") 20 | sys.stderr.write("IOError: " + err.strerror + "\n") 21 | sys.exit() 22 | 23 | def read_masks(self, sequence, masks_id): 24 | mask_0 = self._read_mask(sequence, masks_id[0]) 25 | masks = np.zeros((len(masks_id), *mask_0.shape)) 26 | for ii, m in enumerate(masks_id): 27 | masks[ii, ...] = self._read_mask(sequence, m) 28 | num_objects = int(np.max(masks)) 29 | tmp = np.ones((num_objects, *masks.shape)) 30 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 31 | masks = (tmp == masks[None, ...]) > 0 32 | return masks 33 | -------------------------------------------------------------------------------- /evaluation/source/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from PIL import Image 5 | import warnings 6 | from source.dataset import Dataset 7 | 8 | 9 | def _pascal_color_map(N=256, normalized=False): 10 | """ 11 | Python implementation of the color map function for the PASCAL VOC data set. 12 | Official Matlab version can be found in the PASCAL VOC devkit 13 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 14 | """ 15 | 16 | def bitget(byteval, idx): 17 | return (byteval & (1 << idx)) != 0 18 | 19 | dtype = 'float32' if normalized else 'uint8' 20 | cmap = np.zeros((N, 3), dtype=dtype) 21 | for i in range(N): 22 | r = g = b = 0 23 | c = i 24 | for j in range(8): 25 | r = r | (bitget(c, 0) << 7 - j) 26 | g = g | (bitget(c, 1) << 7 - j) 27 | b = b | (bitget(c, 2) << 7 - j) 28 | c = c >> 3 29 | 30 | cmap[i] = np.array([r, g, b]) 31 | 32 | cmap = cmap / 255 if normalized else cmap 33 | return cmap 34 | 35 | 36 | def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): 37 | im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) 38 | if im.shape[:-1] != ann.shape: 39 | raise ValueError('First two dimensions of `im` and `ann` must match') 40 | if im.shape[-1] != 3: 41 | raise ValueError('im must have three channels at the 3 dimension') 42 | 43 | colors = colors or _pascal_color_map() 44 | colors = np.asarray(colors, dtype=np.uint8) 45 | 46 | mask = colors[ann] 47 | fg = im * alpha + (1 - alpha) * mask 48 | 49 | img = im.copy() 50 | img[ann > 0] = fg[ann > 0] 51 | 52 | if contour_thickness: # pragma: no cover 53 | import cv2 54 | for obj_id in np.unique(ann[ann > 0]): 55 | contours = cv2.findContours((ann == obj_id).astype( 56 | np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 57 | cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), 58 | contour_thickness) 59 | return img 60 | 61 | 62 | def generate_obj_proposals(dataset_root, subset, num_proposals, save_path): 63 | dataset = Dataset(dataset_root, subset=subset, codalab=True) 64 | for seq in dataset.get_sequences(): 65 | save_dir = os.path.join(save_path, seq) 66 | if os.path.exists(save_dir): 67 | continue 68 | all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 69 | img_size = all_gt_masks.shape[2:] 70 | num_rows = int(np.ceil(np.sqrt(num_proposals))) 71 | proposals = np.zeros((num_proposals, len(all_masks_id), *img_size)) 72 | height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist() 73 | width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist() 74 | ii = 0 75 | prev_h, prev_w = 0, 0 76 | for h in height_slices[1:]: 77 | for w in width_slices[1:]: 78 | proposals[ii, :, prev_h:h, prev_w:w] = 1 79 | prev_w = w 80 | ii += 1 81 | if ii == num_proposals: 82 | break 83 | prev_h, prev_w = h, 0 84 | if ii == num_proposals: 85 | break 86 | 87 | os.makedirs(save_dir, exist_ok=True) 88 | for i, mask_id in enumerate(all_masks_id): 89 | mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0) 90 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 91 | 92 | 93 | def generate_random_permutation_gt_obj_proposals(dataset_root, subset, save_path): 94 | dataset = Dataset(dataset_root, subset=subset, codalab=True) 95 | for seq in dataset.get_sequences(): 96 | gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 97 | obj_swap = np.random.permutation(np.arange(gt_masks.shape[0])) 98 | gt_masks = gt_masks[obj_swap, ...] 99 | save_dir = os.path.join(save_path, seq) 100 | os.makedirs(save_dir, exist_ok=True) 101 | for i, mask_id in enumerate(all_masks_id): 102 | mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0) 103 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 104 | 105 | 106 | def color_map(N=256, normalized=False): 107 | def bitget(byteval, idx): 108 | return ((byteval & (1 << idx)) != 0) 109 | 110 | dtype = 'float32' if normalized else 'uint8' 111 | cmap = np.zeros((N, 3), dtype=dtype) 112 | for i in range(N): 113 | r = g = b = 0 114 | c = i 115 | for j in range(8): 116 | r = r | (bitget(c, 0) << 7-j) 117 | g = g | (bitget(c, 1) << 7-j) 118 | b = b | (bitget(c, 2) << 7-j) 119 | c = c >> 3 120 | 121 | cmap[i] = np.array([r, g, b]) 122 | 123 | cmap = cmap/255 if normalized else cmap 124 | return cmap 125 | 126 | 127 | def save_mask(mask, img_path): 128 | if np.max(mask) > 255: 129 | raise ValueError('Maximum id pixel value is 255') 130 | mask_img = Image.fromarray(mask.astype(np.uint8)) 131 | mask_img.putpalette(color_map().flatten().tolist()) 132 | mask_img.save(img_path) 133 | 134 | 135 | def db_statistics(per_frame_values): 136 | """ Compute mean,recall and decay from per-frame evaluation. 137 | Arguments: 138 | per_frame_values (ndarray): per-frame evaluation 139 | 140 | Returns: 141 | M,O,D (float,float,float): 142 | return evaluation statistics: mean,recall,decay. 143 | """ 144 | 145 | # strip off nan values 146 | with warnings.catch_warnings(): 147 | warnings.simplefilter("ignore", category=RuntimeWarning) 148 | M = np.nanmean(per_frame_values) 149 | O = np.nanmean(per_frame_values > 0.5) 150 | 151 | N_bins = 4 152 | ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1 153 | ids = ids.astype(np.uint8) 154 | 155 | D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)] 156 | 157 | with warnings.catch_warnings(): 158 | warnings.simplefilter("ignore", category=RuntimeWarning) 159 | D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3]) 160 | 161 | return M, O, D 162 | 163 | 164 | def list_files(dir, extension=".png"): 165 | return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)] 166 | 167 | 168 | def force_symlink(file1, file2): 169 | try: 170 | os.symlink(file1, file2) 171 | except OSError as e: 172 | if e.errno == errno.EEXIST: 173 | os.remove(file2) 174 | os.symlink(file1, file2) 175 | -------------------------------------------------------------------------------- /figs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/.DS_Store -------------------------------------------------------------------------------- /figs/1403_cut_corn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/1403_cut_corn.gif -------------------------------------------------------------------------------- /figs/1GijsAKflo683C-s55149QvaQ47_PXz0q.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/1GijsAKflo683C-s55149QvaQ47_PXz0q.png -------------------------------------------------------------------------------- /figs/3525_break_eggs.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/3525_break_eggs.gif -------------------------------------------------------------------------------- /figs/4751_mold_clay.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/4751_mold_clay.gif -------------------------------------------------------------------------------- /figs/9672_divide_wheel.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/9672_divide_wheel.gif -------------------------------------------------------------------------------- /figs/9699_divide_wheel.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/9699_divide_wheel.gif -------------------------------------------------------------------------------- /figs/bear.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/bear.gif -------------------------------------------------------------------------------- /figs/bibtex.txt: -------------------------------------------------------------------------------- 1 | @inproceedings{tokmakov2023breaking, 2 | title={Breaking the “Object” in Video Object Segmentation}, 3 | author={Tokmakov, Pavel and Li, Jie and Gaidon, Adrien}, 4 | booktitle={CVPR}, 5 | year={2023} 6 | } -------------------------------------------------------------------------------- /figs/firstpage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/firstpage.png -------------------------------------------------------------------------------- /figs/skater.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/figs/skater.gif -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 13 | 14 | 15 | 16 | 17 | 140 | 141 | 142 | 143 | 144 | 145 | VOST: Video Object Segmentation under Transformations 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 |



154 |
155 | VOST: Video Object Segmentation under Transformations 156 |

157 | 158 | 159 | 160 | 165 | 170 | 175 | 180 | 185 | 186 | 187 |
161 |
162 | [Home] 163 |
164 |
166 |
167 | [Paper] 168 |
169 |
171 |
172 | [Data] 173 |
174 |
176 |
177 | [Code] 178 |
179 |
181 |
182 | [Workshop] 183 |
184 |
188 |
189 | 190 | 191 | 192 | 197 | 202 | 207 | 212 | 213 |
193 |
194 |
195 |
196 |
198 |
199 |
200 |
201 |
203 |
204 |
205 |
206 |
208 |
209 |
210 |
211 |
214 |
215 |
216 | 217 | 218 |
219 | 220 | 221 |
222 |
223 |

Dataset

224 |

VOST is a semi-supervised video object segmentation benchmark that focuses on complex object transformations. Differently from existing datasets, objects in VOST are broken, torn and molded into new shapes, dramatically changing their overall appearance. As our experiments demonstrate, this presents a major challenge for the mainstream, appearance-centric VOS methods. 225 | The dataset consists of more than 700 high-resolution videos, captured in diverse environments, which are 21 seconds long on average and densely labeled with instance masks. A careful, multi-step approach is adopted to ensure that these videos focus on complex transformations, capturing their full temporal extent. Below, we provide a few key statistics of the dataset. 226 |

227 |
228 |
    229 |
  • Number of videos: 713
  • 230 |
  • Number of frames: 75 547
  • 231 |
  • Number of transformations: 51
  • 232 |
  • Number of object categories: 154
  • 233 |
  • Annotation rate: 5 fps
  • 234 |
  • Average video length: 21.2 seconds
  • 235 |
236 |
237 |
238 |
239 | 240 | 241 |
242 | 243 |
244 |

Baseline Results

245 | 246 |
247 |
248 | 249 |

The video below shows the outputs of the AOT+ baseline from our paper on a few videos from the validation and test sets of VOST. AOT+ is an extension of AOT which improves its spatio-temporal modeling capacity. However, this model still largely relies on static appearance cues and struggles with complex transformations. 250 |

251 |
252 |
253 | 256 |
254 | 255 |
257 |
258 |
259 | 260 | 261 |
262 | 263 | 264 |

Paper

265 | 266 | 267 | 274 | 275 | 276 |
Pavel Tokmakov, Jie Li, Adrien Gaidon.
268 | Breaking the “Object” in Video Object Segmentation.
269 | CVPR 2023.
270 | 271 |
272 |
273 |
277 | 278 | 279 | 280 | 281 | 282 | 285 | 286 |
283 | [Bibtex] 284 |
287 | 288 |
289 |
290 | 291 | 292 | 293 | 299 | 300 |
294 |
295 | Email: 296 | support@vostdataset.org 297 |
298 |
301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | --------------------------------------------------------------------------------