├── CNAME ├── README.md ├── aot_plus ├── LICENSE ├── Makefile ├── README.md ├── configs │ ├── default.py │ ├── models │ │ ├── aotb.py │ │ ├── aotl.py │ │ ├── aots.py │ │ ├── aott.py │ │ ├── default.py │ │ ├── r101_aotl.py │ │ ├── r50_aotl.py │ │ ├── rs101_aotl.py │ │ └── swinb_aotl.py │ ├── pre.py │ ├── pre_dav.py │ ├── pre_vost.py │ ├── pre_ytb.py │ ├── pre_ytb_dav.py │ └── ytb.py ├── dataloaders │ ├── __init__.py │ ├── eval_datasets.py │ ├── image_transforms.py │ ├── train_datasets.py │ └── video_transforms.py ├── docker │ └── Dockerfile ├── networks │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── __init__.cpython-38.pyc │ ├── decoders │ │ ├── __init__.py │ │ └── fpn.py │ ├── encoders │ │ ├── __init__.py │ │ ├── mobilenetv2.py │ │ ├── mobilenetv3.py │ │ ├── resnest │ │ │ ├── __init__.py │ │ │ ├── resnest.py │ │ │ ├── resnet.py │ │ │ └── splat.py │ │ ├── resnet.py │ │ └── swin │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── swin_transformer.py │ ├── engines │ │ ├── __init__.py │ │ └── aot_engine.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── basic.py │ │ ├── loss.py │ │ ├── normalization.py │ │ ├── position.py │ │ └── transformer.py │ ├── managers │ │ ├── evaluator.py │ │ └── trainer.py │ └── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── aot.cpython-37.pyc │ │ └── aot.cpython-38.pyc │ │ └── aot.py ├── tools │ ├── demo.py │ ├── eval.py │ └── train.py ├── train_vost.sh └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── ema.py │ ├── eval.py │ ├── image.py │ ├── learning.py │ ├── math.py │ ├── meters.py │ └── metric.py ├── data.html ├── evaluation ├── .gitignore ├── LICENSE ├── README.md ├── evaluation_method.py └── source │ ├── __init__.py │ ├── dataset.py │ ├── evaluation.py │ ├── metrics.py │ ├── results.py │ └── utils.py ├── figs ├── .DS_Store ├── 1403_cut_corn.gif ├── 1GijsAKflo683C-s55149QvaQ47_PXz0q.png ├── 3525_break_eggs.gif ├── 4751_mold_clay.gif ├── 9672_divide_wheel.gif ├── 9699_divide_wheel.gif ├── bear.gif ├── bibtex.txt ├── firstpage.png └── skater.gif ├── index.html └── workshop.html /CNAME: -------------------------------------------------------------------------------- 1 | www.vostdataset.org 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VOST: Video Object Segmentation under Transformations 2 | 3 | The evaluation metric implementation is in the [evaluation](evaluation) folder. 4 | AOT+ baseline implementation and model checkpoints are available in the [aot_plus](aot_plus) folder. 5 | -------------------------------------------------------------------------------- /aot_plus/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, z-x-yang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /aot_plus/Makefile: -------------------------------------------------------------------------------- 1 | # Handy commands: 2 | # - `make docker-build`: builds DOCKERIMAGE (default: `packnet-sfm:latest`) 3 | PROJECT ?= aot_plus 4 | WORKSPACE ?= /workspace/$(PROJECT) 5 | DOCKER_IMAGE ?= ${PROJECT}:latest 6 | 7 | SHMSIZE ?= 444G 8 | DOCKER_OPTS := \ 9 | --name ${PROJECT} \ 10 | --rm -it \ 11 | --shm-size=${SHMSIZE} \ 12 | -e AWS_DEFAULT_REGION \ 13 | -e AWS_ACCESS_KEY_ID \ 14 | -e AWS_SECRET_ACCESS_KEY \ 15 | -e HOST_HOSTNAME= \ 16 | -e NCCL_DEBUG=VERSION \ 17 | -e DISPLAY=${DISPLAY} \ 18 | -e XAUTHORITY \ 19 | -e NVIDIA_DRIVER_CAPABILITIES=all \ 20 | -v ~/.aws:/root/.aws \ 21 | -v /root/.ssh:/root/.ssh \ 22 | -v ~/.cache:/root/.cache \ 23 | -v /data:/data \ 24 | -v /mnt/fsx/:/mnt/fsx \ 25 | -v /dev/null:/dev/raw1394 \ 26 | -v /tmp:/tmp \ 27 | -v /tmp/.X11-unix/X0:/tmp/.X11-unix/X0 \ 28 | -v /var/run/docker.sock:/var/run/docker.sock \ 29 | -v ${PWD}:${WORKSPACE} \ 30 | -w ${WORKSPACE} \ 31 | --privileged \ 32 | --ipc=host \ 33 | --network=host 34 | 35 | NGPUS=$(shell nvidia-smi -L | wc -l) 36 | 37 | 38 | .PHONY: all clean docker-build docker-overfit-pose 39 | 40 | all: clean 41 | 42 | clean: 43 | find . -name "*.pyc" | xargs rm -f && \ 44 | find . -name "__pycache__" | xargs rm -rf 45 | 46 | docker-build: 47 | docker build \ 48 | -f docker/Dockerfile \ 49 | -t ${DOCKER_IMAGE} . 50 | 51 | docker-start-interactive: docker-build 52 | nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} bash 53 | 54 | docker-start-jupyter: docker-build 55 | nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ 56 | bash -c "jupyter notebook --port=8888 -ip=0.0.0.0 --allow-root --no-browser" 57 | 58 | docker-run: docker-build 59 | nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ 60 | bash -c "${COMMAND}" 61 | 62 | docker-run-mpi: docker-build 63 | nvidia-docker run ${DOCKER_OPTS} ${DOCKER_IMAGE} \ 64 | bash -c "${MPI_CMD} ${COMMAND}" -------------------------------------------------------------------------------- /aot_plus/README.md: -------------------------------------------------------------------------------- 1 | # AOT+ 2 | This is the implementation of AOT+ baseline used in [VOST](https://www.vostdataset.org) dataset. The implementation is derived from [AOT](https://github.com/z-x-yang/AOT). 3 | 4 | ## Installation 5 | We provide a Docker file to re-create the environment which was used in our experiments under `$AOT_ROOT/docker/Dockerfile`. You can either configure the environment yourself using the docker file as a guide or build it via: 6 | ~~~ 7 | cd $AOT_ROOT 8 | make docker-build 9 | make docker-start-interactive 10 | ~~~ 11 | 12 | 13 | ## Training and evalaution 14 | 1. Link the VOST folder in [datasets/VOST](datasets/VOST) 15 | 16 | 2. To evaluate the pre-trained AOT+ model on the validation set of VOST download it from [here](https://tri-ml-public.s3.amazonaws.com/datasets/aotplus.pth) into the [pretrain_models](pretrain_models) folder and run the following command: 17 | 18 | ~~~ 19 | python tools/eval.py --exp_name aotplus --stage pre_vost --model r50_aotl --dataset vost --split val --gpu_num 8 --ckpt_path pretrain_models/aotplus.pth --ms 1.0 1.1 1.2 0.9 0.8 20 | ~~~ 21 | To compute the metrics please refer to the [evaluation](../evaluation/) folder. 22 | 23 | 3. To train AOT+ on VOST yourself download the chekpoint pre-trained on static imges and YouTubeVOS from [here](https://tri-ml-public.s3.amazonaws.com/datasets/pre_ytb.pth) into the [pretrain_models](pretrain_models) and run this script: 24 | 25 | ~~~ 26 | sh train_vost.sh 27 | ~~~ 28 | 29 | 30 | ## Citations 31 | Please consider citing the related paper(s) in your publications if it helps your research. 32 | ``` 33 | @inproceedings{tokmakov2023breaking, 34 | title={Breaking the “Object” in Video Object Segmentation}, 35 | author={Tokmakov, Pavel and Li, Jie and Gaidon, Adrien}, 36 | booktitle={CVPR}, 37 | year={2023} 38 | } 39 | 40 | @inproceedings{yang2021aot, 41 | title={Associating Objects with Transformers for Video Object Segmentation}, 42 | author={Yang, Zongxin and Wei, Yunchao and Yang, Yi}, 43 | booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, 44 | year={2021} 45 | } 46 | ``` 47 | 48 | ## License 49 | This project is released under the BSD-3-Clause license. See [LICENSE](LICENSE) for additional details. 50 | -------------------------------------------------------------------------------- /aot_plus/configs/default.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | 5 | class DefaultEngineConfig(): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | model_cfg = importlib.import_module('configs.models.' + 8 | model).ModelConfig() 9 | self.__dict__.update(model_cfg.__dict__) # add model config 10 | 11 | self.EXP_NAME = exp_name + '_' + self.MODEL_NAME 12 | 13 | self.STAGE_NAME = 'default' 14 | 15 | self.DATASETS = ['youtubevos'] 16 | self.DATA_WORKERS = 8 17 | self.DATA_RANDOMCROP = (465, 18 | 465) if self.MODEL_ALIGN_CORNERS else (464, 19 | 464) 20 | self.DATA_RANDOMFLIP = 0.5 21 | self.DATA_MAX_CROP_STEPS = 10 22 | self.DATA_SHORT_EDGE_LEN = 480 23 | self.DATA_MIN_SCALE_FACTOR = 0.7 24 | self.DATA_MAX_SCALE_FACTOR = 1.3 25 | self.DATA_RANDOM_REVERSE_SEQ = True 26 | self.DATA_SEQ_LEN = 5 27 | self.DATA_DAVIS_REPEAT = 5 28 | self.DATA_VOST_REPEAT = 1 29 | self.DATA_VOST_IGNORE_THRESH = 0.2 30 | self.DATA_VOST_ALL_FRAMES = False 31 | self.DATA_VOST_VALID_FRAMES = False 32 | self.DATA_RANDOM_GAP_DAVIS = 12 # max frame interval between two sampled frames for DAVIS (24fps) 33 | self.DATA_RANDOM_GAP_YTB = 3 # max frame interval between two sampled frames for YouTube-VOS (6fps) 34 | self.DATA_RANDOM_GAP_VOST = 3 35 | self.DATA_RANDOM_GAP_VISOR = 1 36 | self.DATA_DYNAMIC_MERGE_PROB = 0.2 37 | self.IGNORE_IN_MERGE = True 38 | self.DATA_VISOR_REPEAT = 1 39 | self.DATA_VISOR_IGNORE_THRESH = 0.2 40 | 41 | self.PRETRAIN = True 42 | self.PRETRAIN_FULL = False # if False, load encoder only 43 | self.PRETRAIN_MODEL = '' 44 | 45 | self.TRAIN_TOTAL_STEPS = 100000 46 | self.TRAIN_START_STEP = 0 47 | self.TRAIN_WEIGHT_DECAY = 0.07 48 | self.TRAIN_WEIGHT_DECAY_EXCLUSIVE = { 49 | # 'encoder.': 0.01 50 | } 51 | self.TRAIN_WEIGHT_DECAY_EXEMPTION = [ 52 | 'absolute_pos_embed', 'relative_position_bias_table', 53 | 'relative_emb_v', 'conv_out' 54 | ] 55 | self.TRAIN_LR = 2e-4 56 | self.TRAIN_LR_MIN = 2e-5 if 'mobilenetv2' in self.MODEL_ENCODER else 1e-5 57 | self.TRAIN_LR_POWER = 0.9 58 | self.TRAIN_LR_ENCODER_RATIO = 0.1 59 | self.TRAIN_LR_WARM_UP_RATIO = 0.05 60 | self.TRAIN_LR_COSINE_DECAY = False 61 | self.TRAIN_LR_RESTART = 1 62 | self.TRAIN_LR_UPDATE_STEP = 1 63 | self.TRAIN_AUX_LOSS_WEIGHT = 1.0 64 | self.TRAIN_AUX_LOSS_RATIO = 1.0 65 | self.TRAIN_OPT = 'adamw' 66 | self.TRAIN_SGD_MOMENTUM = 0.9 67 | self.TRAIN_GPUS = 4 68 | self.TRAIN_BATCH_SIZE = 16 69 | self.TRAIN_TBLOG = False 70 | self.TRAIN_TBLOG_STEP = 50 71 | self.TRAIN_LOG_STEP = 20 72 | self.TRAIN_IMG_LOG = True 73 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15 74 | self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank'] 75 | self.TRAIN_SEQ_TRAINING_START_RATIO = 0.5 76 | self.TRAIN_HARD_MINING_RATIO = 0.5 77 | self.TRAIN_EMA_RATIO = 0.1 78 | self.TRAIN_CLIP_GRAD_NORM = 5. 79 | self.TRAIN_SAVE_STEP = 1000 80 | self.TRAIN_EVAL = False 81 | self.TRAIN_MAX_KEEP_CKPT = 8 82 | self.TRAIN_RESUME = False 83 | self.TRAIN_RESUME_CKPT = None 84 | self.TRAIN_RESUME_STEP = 0 85 | self.TRAIN_AUTO_RESUME = True 86 | self.TRAIN_DATASET_FULL_RESOLUTION = False 87 | self.TRAIN_ENABLE_PREV_FRAME = False 88 | self.TRAIN_ENCODER_FREEZE_AT = 2 89 | self.TRAIN_LSTT_EMB_DROPOUT = 0. 90 | self.TRAIN_LSTT_ID_DROPOUT = 0. 91 | self.TRAIN_LSTT_DROPPATH = 0.1 92 | self.TRAIN_LSTT_DROPPATH_SCALING = False 93 | self.TRAIN_LSTT_DROPPATH_LST = False 94 | self.TRAIN_LSTT_LT_DROPOUT = 0. 95 | self.TRAIN_LSTT_ST_DROPOUT = 0. 96 | 97 | self.TEST_GPU_ID = 0 98 | self.TEST_GPU_NUM = 1 99 | self.TEST_FRAME_LOG = False 100 | self.TEST_DATASET = 'youtubevos' 101 | self.TEST_DATASET_FULL_RESOLUTION = False 102 | self.TEST_DATASET_SPLIT = 'val' 103 | self.TEST_CKPT_PATH = None 104 | # if "None", evaluate the latest checkpoint. 105 | self.TEST_CKPT_STEP = None 106 | self.TEST_FLIP = False 107 | self.TEST_MULTISCALE = [1] 108 | self.TEST_MIN_SIZE = None 109 | self.TEST_MAX_SIZE = 800 * 1.3 110 | self.TEST_WORKERS = 4 111 | 112 | # GPU distribution 113 | self.DIST_ENABLE = True 114 | self.DIST_BACKEND = "nccl" # "gloo" 115 | self.DIST_URL = "tcp://127.0.0.1:13241" 116 | self.DIST_START_GPU = 0 117 | 118 | def init_dir(self): 119 | self.DIR_DATA = './datasets' 120 | self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS') 121 | self.DIR_VOST = os.path.join(self.DIR_DATA, 'VOST') 122 | self.DIR_VISOR = os.path.join(self.DIR_DATA, 'VISOR') 123 | self.DIR_YTB = os.path.join(self.DIR_DATA, 'YTB') 124 | self.DIR_STATIC = os.path.join(self.DIR_DATA, 'Static') 125 | 126 | self.DIR_ROOT = './results' 127 | 128 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, self.EXP_NAME, 129 | self.STAGE_NAME) 130 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt') 131 | self.DIR_EMA_CKPT = os.path.join(self.DIR_RESULT, 'ema_ckpt') 132 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log') 133 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard') 134 | self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img') 135 | self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval') 136 | 137 | for path in [ 138 | self.DIR_RESULT, self.DIR_CKPT, self.DIR_EMA_CKPT, 139 | self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, 140 | self.DIR_TB_LOG 141 | ]: 142 | if not os.path.isdir(path): 143 | try: 144 | os.makedirs(path) 145 | except Exception as inst: 146 | print(inst) 147 | print('Failed to make dir: {}.'.format(path)) 148 | -------------------------------------------------------------------------------- /aot_plus/configs/models/aotb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTB' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | -------------------------------------------------------------------------------- /aot_plus/configs/models/aotl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTL' 8 | 9 | self.MODEL_LSTT_NUM = 3 10 | 11 | self.TRAIN_LONG_TERM_MEM_GAP = 2 12 | 13 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/models/aots.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTS' 8 | 9 | self.MODEL_LSTT_NUM = 2 10 | -------------------------------------------------------------------------------- /aot_plus/configs/models/aott.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultModelConfig 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'AOTT' 8 | -------------------------------------------------------------------------------- /aot_plus/configs/models/default.py: -------------------------------------------------------------------------------- 1 | class DefaultModelConfig(): 2 | def __init__(self): 3 | self.MODEL_NAME = 'AOTDefault' 4 | 5 | self.MODEL_VOS = 'aot' 6 | self.MODEL_ENGINE = 'aotengine' 7 | self.MODEL_ALIGN_CORNERS = True 8 | self.MODEL_ENCODER = 'mobilenetv2' 9 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/mobilenet_v2-b0353104.pth' # https://download.pytorch.org/models/mobilenet_v2-b0353104.pth 10 | self.MODEL_ENCODER_DIM = [24, 32, 96, 1280] # 4x, 8x, 16x, 16x 11 | self.MODEL_ENCODER_EMBEDDING_DIM = 256 12 | self.MODEL_DECODER_INTERMEDIATE_LSTT = True 13 | self.MODEL_SIMPLIFIED_STM = True 14 | self.MODEL_STM_STOPGRAD = False 15 | self.MODEL_LINEAR_Q = True 16 | self.MODEL_NORM_INP = True 17 | self.MODEL_RECURRENT_LTM = False 18 | self.MODEL_RECURRENT_STM = True 19 | self.MODEL_JOINT_LONGATT = True 20 | self.MODEL_FREEZE_BN = True 21 | self.MODEL_FREEZE_BACKBONE = False 22 | self.MODEL_MAX_OBJ_NUM = 10 23 | self.MODEL_IGNORE_TOKEN = True 24 | self.MODEL_SELF_HEADS = 8 25 | self.MODEL_ATT_HEADS = 8 26 | self.MODEL_LSTT_NUM = 1 27 | self.MODEL_EPSILON = 1e-5 28 | self.MODEL_USE_PREV_PROB = False 29 | 30 | self.TRAIN_LONG_TERM_MEM_GAP = 9999 31 | 32 | self.TEST_LONG_TERM_MEM_GAP = 9999 33 | -------------------------------------------------------------------------------- /aot_plus/configs/models/r101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/models/r50_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R50_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnet50' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/models/rs101_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'R101_AOTL' 8 | 9 | self.MODEL_ENCODER = 'resnest101' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnest101-22405ba7.pth' # https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth 11 | self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x 12 | self.MODEL_LSTT_NUM = 3 13 | 14 | self.TRAIN_LONG_TERM_MEM_GAP = 2 15 | 16 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/models/swinb_aotl.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultModelConfig 2 | 3 | 4 | class ModelConfig(DefaultModelConfig): 5 | def __init__(self): 6 | super().__init__() 7 | self.MODEL_NAME = 'SwinB_AOTL' 8 | 9 | self.MODEL_ENCODER = 'swin_base' 10 | self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth 11 | self.MODEL_ALIGN_CORNERS = False 12 | self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x 13 | self.MODEL_LSTT_NUM = 3 14 | 15 | self.TRAIN_LONG_TERM_MEM_GAP = 2 16 | 17 | self.TEST_LONG_TERM_MEM_GAP = 5 -------------------------------------------------------------------------------- /aot_plus/configs/pre.py: -------------------------------------------------------------------------------- 1 | from .default import DefaultEngineConfig 2 | 3 | 4 | class EngineConfig(DefaultEngineConfig): 5 | def __init__(self, exp_name='default', model='AOTT'): 6 | super().__init__(exp_name, model) 7 | self.STAGE_NAME = 'pre' 8 | 9 | self.init_dir() 10 | 11 | self.DATASETS = ['static'] 12 | 13 | self.DATA_DYNAMIC_MERGE_PROB = 1.0 14 | 15 | self.TRAIN_LR = 4e-4 16 | self.TRAIN_LR_MIN = 2e-5 17 | self.TRAIN_WEIGHT_DECAY = 0.03 18 | self.TRAIN_SEQ_TRAINING_START_RATIO = 1.0 19 | self.TRAIN_AUX_LOSS_RATIO = 0.1 20 | 21 | self.MODEL_SIMPLIFIED_STM = True 22 | self.MODEL_LINEAR_Q = True 23 | self.MODEL_RECURRENT_STM = True 24 | self.MODEL_JOINT_LONGATT = True 25 | -------------------------------------------------------------------------------- /aot_plus/configs/pre_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['davis2017'] 13 | 14 | self.TRAIN_TOTAL_STEPS = 50000 15 | 16 | pretrain_stage = 'PRE' 17 | pretrain_ckpt = 'save_step_100000.pth' 18 | self.PRETRAIN_FULL = True # if False, load encoder only 19 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 20 | self.EXP_NAME, pretrain_stage, 21 | 'ema_ckpt', pretrain_ckpt) 22 | -------------------------------------------------------------------------------- /aot_plus/configs/pre_vost.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'pre_vost' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['vost'] 13 | self.TRAIN_TOTAL_STEPS = 20000 14 | self.DATA_SEQ_LEN = 15 15 | self.TRAIN_LONG_TERM_MEM_GAP = 4 16 | self.MODEL_LINEAR_Q = False 17 | self.MODEL_JOINT_LONGATT = False 18 | self.MODEL_SIMPLIFIED_STM = True 19 | self.MODEL_RECURRENT_STM = True 20 | self.MODEL_IGNORE_TOKEN = True 21 | 22 | self.PRETRAIN_FULL = True # if False, load encoder only 23 | self.PRETRAIN_MODEL = os.path.join('pretrain_models', 'pre_ytb.pth') -------------------------------------------------------------------------------- /aot_plus/configs/pre_ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'pre_ytb' 9 | 10 | self.init_dir() 11 | 12 | pretrain_stage = 'PRE' 13 | pretrain_ckpt = 'save_step_100000.pth' 14 | self.DATA_SEQ_LEN = 10 15 | self.TRAIN_LONG_TERM_MEM_GAP = 4 16 | self.MODEL_JOINT_LONGATT = False 17 | self.MODEL_STM_STOPGRAD = False 18 | self.TRAIN_TOTAL_STEPS = 80000 19 | self.MODEL_LINEAR_Q = True 20 | self.PRETRAIN_FULL = True # if False, load encoder only 21 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 22 | self.EXP_NAME, pretrain_stage, 23 | 'ema_ckpt', pretrain_ckpt) 24 | -------------------------------------------------------------------------------- /aot_plus/configs/pre_ytb_dav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'PRE_YTB_DAV' 9 | 10 | self.init_dir() 11 | 12 | self.DATASETS = ['youtubevos', 'davis2017'] 13 | 14 | pretrain_stage = 'PRE' 15 | pretrain_ckpt = 'save_step_100000.pth' 16 | self.PRETRAIN_FULL = True # if False, load encoder only 17 | self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result', 18 | self.EXP_NAME, pretrain_stage, 19 | 'ema_ckpt', pretrain_ckpt) 20 | -------------------------------------------------------------------------------- /aot_plus/configs/ytb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .default import DefaultEngineConfig 3 | 4 | 5 | class EngineConfig(DefaultEngineConfig): 6 | def __init__(self, exp_name='default', model='AOTT'): 7 | super().__init__(exp_name, model) 8 | self.STAGE_NAME = 'YTB' 9 | 10 | self.init_dir() 11 | -------------------------------------------------------------------------------- /aot_plus/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/dataloaders/__init__.py -------------------------------------------------------------------------------- /aot_plus/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PYTORCH="1.9.0" 2 | ARG CUDA="11.1" 3 | ARG CUDNN="8" 4 | 5 | FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel 6 | RUN rm -rf /etc/apt/sources.list.d/cuda.list 7 | RUN rm -rf /etc/apt/sources.list.d/nvidia-ml.list 8 | RUN apt-key del 7fa2af80 9 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 10 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 11 | 12 | ENV PROJECT=aot 13 | ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX" 14 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" 15 | ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" 16 | ARG python=3.7 17 | ENV PYTHON_VERSION=${python} 18 | 19 | # To fix GPG key error when running apt-get update 20 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 21 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 22 | 23 | # Core tools 24 | RUN apt-get update && apt-get install -y \ 25 | cmake \ 26 | curl \ 27 | docker.io \ 28 | ffmpeg \ 29 | git \ 30 | htop \ 31 | libsm6 \ 32 | libxext6 \ 33 | libglib2.0-0 \ 34 | libsm6 \ 35 | libxrender-dev \ 36 | libxext6 \ 37 | ninja-build \ 38 | unzip \ 39 | vim \ 40 | wget \ 41 | sudo \ 42 | && apt-get clean \ 43 | && rm -rf /var/lib/apt/lists/* 44 | 45 | # Install OpenSSH for MPI to communicate between containers 46 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ 47 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ 48 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config 49 | 50 | RUN ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python 51 | 52 | RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ 53 | python get-pip.py && \ 54 | rm get-pip.py 55 | 56 | # Install Pydata and other deps 57 | RUN pip install easydict scipy numpy pyquaternion matplotlib jupyter h5py \ 58 | awscli tqdm progress path.py pyyaml opencv-python pandas \ 59 | numba cython scikit-learn moviepy imageio yacs Pillow scikit-image 60 | 61 | # Settings for S3 62 | RUN aws configure set default.s3.max_concurrent_requests 100 && \ 63 | aws configure set default.s3.max_queue_size 10000 64 | 65 | #pip install spatial-correlation-sampler 66 | 67 | # Expose Port for jupyter (8888) 68 | EXPOSE 8888 69 | 70 | # create project workspace dir 71 | WORKDIR /workspace/${PROJECT} 72 | 73 | ENV PYTHONPATH="/workspace/${PROJECT}:$PYTHONPATH" 74 | RUN git config --global --add safe.directory /workspace/${PROJECT} 75 | -------------------------------------------------------------------------------- /aot_plus/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/__init__.py -------------------------------------------------------------------------------- /aot_plus/networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /aot_plus/networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /aot_plus/networks/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.decoders.fpn import FPNSegmentationHead 2 | 3 | 4 | def build_decoder(name, **kwargs): 5 | 6 | if name == 'fpn': 7 | return FPNSegmentationHead(**kwargs) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /aot_plus/networks/decoders/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.layers.basic import ConvGN 5 | 6 | 7 | class FPNSegmentationHead(nn.Module): 8 | def __init__(self, 9 | in_dim, 10 | out_dim, 11 | decode_intermediate_input=True, 12 | hidden_dim=256, 13 | shortcut_dims=[24, 32, 96, 1280], 14 | align_corners=True): 15 | super().__init__() 16 | self.align_corners = align_corners 17 | 18 | self.decode_intermediate_input = decode_intermediate_input 19 | 20 | self.conv_in = ConvGN(in_dim, hidden_dim, 1) 21 | 22 | self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3) 23 | self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3) 24 | self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3) 25 | 26 | self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1) 27 | self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1) 28 | self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1) 29 | 30 | self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1) 31 | 32 | self._init_weight() 33 | 34 | def forward(self, inputs, shortcuts): 35 | 36 | if self.decode_intermediate_input: 37 | x = torch.cat(inputs, dim=1) 38 | else: 39 | x = inputs[-1] 40 | 41 | x = F.relu_(self.conv_in(x)) 42 | x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x)) 43 | 44 | x = F.interpolate(x, 45 | size=shortcuts[-3].size()[-2:], 46 | mode="bilinear", 47 | align_corners=self.align_corners) 48 | x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x)) 49 | 50 | x = F.interpolate(x, 51 | size=shortcuts[-4].size()[-2:], 52 | mode="bilinear", 53 | align_corners=self.align_corners) 54 | x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x)) 55 | 56 | x = self.conv_out(x) 57 | 58 | return x 59 | 60 | def _init_weight(self): 61 | for p in self.parameters(): 62 | if p.dim() > 1: 63 | nn.init.xavier_uniform_(p) 64 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.encoders.mobilenetv2 import MobileNetV2 2 | from networks.encoders.mobilenetv3 import MobileNetV3Large 3 | from networks.encoders.resnet import ResNet101, ResNet50 4 | from networks.encoders.resnest import resnest 5 | from networks.encoders.swin import build_swin_model 6 | from networks.layers.normalization import FrozenBatchNorm2d 7 | from torch import nn 8 | 9 | 10 | def build_encoder(name, frozen_bn=True, freeze_at=-1): 11 | if frozen_bn: 12 | BatchNorm = FrozenBatchNorm2d 13 | else: 14 | BatchNorm = nn.BatchNorm2d 15 | 16 | if name == 'mobilenetv2': 17 | return MobileNetV2(16, BatchNorm, freeze_at=freeze_at) 18 | elif name == 'mobilenetv3': 19 | return MobileNetV3Large(16, BatchNorm, freeze_at=freeze_at) 20 | elif name == 'resnet50': 21 | return ResNet50(16, BatchNorm, freeze_at=freeze_at) 22 | elif name == 'resnet101': 23 | return ResNet101(16, BatchNorm, freeze_at=freeze_at) 24 | elif name == 'resnest50': 25 | return resnest.resnest50(norm_layer=BatchNorm, 26 | dilation=2, 27 | freeze_at=freeze_at) 28 | elif name == 'resnest101': 29 | return resnest.resnest101(norm_layer=BatchNorm, 30 | dilation=2, 31 | freeze_at=freeze_at) 32 | elif 'swin' in name: 33 | return build_swin_model(name, freeze_at=freeze_at) 34 | else: 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import Tensor 3 | from typing import Callable, Optional, List 4 | from utils.learning import freeze_params 5 | 6 | __all__ = ['MobileNetV2'] 7 | 8 | 9 | def _make_divisible(v: float, 10 | divisor: int, 11 | min_value: Optional[int] = None) -> int: 12 | """ 13 | This function is taken from the original tf repo. 14 | It ensures that all layers have a channel number that is divisible by 8 15 | It can be seen here: 16 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 17 | """ 18 | if min_value is None: 19 | min_value = divisor 20 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 21 | # Make sure that round down does not go down by more than 10%. 22 | if new_v < 0.9 * v: 23 | new_v += divisor 24 | return new_v 25 | 26 | 27 | class ConvBNActivation(nn.Sequential): 28 | def __init__( 29 | self, 30 | in_planes: int, 31 | out_planes: int, 32 | kernel_size: int = 3, 33 | stride: int = 1, 34 | groups: int = 1, 35 | padding: int = -1, 36 | norm_layer: Optional[Callable[..., nn.Module]] = None, 37 | activation_layer: Optional[Callable[..., nn.Module]] = None, 38 | dilation: int = 1, 39 | ) -> None: 40 | if padding == -1: 41 | padding = (kernel_size - 1) // 2 * dilation 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if activation_layer is None: 45 | activation_layer = nn.ReLU6 46 | super().__init__( 47 | nn.Conv2d(in_planes, 48 | out_planes, 49 | kernel_size, 50 | stride, 51 | padding, 52 | dilation=dilation, 53 | groups=groups, 54 | bias=False), norm_layer(out_planes), 55 | activation_layer(inplace=True)) 56 | self.out_channels = out_planes 57 | 58 | 59 | # necessary for backwards compatibility 60 | ConvBNReLU = ConvBNActivation 61 | 62 | 63 | class InvertedResidual(nn.Module): 64 | def __init__( 65 | self, 66 | inp: int, 67 | oup: int, 68 | stride: int, 69 | dilation: int, 70 | expand_ratio: int, 71 | norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: 72 | super(InvertedResidual, self).__init__() 73 | self.stride = stride 74 | assert stride in [1, 2] 75 | 76 | if norm_layer is None: 77 | norm_layer = nn.BatchNorm2d 78 | 79 | self.kernel_size = 3 80 | self.dilation = dilation 81 | 82 | hidden_dim = int(round(inp * expand_ratio)) 83 | self.use_res_connect = self.stride == 1 and inp == oup 84 | 85 | layers: List[nn.Module] = [] 86 | if expand_ratio != 1: 87 | # pw 88 | layers.append( 89 | ConvBNReLU(inp, 90 | hidden_dim, 91 | kernel_size=1, 92 | norm_layer=norm_layer)) 93 | layers.extend([ 94 | # dw 95 | ConvBNReLU(hidden_dim, 96 | hidden_dim, 97 | stride=stride, 98 | dilation=dilation, 99 | groups=hidden_dim, 100 | norm_layer=norm_layer), 101 | # pw-linear 102 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 103 | norm_layer(oup), 104 | ]) 105 | self.conv = nn.Sequential(*layers) 106 | self.out_channels = oup 107 | self._is_cn = stride > 1 108 | 109 | def forward(self, x: Tensor) -> Tensor: 110 | if self.use_res_connect: 111 | return x + self.conv(x) 112 | else: 113 | return self.conv(x) 114 | 115 | 116 | class MobileNetV2(nn.Module): 117 | def __init__(self, 118 | output_stride=8, 119 | norm_layer: Optional[Callable[..., nn.Module]] = None, 120 | width_mult: float = 1.0, 121 | inverted_residual_setting: Optional[List[List[int]]] = None, 122 | round_nearest: int = 8, 123 | block: Optional[Callable[..., nn.Module]] = None, 124 | freeze_at=0) -> None: 125 | """ 126 | MobileNet V2 main class 127 | Args: 128 | num_classes (int): Number of classes 129 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 130 | inverted_residual_setting: Network structure 131 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 132 | Set to 1 to turn off rounding 133 | block: Module specifying inverted residual building block for mobilenet 134 | norm_layer: Module specifying the normalization layer to use 135 | """ 136 | super(MobileNetV2, self).__init__() 137 | 138 | if block is None: 139 | block = InvertedResidual 140 | 141 | if norm_layer is None: 142 | norm_layer = nn.BatchNorm2d 143 | 144 | last_channel = 1280 145 | input_channel = 32 146 | current_stride = 1 147 | rate = 1 148 | 149 | if inverted_residual_setting is None: 150 | inverted_residual_setting = [ 151 | # t, c, n, s 152 | [1, 16, 1, 1], 153 | [6, 24, 2, 2], 154 | [6, 32, 3, 2], 155 | [6, 64, 4, 2], 156 | [6, 96, 3, 1], 157 | [6, 160, 3, 2], 158 | [6, 320, 1, 1], 159 | ] 160 | 161 | # only check the first element, assuming user knows t,c,n,s are required 162 | if len(inverted_residual_setting) == 0 or len( 163 | inverted_residual_setting[0]) != 4: 164 | raise ValueError("inverted_residual_setting should be non-empty " 165 | "or a 4-element list, got {}".format( 166 | inverted_residual_setting)) 167 | 168 | # building first layer 169 | input_channel = _make_divisible(input_channel * width_mult, 170 | round_nearest) 171 | self.last_channel = _make_divisible( 172 | last_channel * max(1.0, width_mult), round_nearest) 173 | features: List[nn.Module] = [ 174 | ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer) 175 | ] 176 | current_stride *= 2 177 | # building inverted residual blocks 178 | for t, c, n, s in inverted_residual_setting: 179 | if current_stride == output_stride: 180 | stride = 1 181 | dilation = rate 182 | rate *= s 183 | else: 184 | stride = s 185 | dilation = 1 186 | current_stride *= s 187 | output_channel = _make_divisible(c * width_mult, round_nearest) 188 | for i in range(n): 189 | if i == 0: 190 | features.append( 191 | block(input_channel, output_channel, stride, dilation, 192 | t, norm_layer)) 193 | else: 194 | features.append( 195 | block(input_channel, output_channel, 1, rate, t, 196 | norm_layer)) 197 | input_channel = output_channel 198 | 199 | # building last several layers 200 | features.append( 201 | ConvBNReLU(input_channel, 202 | self.last_channel, 203 | kernel_size=1, 204 | norm_layer=norm_layer)) 205 | # make it nn.Sequential 206 | self.features = nn.Sequential(*features) 207 | 208 | self._initialize_weights() 209 | 210 | feature_4x = self.features[0:4] 211 | feautre_8x = self.features[4:7] 212 | feature_16x = self.features[7:14] 213 | feature_32x = self.features[14:] 214 | 215 | self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x] 216 | 217 | self.freeze(freeze_at) 218 | 219 | def forward(self, x): 220 | xs = [] 221 | for stage in self.stages: 222 | x = stage(x) 223 | xs.append(x) 224 | return xs 225 | 226 | def _initialize_weights(self): 227 | # weight initialization 228 | for m in self.modules(): 229 | if isinstance(m, nn.Conv2d): 230 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 231 | if m.bias is not None: 232 | nn.init.zeros_(m.bias) 233 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 234 | nn.init.ones_(m.weight) 235 | nn.init.zeros_(m.bias) 236 | elif isinstance(m, nn.Linear): 237 | nn.init.normal_(m.weight, 0, 0.01) 238 | nn.init.zeros_(m.bias) 239 | 240 | def freeze(self, freeze_at): 241 | if freeze_at >= 1: 242 | for m in self.stages[0][0]: 243 | freeze_params(m) 244 | 245 | for idx, stage in enumerate(self.stages, start=2): 246 | if freeze_at >= idx: 247 | freeze_params(stage) 248 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a MobileNetV3 Model as defined in: 3 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019). 4 | Searching for MobileNetV3 5 | arXiv preprint arXiv:1905.02244. 6 | """ 7 | 8 | import torch.nn as nn 9 | import math 10 | from utils.learning import freeze_params 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class h_sigmoid(nn.Module): 34 | def __init__(self, inplace=True): 35 | super(h_sigmoid, self).__init__() 36 | self.relu = nn.ReLU6(inplace=inplace) 37 | 38 | def forward(self, x): 39 | return self.relu(x + 3) / 6 40 | 41 | 42 | class h_swish(nn.Module): 43 | def __init__(self, inplace=True): 44 | super(h_swish, self).__init__() 45 | self.sigmoid = h_sigmoid(inplace=inplace) 46 | 47 | def forward(self, x): 48 | return x * self.sigmoid(x) 49 | 50 | 51 | class SELayer(nn.Module): 52 | def __init__(self, channel, reduction=4): 53 | super(SELayer, self).__init__() 54 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 55 | self.fc = nn.Sequential( 56 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 57 | nn.ReLU(inplace=True), 58 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 59 | h_sigmoid()) 60 | 61 | def forward(self, x): 62 | b, c, _, _ = x.size() 63 | y = self.avg_pool(x).view(b, c) 64 | y = self.fc(y).view(b, c, 1, 1) 65 | return x * y 66 | 67 | 68 | def conv_3x3_bn(inp, oup, stride, norm_layer=nn.BatchNorm2d): 69 | return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 70 | norm_layer(oup), h_swish()) 71 | 72 | 73 | def conv_1x1_bn(inp, oup, norm_layer=nn.BatchNorm2d): 74 | return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 75 | norm_layer(oup), h_swish()) 76 | 77 | 78 | class InvertedResidual(nn.Module): 79 | def __init__(self, 80 | inp, 81 | hidden_dim, 82 | oup, 83 | kernel_size, 84 | stride, 85 | use_se, 86 | use_hs, 87 | dilation=1, 88 | norm_layer=nn.BatchNorm2d): 89 | super(InvertedResidual, self).__init__() 90 | assert stride in [1, 2] 91 | 92 | self.identity = stride == 1 and inp == oup 93 | 94 | if inp == hidden_dim: 95 | self.conv = nn.Sequential( 96 | # dw 97 | nn.Conv2d(hidden_dim, 98 | hidden_dim, 99 | kernel_size, 100 | stride, (kernel_size - 1) // 2 * dilation, 101 | dilation=dilation, 102 | groups=hidden_dim, 103 | bias=False), 104 | norm_layer(hidden_dim), 105 | h_swish() if use_hs else nn.ReLU(inplace=True), 106 | # Squeeze-and-Excite 107 | SELayer(hidden_dim) if use_se else nn.Identity(), 108 | # pw-linear 109 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 110 | norm_layer(oup), 111 | ) 112 | else: 113 | self.conv = nn.Sequential( 114 | # pw 115 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 116 | norm_layer(hidden_dim), 117 | h_swish() if use_hs else nn.ReLU(inplace=True), 118 | # dw 119 | nn.Conv2d(hidden_dim, 120 | hidden_dim, 121 | kernel_size, 122 | stride, (kernel_size - 1) // 2 * dilation, 123 | dilation=dilation, 124 | groups=hidden_dim, 125 | bias=False), 126 | norm_layer(hidden_dim), 127 | # Squeeze-and-Excite 128 | SELayer(hidden_dim) if use_se else nn.Identity(), 129 | h_swish() if use_hs else nn.ReLU(inplace=True), 130 | # pw-linear 131 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 132 | norm_layer(oup), 133 | ) 134 | 135 | def forward(self, x): 136 | if self.identity: 137 | return x + self.conv(x) 138 | else: 139 | return self.conv(x) 140 | 141 | 142 | class MobileNetV3Large(nn.Module): 143 | def __init__(self, 144 | output_stride=16, 145 | norm_layer=nn.BatchNorm2d, 146 | width_mult=1., 147 | freeze_at=0): 148 | super(MobileNetV3Large, self).__init__() 149 | """ 150 | Constructs a MobileNetV3-Large model 151 | """ 152 | cfgs = [ 153 | # k, t, c, SE, HS, s 154 | [3, 1, 16, 0, 0, 1], 155 | [3, 4, 24, 0, 0, 2], 156 | [3, 3, 24, 0, 0, 1], 157 | [5, 3, 40, 1, 0, 2], 158 | [5, 3, 40, 1, 0, 1], 159 | [5, 3, 40, 1, 0, 1], 160 | [3, 6, 80, 0, 1, 2], 161 | [3, 2.5, 80, 0, 1, 1], 162 | [3, 2.3, 80, 0, 1, 1], 163 | [3, 2.3, 80, 0, 1, 1], 164 | [3, 6, 112, 1, 1, 1], 165 | [3, 6, 112, 1, 1, 1], 166 | [5, 6, 160, 1, 1, 2], 167 | [5, 6, 160, 1, 1, 1], 168 | [5, 6, 160, 1, 1, 1] 169 | ] 170 | self.cfgs = cfgs 171 | 172 | # building first layer 173 | input_channel = _make_divisible(16 * width_mult, 8) 174 | layers = [conv_3x3_bn(3, input_channel, 2, norm_layer)] 175 | # building inverted residual blocks 176 | block = InvertedResidual 177 | now_stride = 2 178 | rate = 1 179 | for k, t, c, use_se, use_hs, s in self.cfgs: 180 | if now_stride == output_stride: 181 | dilation = rate 182 | rate *= s 183 | s = 1 184 | else: 185 | dilation = 1 186 | now_stride *= s 187 | output_channel = _make_divisible(c * width_mult, 8) 188 | exp_size = _make_divisible(input_channel * t, 8) 189 | layers.append( 190 | block(input_channel, exp_size, output_channel, k, s, use_se, 191 | use_hs, dilation, norm_layer)) 192 | input_channel = output_channel 193 | 194 | self.features = nn.Sequential(*layers) 195 | self.conv = conv_1x1_bn(input_channel, exp_size, norm_layer) 196 | # building last several layers 197 | 198 | self._initialize_weights() 199 | 200 | feature_4x = self.features[0:4] 201 | feautre_8x = self.features[4:7] 202 | feature_16x = self.features[7:13] 203 | feature_32x = self.features[13:] 204 | 205 | self.stages = [feature_4x, feautre_8x, feature_16x, feature_32x] 206 | 207 | self.freeze(freeze_at) 208 | 209 | def forward(self, x): 210 | xs = [] 211 | for stage in self.stages: 212 | x = stage(x) 213 | xs.append(x) 214 | xs[-1] = self.conv(xs[-1]) 215 | return xs 216 | 217 | def _initialize_weights(self): 218 | for m in self.modules(): 219 | if isinstance(m, nn.Conv2d): 220 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 221 | m.weight.data.normal_(0, math.sqrt(2. / n)) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | elif isinstance(m, nn.BatchNorm2d): 225 | m.weight.data.fill_(1) 226 | m.bias.data.zero_() 227 | elif isinstance(m, nn.Linear): 228 | n = m.weight.size(1) 229 | m.weight.data.normal_(0, 0.01) 230 | m.bias.data.zero_() 231 | 232 | def freeze(self, freeze_at): 233 | if freeze_at >= 1: 234 | for m in self.stages[0][0]: 235 | freeze_params(m) 236 | 237 | for idx, stage in enumerate(self.stages, start=2): 238 | if freeze_at >= idx: 239 | freeze_params(stage) 240 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnest/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnest import * 2 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnest/resnest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnet import ResNet, Bottleneck 3 | 4 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 5 | 6 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 7 | 8 | _model_sha256 = { 9 | name: checksum 10 | for checksum, name in [ 11 | ('528c19ca', 'resnest50'), 12 | ('22405ba7', 'resnest101'), 13 | ('75117900', 'resnest200'), 14 | ('0cc87c48', 'resnest269'), 15 | ] 16 | } 17 | 18 | 19 | def short_hash(name): 20 | if name not in _model_sha256: 21 | raise ValueError( 22 | 'Pretrained model for {name} is not available.'.format(name=name)) 23 | return _model_sha256[name][:8] 24 | 25 | 26 | resnest_model_urls = { 27 | name: _url_format.format(name, short_hash(name)) 28 | for name in _model_sha256.keys() 29 | } 30 | 31 | 32 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 33 | model = ResNet(Bottleneck, [3, 4, 6, 3], 34 | radix=2, 35 | groups=1, 36 | bottleneck_width=64, 37 | deep_stem=True, 38 | stem_width=32, 39 | avg_down=True, 40 | avd=True, 41 | avd_first=False, 42 | **kwargs) 43 | if pretrained: 44 | model.load_state_dict( 45 | torch.hub.load_state_dict_from_url(resnest_model_urls['resnest50'], 46 | progress=True, 47 | check_hash=True)) 48 | return model 49 | 50 | 51 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 52 | model = ResNet(Bottleneck, [3, 4, 23, 3], 53 | radix=2, 54 | groups=1, 55 | bottleneck_width=64, 56 | deep_stem=True, 57 | stem_width=64, 58 | avg_down=True, 59 | avd=True, 60 | avd_first=False, 61 | **kwargs) 62 | if pretrained: 63 | model.load_state_dict( 64 | torch.hub.load_state_dict_from_url( 65 | resnest_model_urls['resnest101'], 66 | progress=True, 67 | check_hash=True)) 68 | return model 69 | 70 | 71 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): 72 | model = ResNet(Bottleneck, [3, 24, 36, 3], 73 | radix=2, 74 | groups=1, 75 | bottleneck_width=64, 76 | deep_stem=True, 77 | stem_width=64, 78 | avg_down=True, 79 | avd=True, 80 | avd_first=False, 81 | **kwargs) 82 | if pretrained: 83 | model.load_state_dict( 84 | torch.hub.load_state_dict_from_url( 85 | resnest_model_urls['resnest200'], 86 | progress=True, 87 | check_hash=True)) 88 | return model 89 | 90 | 91 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): 92 | model = ResNet(Bottleneck, [3, 30, 48, 8], 93 | radix=2, 94 | groups=1, 95 | bottleneck_width=64, 96 | deep_stem=True, 97 | stem_width=64, 98 | avg_down=True, 99 | avd=True, 100 | avd_first=False, 101 | **kwargs) 102 | if pretrained: 103 | model.load_state_dict( 104 | torch.hub.load_state_dict_from_url( 105 | resnest_model_urls['resnest269'], 106 | progress=True, 107 | check_hash=True)) 108 | return model 109 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnest/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | 4 | from .splat import SplAtConv2d, DropBlock2D 5 | from utils.learning import freeze_params 6 | 7 | __all__ = ['ResNet', 'Bottleneck'] 8 | 9 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 10 | 11 | _model_sha256 = {name: checksum for checksum, name in []} 12 | 13 | 14 | def short_hash(name): 15 | if name not in _model_sha256: 16 | raise ValueError( 17 | 'Pretrained model for {name} is not available.'.format(name=name)) 18 | return _model_sha256[name][:8] 19 | 20 | 21 | resnest_model_urls = { 22 | name: _url_format.format(name, short_hash(name)) 23 | for name in _model_sha256.keys() 24 | } 25 | 26 | 27 | class GlobalAvgPool2d(nn.Module): 28 | def __init__(self): 29 | """Global average pooling over the input's spatial dimensions""" 30 | super(GlobalAvgPool2d, self).__init__() 31 | 32 | def forward(self, inputs): 33 | return nn.functional.adaptive_avg_pool2d(inputs, 34 | 1).view(inputs.size(0), -1) 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | """ResNet Bottleneck 39 | """ 40 | # pylint: disable=unused-argument 41 | expansion = 4 42 | 43 | def __init__(self, 44 | inplanes, 45 | planes, 46 | stride=1, 47 | downsample=None, 48 | radix=1, 49 | cardinality=1, 50 | bottleneck_width=64, 51 | avd=False, 52 | avd_first=False, 53 | dilation=1, 54 | is_first=False, 55 | rectified_conv=False, 56 | rectify_avg=False, 57 | norm_layer=None, 58 | dropblock_prob=0.0, 59 | last_gamma=False): 60 | super(Bottleneck, self).__init__() 61 | group_width = int(planes * (bottleneck_width / 64.)) * cardinality 62 | self.conv1 = nn.Conv2d(inplanes, 63 | group_width, 64 | kernel_size=1, 65 | bias=False) 66 | self.bn1 = norm_layer(group_width) 67 | self.dropblock_prob = dropblock_prob 68 | self.radix = radix 69 | self.avd = avd and (stride > 1 or is_first) 70 | self.avd_first = avd_first 71 | 72 | if self.avd: 73 | self.avd_layer = nn.AvgPool2d(3, stride, padding=1) 74 | stride = 1 75 | 76 | if dropblock_prob > 0.0: 77 | self.dropblock1 = DropBlock2D(dropblock_prob, 3) 78 | if radix == 1: 79 | self.dropblock2 = DropBlock2D(dropblock_prob, 3) 80 | self.dropblock3 = DropBlock2D(dropblock_prob, 3) 81 | 82 | if radix >= 1: 83 | self.conv2 = SplAtConv2d(group_width, 84 | group_width, 85 | kernel_size=3, 86 | stride=stride, 87 | padding=dilation, 88 | dilation=dilation, 89 | groups=cardinality, 90 | bias=False, 91 | radix=radix, 92 | rectify=rectified_conv, 93 | rectify_avg=rectify_avg, 94 | norm_layer=norm_layer, 95 | dropblock_prob=dropblock_prob) 96 | elif rectified_conv: 97 | from rfconv import RFConv2d 98 | self.conv2 = RFConv2d(group_width, 99 | group_width, 100 | kernel_size=3, 101 | stride=stride, 102 | padding=dilation, 103 | dilation=dilation, 104 | groups=cardinality, 105 | bias=False, 106 | average_mode=rectify_avg) 107 | self.bn2 = norm_layer(group_width) 108 | else: 109 | self.conv2 = nn.Conv2d(group_width, 110 | group_width, 111 | kernel_size=3, 112 | stride=stride, 113 | padding=dilation, 114 | dilation=dilation, 115 | groups=cardinality, 116 | bias=False) 117 | self.bn2 = norm_layer(group_width) 118 | 119 | self.conv3 = nn.Conv2d(group_width, 120 | planes * 4, 121 | kernel_size=1, 122 | bias=False) 123 | self.bn3 = norm_layer(planes * 4) 124 | 125 | if last_gamma: 126 | from torch.nn.init import zeros_ 127 | zeros_(self.bn3.weight) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.downsample = downsample 130 | self.dilation = dilation 131 | self.stride = stride 132 | 133 | def forward(self, x): 134 | residual = x 135 | 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | if self.dropblock_prob > 0.0: 139 | out = self.dropblock1(out) 140 | out = self.relu(out) 141 | 142 | if self.avd and self.avd_first: 143 | out = self.avd_layer(out) 144 | 145 | out = self.conv2(out) 146 | if self.radix == 0: 147 | out = self.bn2(out) 148 | if self.dropblock_prob > 0.0: 149 | out = self.dropblock2(out) 150 | out = self.relu(out) 151 | 152 | if self.avd and not self.avd_first: 153 | out = self.avd_layer(out) 154 | 155 | out = self.conv3(out) 156 | out = self.bn3(out) 157 | if self.dropblock_prob > 0.0: 158 | out = self.dropblock3(out) 159 | 160 | if self.downsample is not None: 161 | residual = self.downsample(x) 162 | 163 | out += residual 164 | out = self.relu(out) 165 | 166 | return out 167 | 168 | 169 | class ResNet(nn.Module): 170 | """ResNet Variants 171 | Parameters 172 | ---------- 173 | block : Block 174 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 175 | layers : list of int 176 | Numbers of layers in each block 177 | classes : int, default 1000 178 | Number of classification classes. 179 | dilated : bool, default False 180 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 181 | typically used in Semantic Segmentation. 182 | norm_layer : object 183 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 184 | for Synchronized Cross-GPU BachNormalization). 185 | Reference: 186 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 187 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 188 | """ 189 | 190 | # pylint: disable=unused-variable 191 | def __init__(self, 192 | block, 193 | layers, 194 | radix=1, 195 | groups=1, 196 | bottleneck_width=64, 197 | num_classes=1000, 198 | dilated=False, 199 | dilation=1, 200 | deep_stem=False, 201 | stem_width=64, 202 | avg_down=False, 203 | rectified_conv=False, 204 | rectify_avg=False, 205 | avd=False, 206 | avd_first=False, 207 | final_drop=0.0, 208 | dropblock_prob=0, 209 | last_gamma=False, 210 | norm_layer=nn.BatchNorm2d, 211 | freeze_at=0): 212 | self.cardinality = groups 213 | self.bottleneck_width = bottleneck_width 214 | # ResNet-D params 215 | self.inplanes = stem_width * 2 if deep_stem else 64 216 | self.avg_down = avg_down 217 | self.last_gamma = last_gamma 218 | # ResNeSt params 219 | self.radix = radix 220 | self.avd = avd 221 | self.avd_first = avd_first 222 | 223 | super(ResNet, self).__init__() 224 | self.rectified_conv = rectified_conv 225 | self.rectify_avg = rectify_avg 226 | if rectified_conv: 227 | from rfconv import RFConv2d 228 | conv_layer = RFConv2d 229 | else: 230 | conv_layer = nn.Conv2d 231 | conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} 232 | if deep_stem: 233 | self.conv1 = nn.Sequential( 234 | conv_layer(3, 235 | stem_width, 236 | kernel_size=3, 237 | stride=2, 238 | padding=1, 239 | bias=False, 240 | **conv_kwargs), 241 | norm_layer(stem_width), 242 | nn.ReLU(inplace=True), 243 | conv_layer(stem_width, 244 | stem_width, 245 | kernel_size=3, 246 | stride=1, 247 | padding=1, 248 | bias=False, 249 | **conv_kwargs), 250 | norm_layer(stem_width), 251 | nn.ReLU(inplace=True), 252 | conv_layer(stem_width, 253 | stem_width * 2, 254 | kernel_size=3, 255 | stride=1, 256 | padding=1, 257 | bias=False, 258 | **conv_kwargs), 259 | ) 260 | else: 261 | self.conv1 = conv_layer(3, 262 | 64, 263 | kernel_size=7, 264 | stride=2, 265 | padding=3, 266 | bias=False, 267 | **conv_kwargs) 268 | self.bn1 = norm_layer(self.inplanes) 269 | self.relu = nn.ReLU(inplace=True) 270 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 271 | self.layer1 = self._make_layer(block, 272 | 64, 273 | layers[0], 274 | norm_layer=norm_layer, 275 | is_first=False) 276 | self.layer2 = self._make_layer(block, 277 | 128, 278 | layers[1], 279 | stride=2, 280 | norm_layer=norm_layer) 281 | if dilated or dilation == 4: 282 | self.layer3 = self._make_layer(block, 283 | 256, 284 | layers[2], 285 | stride=1, 286 | dilation=2, 287 | norm_layer=norm_layer, 288 | dropblock_prob=dropblock_prob) 289 | elif dilation == 2: 290 | self.layer3 = self._make_layer(block, 291 | 256, 292 | layers[2], 293 | stride=2, 294 | dilation=1, 295 | norm_layer=norm_layer, 296 | dropblock_prob=dropblock_prob) 297 | else: 298 | self.layer3 = self._make_layer(block, 299 | 256, 300 | layers[2], 301 | stride=2, 302 | norm_layer=norm_layer, 303 | dropblock_prob=dropblock_prob) 304 | 305 | self.stem = [self.conv1, self.bn1] 306 | self.stages = [self.layer1, self.layer2, self.layer3] 307 | 308 | for m in self.modules(): 309 | if isinstance(m, nn.Conv2d): 310 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 311 | m.weight.data.normal_(0, math.sqrt(2. / n)) 312 | elif isinstance(m, norm_layer): 313 | m.weight.data.fill_(1) 314 | m.bias.data.zero_() 315 | 316 | self.freeze(freeze_at) 317 | 318 | def _make_layer(self, 319 | block, 320 | planes, 321 | blocks, 322 | stride=1, 323 | dilation=1, 324 | norm_layer=None, 325 | dropblock_prob=0.0, 326 | is_first=True): 327 | downsample = None 328 | if stride != 1 or self.inplanes != planes * block.expansion: 329 | down_layers = [] 330 | if self.avg_down: 331 | if dilation == 1: 332 | down_layers.append( 333 | nn.AvgPool2d(kernel_size=stride, 334 | stride=stride, 335 | ceil_mode=True, 336 | count_include_pad=False)) 337 | else: 338 | down_layers.append( 339 | nn.AvgPool2d(kernel_size=1, 340 | stride=1, 341 | ceil_mode=True, 342 | count_include_pad=False)) 343 | down_layers.append( 344 | nn.Conv2d(self.inplanes, 345 | planes * block.expansion, 346 | kernel_size=1, 347 | stride=1, 348 | bias=False)) 349 | else: 350 | down_layers.append( 351 | nn.Conv2d(self.inplanes, 352 | planes * block.expansion, 353 | kernel_size=1, 354 | stride=stride, 355 | bias=False)) 356 | down_layers.append(norm_layer(planes * block.expansion)) 357 | downsample = nn.Sequential(*down_layers) 358 | 359 | layers = [] 360 | if dilation == 1 or dilation == 2: 361 | layers.append( 362 | block(self.inplanes, 363 | planes, 364 | stride, 365 | downsample=downsample, 366 | radix=self.radix, 367 | cardinality=self.cardinality, 368 | bottleneck_width=self.bottleneck_width, 369 | avd=self.avd, 370 | avd_first=self.avd_first, 371 | dilation=1, 372 | is_first=is_first, 373 | rectified_conv=self.rectified_conv, 374 | rectify_avg=self.rectify_avg, 375 | norm_layer=norm_layer, 376 | dropblock_prob=dropblock_prob, 377 | last_gamma=self.last_gamma)) 378 | elif dilation == 4: 379 | layers.append( 380 | block(self.inplanes, 381 | planes, 382 | stride, 383 | downsample=downsample, 384 | radix=self.radix, 385 | cardinality=self.cardinality, 386 | bottleneck_width=self.bottleneck_width, 387 | avd=self.avd, 388 | avd_first=self.avd_first, 389 | dilation=2, 390 | is_first=is_first, 391 | rectified_conv=self.rectified_conv, 392 | rectify_avg=self.rectify_avg, 393 | norm_layer=norm_layer, 394 | dropblock_prob=dropblock_prob, 395 | last_gamma=self.last_gamma)) 396 | else: 397 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 398 | 399 | self.inplanes = planes * block.expansion 400 | for i in range(1, blocks): 401 | layers.append( 402 | block(self.inplanes, 403 | planes, 404 | radix=self.radix, 405 | cardinality=self.cardinality, 406 | bottleneck_width=self.bottleneck_width, 407 | avd=self.avd, 408 | avd_first=self.avd_first, 409 | dilation=dilation, 410 | rectified_conv=self.rectified_conv, 411 | rectify_avg=self.rectify_avg, 412 | norm_layer=norm_layer, 413 | dropblock_prob=dropblock_prob, 414 | last_gamma=self.last_gamma)) 415 | 416 | return nn.Sequential(*layers) 417 | 418 | def forward(self, x): 419 | x = self.conv1(x) 420 | x = self.bn1(x) 421 | x = self.relu(x) 422 | x = self.maxpool(x) 423 | 424 | xs = [] 425 | 426 | x = self.layer1(x) 427 | xs.append(x) # 4X 428 | x = self.layer2(x) 429 | xs.append(x) # 8X 430 | x = self.layer3(x) 431 | xs.append(x) # 16X 432 | # Following STMVOS, we drop stage 5. 433 | xs.append(x) # 16X 434 | 435 | return xs 436 | 437 | def freeze(self, freeze_at): 438 | if freeze_at >= 1: 439 | for m in self.stem: 440 | freeze_params(m) 441 | 442 | for idx, stage in enumerate(self.stages, start=2): 443 | if freeze_at >= idx: 444 | freeze_params(stage) 445 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnest/splat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv2d, Module, ReLU 5 | from torch.nn.modules.utils import _pair 6 | 7 | __all__ = ['SplAtConv2d', 'DropBlock2D'] 8 | 9 | 10 | class DropBlock2D(object): 11 | def __init__(self, *args, **kwargs): 12 | raise NotImplementedError 13 | 14 | 15 | class SplAtConv2d(Module): 16 | """Split-Attention Conv2d 17 | """ 18 | def __init__(self, 19 | in_channels, 20 | channels, 21 | kernel_size, 22 | stride=(1, 1), 23 | padding=(0, 0), 24 | dilation=(1, 1), 25 | groups=1, 26 | bias=True, 27 | radix=2, 28 | reduction_factor=4, 29 | rectify=False, 30 | rectify_avg=False, 31 | norm_layer=None, 32 | dropblock_prob=0.0, 33 | **kwargs): 34 | super(SplAtConv2d, self).__init__() 35 | padding = _pair(padding) 36 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 37 | self.rectify_avg = rectify_avg 38 | inter_channels = max(in_channels * radix // reduction_factor, 32) 39 | self.radix = radix 40 | self.cardinality = groups 41 | self.channels = channels 42 | self.dropblock_prob = dropblock_prob 43 | if self.rectify: 44 | from rfconv import RFConv2d 45 | self.conv = RFConv2d(in_channels, 46 | channels * radix, 47 | kernel_size, 48 | stride, 49 | padding, 50 | dilation, 51 | groups=groups * radix, 52 | bias=bias, 53 | average_mode=rectify_avg, 54 | **kwargs) 55 | else: 56 | self.conv = Conv2d(in_channels, 57 | channels * radix, 58 | kernel_size, 59 | stride, 60 | padding, 61 | dilation, 62 | groups=groups * radix, 63 | bias=bias, 64 | **kwargs) 65 | self.use_bn = norm_layer is not None 66 | if self.use_bn: 67 | self.bn0 = norm_layer(channels * radix) 68 | self.relu = ReLU(inplace=True) 69 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 70 | if self.use_bn: 71 | self.bn1 = norm_layer(inter_channels) 72 | self.fc2 = Conv2d(inter_channels, 73 | channels * radix, 74 | 1, 75 | groups=self.cardinality) 76 | if dropblock_prob > 0.0: 77 | self.dropblock = DropBlock2D(dropblock_prob, 3) 78 | self.rsoftmax = rSoftMax(radix, groups) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | if self.use_bn: 83 | x = self.bn0(x) 84 | if self.dropblock_prob > 0.0: 85 | x = self.dropblock(x) 86 | x = self.relu(x) 87 | 88 | batch, rchannel = x.shape[:2] 89 | if self.radix > 1: 90 | if torch.__version__ < '1.5': 91 | splited = torch.split(x, int(rchannel // self.radix), dim=1) 92 | else: 93 | splited = torch.split(x, rchannel // self.radix, dim=1) 94 | gap = sum(splited) 95 | else: 96 | gap = x 97 | gap = F.adaptive_avg_pool2d(gap, 1) 98 | gap = self.fc1(gap) 99 | 100 | if self.use_bn: 101 | gap = self.bn1(gap) 102 | gap = self.relu(gap) 103 | 104 | atten = self.fc2(gap) 105 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 106 | 107 | if self.radix > 1: 108 | if torch.__version__ < '1.5': 109 | attens = torch.split(atten, int(rchannel // self.radix), dim=1) 110 | else: 111 | attens = torch.split(atten, rchannel // self.radix, dim=1) 112 | out = sum([att * split for (att, split) in zip(attens, splited)]) 113 | else: 114 | out = atten * x 115 | return out.contiguous() 116 | 117 | 118 | class rSoftMax(nn.Module): 119 | def __init__(self, radix, cardinality): 120 | super().__init__() 121 | self.radix = radix 122 | self.cardinality = cardinality 123 | 124 | def forward(self, x): 125 | batch = x.size(0) 126 | if self.radix > 1: 127 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 128 | x = F.softmax(x, dim=1) 129 | x = x.reshape(batch, -1) 130 | else: 131 | x = torch.sigmoid(x) 132 | return x 133 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from utils.learning import freeze_params 4 | 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, 10 | inplanes, 11 | planes, 12 | stride=1, 13 | dilation=1, 14 | downsample=None, 15 | BatchNorm=None): 16 | super(Bottleneck, self).__init__() 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn1 = BatchNorm(planes) 19 | self.conv2 = nn.Conv2d(planes, 20 | planes, 21 | kernel_size=3, 22 | stride=stride, 23 | dilation=dilation, 24 | padding=dilation, 25 | bias=False) 26 | self.bn2 = BatchNorm(planes) 27 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 28 | self.bn3 = BatchNorm(planes * 4) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = downsample 31 | self.stride = stride 32 | self.dilation = dilation 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class ResNet(nn.Module): 58 | def __init__(self, block, layers, output_stride, BatchNorm, freeze_at=0): 59 | self.inplanes = 64 60 | super(ResNet, self).__init__() 61 | 62 | if output_stride == 16: 63 | strides = [1, 2, 2, 1] 64 | dilations = [1, 1, 1, 2] 65 | elif output_stride == 8: 66 | strides = [1, 2, 1, 1] 67 | dilations = [1, 1, 2, 4] 68 | else: 69 | raise NotImplementedError 70 | 71 | # Modules 72 | self.conv1 = nn.Conv2d(3, 73 | 64, 74 | kernel_size=7, 75 | stride=2, 76 | padding=3, 77 | bias=False) 78 | self.bn1 = BatchNorm(64) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 81 | 82 | self.layer1 = self._make_layer(block, 83 | 64, 84 | layers[0], 85 | stride=strides[0], 86 | dilation=dilations[0], 87 | BatchNorm=BatchNorm) 88 | self.layer2 = self._make_layer(block, 89 | 128, 90 | layers[1], 91 | stride=strides[1], 92 | dilation=dilations[1], 93 | BatchNorm=BatchNorm) 94 | self.layer3 = self._make_layer(block, 95 | 256, 96 | layers[2], 97 | stride=strides[2], 98 | dilation=dilations[2], 99 | BatchNorm=BatchNorm) 100 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 101 | 102 | self.stem = [self.conv1, self.bn1] 103 | self.stages = [self.layer1, self.layer2, self.layer3] 104 | 105 | self._init_weight() 106 | self.freeze(freeze_at) 107 | 108 | def _make_layer(self, 109 | block, 110 | planes, 111 | blocks, 112 | stride=1, 113 | dilation=1, 114 | BatchNorm=None): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | nn.Conv2d(self.inplanes, 119 | planes * block.expansion, 120 | kernel_size=1, 121 | stride=stride, 122 | bias=False), 123 | BatchNorm(planes * block.expansion), 124 | ) 125 | 126 | layers = [] 127 | layers.append( 128 | block(self.inplanes, planes, stride, max(dilation // 2, 1), 129 | downsample, BatchNorm)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append( 133 | block(self.inplanes, 134 | planes, 135 | dilation=dilation, 136 | BatchNorm=BatchNorm)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, input): 141 | x = self.conv1(input) 142 | x = self.bn1(x) 143 | x = self.relu(x) 144 | x = self.maxpool(x) 145 | 146 | xs = [] 147 | 148 | x = self.layer1(x) 149 | xs.append(x) # 4X 150 | x = self.layer2(x) 151 | xs.append(x) # 8X 152 | x = self.layer3(x) 153 | xs.append(x) # 16X 154 | # Following STMVOS, we drop stage 5. 155 | xs.append(x) # 16X 156 | 157 | return xs 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, nn.BatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | 168 | def freeze(self, freeze_at): 169 | if freeze_at >= 1: 170 | for m in self.stem: 171 | freeze_params(m) 172 | 173 | for idx, stage in enumerate(self.stages, start=2): 174 | if freeze_at >= idx: 175 | freeze_params(stage) 176 | 177 | 178 | def ResNet50(output_stride, BatchNorm, freeze_at=0): 179 | """Constructs a ResNet-50 model. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(Bottleneck, [3, 4, 6, 3], 184 | output_stride, 185 | BatchNorm, 186 | freeze_at=freeze_at) 187 | return model 188 | 189 | 190 | def ResNet101(output_stride, BatchNorm, freeze_at=0): 191 | """Constructs a ResNet-101 model. 192 | Args: 193 | pretrained (bool): If True, returns a model pre-trained on ImageNet 194 | """ 195 | model = ResNet(Bottleneck, [3, 4, 23, 3], 196 | output_stride, 197 | BatchNorm, 198 | freeze_at=freeze_at) 199 | return model 200 | 201 | 202 | if __name__ == "__main__": 203 | import torch 204 | model = ResNet101(BatchNorm=nn.BatchNorm2d, output_stride=8) 205 | input = torch.rand(1, 3, 512, 512) 206 | output, low_level_feat = model(input) 207 | print(output.size()) 208 | print(low_level_feat.size()) 209 | -------------------------------------------------------------------------------- /aot_plus/networks/encoders/swin/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_swin_model -------------------------------------------------------------------------------- /aot_plus/networks/encoders/swin/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | 10 | 11 | def build_swin_model(model_type, freeze_at=0): 12 | if model_type == 'swin_base': 13 | model = SwinTransformer(embed_dim=128, 14 | depths=[2, 2, 18, 2], 15 | num_heads=[4, 8, 16, 32], 16 | window_size=7, 17 | drop_path_rate=0.3, 18 | out_indices=(0, 1, 2), 19 | ape=False, 20 | patch_norm=True, 21 | frozen_stages=freeze_at, 22 | use_checkpoint=False) 23 | 24 | else: 25 | raise NotImplementedError(f"Unkown model: {model_type}") 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /aot_plus/networks/engines/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.engines.aot_engine import AOTEngine, AOTInferEngine 2 | 3 | 4 | def build_engine(name, phase='train', **kwargs): 5 | if name == 'aotengine': 6 | if phase == 'train': 7 | return AOTEngine(**kwargs) 8 | elif phase == 'eval': 9 | return AOTInferEngine(**kwargs) 10 | else: 11 | raise NotImplementedError 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /aot_plus/networks/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/layers/__init__.py -------------------------------------------------------------------------------- /aot_plus/networks/layers/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class GroupNorm1D(nn.Module): 7 | def __init__(self, indim, groups=8): 8 | super().__init__() 9 | self.gn = nn.GroupNorm(groups, indim) 10 | 11 | def forward(self, x): 12 | return self.gn(x.permute(1, 2, 0)).permute(2, 0, 1) 13 | 14 | 15 | class GNActDWConv2d(nn.Module): 16 | def __init__(self, indim, gn_groups=32): 17 | super().__init__() 18 | self.gn = nn.GroupNorm(gn_groups, indim) 19 | self.conv = nn.Conv2d(indim, 20 | indim, 21 | 5, 22 | dilation=1, 23 | padding=2, 24 | groups=indim, 25 | bias=False) 26 | 27 | def forward(self, x, size_2d): 28 | h, w = size_2d 29 | _, bs, c = x.size() 30 | x = x.view(h, w, bs, c).permute(2, 3, 0, 1) 31 | x = self.gn(x) 32 | x = F.gelu(x) 33 | x = self.conv(x) 34 | x = x.view(bs, c, h * w).permute(2, 0, 1) 35 | return x 36 | 37 | 38 | class ConvGN(nn.Module): 39 | def __init__(self, indim, outdim, kernel_size, gn_groups=8): 40 | super().__init__() 41 | self.conv = nn.Conv2d(indim, 42 | outdim, 43 | kernel_size, 44 | padding=kernel_size // 2) 45 | self.gn = nn.GroupNorm(gn_groups, outdim) 46 | 47 | def forward(self, x): 48 | return self.gn(self.conv(x)) 49 | 50 | 51 | def seq_to_2d(tensor, size_2d): 52 | h, w = size_2d 53 | _, n, c = tensor.size() 54 | tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() 55 | return tensor 56 | 57 | def twod_to_seq(tensor): 58 | n, c, h, w = tensor.size() 59 | tensor = tensor.view(n, c, h * w).permute(3, 0, 1).contiguous() 60 | return tensor 61 | 62 | 63 | def drop_path(x, drop_prob: float = 0., training: bool = False): 64 | if drop_prob == 0. or not training: 65 | return x 66 | keep_prob = 1 - drop_prob 67 | shape = (x.shape[0], ) + (1, ) * ( 68 | x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 69 | random_tensor = keep_prob + torch.rand( 70 | shape, dtype=x.dtype, device=x.device) 71 | random_tensor.floor_() # binarize 72 | output = x.div(keep_prob) * random_tensor 73 | return output 74 | 75 | 76 | class DropPath(nn.Module): 77 | def __init__(self, drop_prob=None, batch_dim=0): 78 | super(DropPath, self).__init__() 79 | self.drop_prob = drop_prob 80 | self.batch_dim = batch_dim 81 | 82 | def forward(self, x): 83 | return self.drop_path(x, self.drop_prob) 84 | 85 | def drop_path(self, x, drop_prob): 86 | if drop_prob == 0. or not self.training: 87 | return x 88 | keep_prob = 1 - drop_prob 89 | shape = [1 for _ in range(x.ndim)] 90 | shape[self.batch_dim] = x.shape[self.batch_dim] 91 | random_tensor = keep_prob + torch.rand( 92 | shape, dtype=x.dtype, device=x.device) 93 | random_tensor.floor_() # binarize 94 | output = x.div(keep_prob) * random_tensor 95 | return output 96 | 97 | 98 | class DropOutLogit(nn.Module): 99 | def __init__(self, drop_prob=None): 100 | super(DropOutLogit, self).__init__() 101 | self.drop_prob = drop_prob 102 | 103 | def forward(self, x): 104 | return self.drop_logit(x, self.drop_prob) 105 | 106 | def drop_logit(self, x, drop_prob): 107 | if drop_prob == 0. or not self.training: 108 | return x 109 | random_tensor = drop_prob + torch.rand( 110 | x.shape, dtype=x.dtype, device=x.device) 111 | random_tensor.floor_() # binarize 112 | mask = random_tensor * 1e+8 if ( 113 | x.dtype == torch.float32) else random_tensor * 1e+4 114 | output = x - mask 115 | return output 116 | -------------------------------------------------------------------------------- /aot_plus/networks/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from itertools import ifilterfalse 7 | except ImportError: # py3k 8 | from itertools import filterfalse as ifilterfalse 9 | 10 | 11 | def dice_loss(probas, labels, smooth=1): 12 | 13 | C = probas.size(1) 14 | losses = [] 15 | for c in list(range(C)): 16 | fg = (labels == c).float() 17 | if fg.sum() == 0: 18 | continue 19 | class_pred = probas[:, c] 20 | p0 = class_pred 21 | g0 = fg 22 | numerator = 2 * torch.sum(p0 * g0) + smooth 23 | denominator = torch.sum(p0) + torch.sum(g0) + smooth 24 | losses.append(1 - ((numerator) / (denominator))) 25 | return mean(losses) 26 | 27 | 28 | def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6): 29 | ''' 30 | Tversky loss function. 31 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 32 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 33 | 34 | Same as soft dice loss when alpha=beta=0.5. 35 | Same as Jaccord loss when alpha=beta=1.0. 36 | See `Tversky loss function for image segmentation using 3D fully convolutional deep networks` 37 | https://arxiv.org/pdf/1706.05721.pdf 38 | ''' 39 | C = probas.size(1) 40 | losses = [] 41 | for c in list(range(C)): 42 | fg = (labels == c).float() 43 | if fg.sum() == 0: 44 | continue 45 | class_pred = probas[:, c] 46 | p0 = class_pred 47 | p1 = 1 - class_pred 48 | g0 = fg 49 | g1 = 1 - fg 50 | numerator = torch.sum(p0 * g0) 51 | denominator = numerator + alpha * \ 52 | torch.sum(p0*g1) + beta*torch.sum(p1*g0) 53 | losses.append(1 - ((numerator) / (denominator + epsilon))) 54 | return mean(losses) 55 | 56 | 57 | def flatten_probas(probas, labels, ignore=255): 58 | """ 59 | Flattens predictions in the batch 60 | """ 61 | B, C, H, W = probas.size() 62 | probas = probas.permute(0, 2, 3, 63 | 1).contiguous().view(-1, C) # B * H * W, C = P, C 64 | labels = labels.view(-1) 65 | if ignore is None: 66 | return probas, labels 67 | valid = (labels != ignore) 68 | vprobas = probas[valid.view(-1, 1).expand(-1, C)].reshape(-1, C) 69 | # vprobas = probas[torch.nonzero(valid).squeeze()] 70 | vlabels = labels[valid] 71 | return vprobas, vlabels 72 | 73 | 74 | def isnan(x): 75 | return x != x 76 | 77 | 78 | def mean(l, ignore_nan=False, empty=0): 79 | """ 80 | nanmean compatible with generators. 81 | """ 82 | l = iter(l) 83 | if ignore_nan: 84 | l = ifilterfalse(isnan, l) 85 | try: 86 | n = 1 87 | acc = next(l) 88 | except StopIteration: 89 | if empty == 'raise': 90 | raise ValueError('Empty mean') 91 | return empty 92 | for n, v in enumerate(l, 2): 93 | acc += v 94 | if n == 1: 95 | return acc 96 | return acc / n 97 | 98 | 99 | class DiceLoss(nn.Module): 100 | def __init__(self, ignore_index=255): 101 | super(DiceLoss, self).__init__() 102 | self.ignore_index = ignore_index 103 | 104 | def forward(self, tmp_dic, label_dic, step=None): 105 | total_loss = [] 106 | for idx in range(len(tmp_dic)): 107 | pred = tmp_dic[idx] 108 | label = label_dic[idx] 109 | pred = F.softmax(pred, dim=1) 110 | label = label.view(1, 1, pred.size()[2], pred.size()[3]) 111 | loss = dice_loss( 112 | *flatten_probas(pred, label, ignore=self.ignore_index)) 113 | total_loss.append(loss.unsqueeze(0)) 114 | total_loss = torch.cat(total_loss, dim=0) 115 | return total_loss 116 | 117 | 118 | class SoftJaccordLoss(nn.Module): 119 | def __init__(self, ignore_index=255): 120 | super(SoftJaccordLoss, self).__init__() 121 | self.ignore_index = ignore_index 122 | 123 | def forward(self, tmp_dic, label_dic, step=None): 124 | total_loss = [] 125 | for idx in range(len(tmp_dic)): 126 | pred = tmp_dic[idx] 127 | label = label_dic[idx] 128 | pred = F.softmax(pred, dim=1) 129 | label = label.view(1, 1, pred.size()[2], pred.size()[3]) 130 | loss = tversky_loss(*flatten_probas(pred, 131 | label, 132 | ignore=self.ignore_index), 133 | alpha=1.0, 134 | beta=1.0) 135 | if loss != 0: 136 | total_loss.append(loss.unsqueeze(0)) 137 | else: 138 | total_loss.append(torch.zeros(1).cuda()) 139 | total_loss = torch.cat(total_loss, dim=0) 140 | return total_loss 141 | 142 | 143 | class CrossEntropyLoss(nn.Module): 144 | def __init__(self, 145 | top_k_percent_pixels=None, 146 | hard_example_mining_step=100000): 147 | super(CrossEntropyLoss, self).__init__() 148 | self.top_k_percent_pixels = top_k_percent_pixels 149 | if top_k_percent_pixels is not None: 150 | assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1) 151 | self.hard_example_mining_step = hard_example_mining_step + 1e-5 152 | if self.top_k_percent_pixels is None: 153 | self.celoss = nn.CrossEntropyLoss(ignore_index=255, 154 | reduction='mean') 155 | else: 156 | self.celoss = nn.CrossEntropyLoss(ignore_index=255, 157 | reduction='none') 158 | 159 | def forward(self, dic_tmp, y, step): 160 | total_loss = [] 161 | for i in range(len(dic_tmp)): 162 | pred_logits = dic_tmp[i] 163 | gts = y[i] 164 | if self.top_k_percent_pixels is None: 165 | final_loss = self.celoss(pred_logits, gts) 166 | else: 167 | # Only compute the loss for top k percent pixels. 168 | # First, compute the loss for all pixels. Note we do not put the loss 169 | # to loss_collection and set reduction = None to keep the shape. 170 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) 171 | pred_logits = pred_logits.view( 172 | -1, pred_logits.size(1), 173 | pred_logits.size(2) * pred_logits.size(3)) 174 | gts = gts.view(-1, gts.size(1) * gts.size(2)) 175 | pixel_losses = self.celoss(pred_logits, gts) 176 | if self.hard_example_mining_step == 0: 177 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels) 178 | else: 179 | ratio = min(1.0, 180 | step / float(self.hard_example_mining_step)) 181 | top_k_pixels = int((ratio * self.top_k_percent_pixels + 182 | (1.0 - ratio)) * num_pixels) 183 | top_k_loss, top_k_indices = torch.topk(pixel_losses, 184 | k=top_k_pixels, 185 | dim=1) 186 | 187 | final_loss = torch.mean(top_k_loss) 188 | if final_loss != 0: 189 | final_loss = final_loss.unsqueeze(0) 190 | else: 191 | final_loss = torch.zeros(1).cuda() 192 | total_loss.append(final_loss) 193 | total_loss = torch.cat(total_loss, dim=0) 194 | return total_loss 195 | -------------------------------------------------------------------------------- /aot_plus/networks/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FrozenBatchNorm2d(nn.Module): 7 | """ 8 | BatchNorm2d where the batch statistics and the affine parameters 9 | are fixed 10 | """ 11 | def __init__(self, n, epsilon=1e-5): 12 | super(FrozenBatchNorm2d, self).__init__() 13 | self.register_buffer("weight", torch.ones(n)) 14 | self.register_buffer("bias", torch.zeros(n)) 15 | self.register_buffer("running_mean", torch.zeros(n)) 16 | self.register_buffer("running_var", torch.ones(n) - epsilon) 17 | self.epsilon = epsilon 18 | 19 | def forward(self, x): 20 | """ 21 | Refer to Detectron2 (https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py) 22 | """ 23 | if x.requires_grad: 24 | # When gradients are needed, F.batch_norm will use extra memory 25 | # because its backward op computes gradients for weight/bias as well. 26 | scale = self.weight * (self.running_var + self.epsilon).rsqrt() 27 | bias = self.bias - self.running_mean * scale 28 | scale = scale.reshape(1, -1, 1, 1) 29 | bias = bias.reshape(1, -1, 1, 1) 30 | out_dtype = x.dtype # may be half 31 | return x * scale.to(out_dtype) + bias.to(out_dtype) 32 | else: 33 | # When gradients are not needed, F.batch_norm is a single fused op 34 | # and provide more optimization opportunities. 35 | return F.batch_norm( 36 | x, 37 | self.running_mean, 38 | self.running_var, 39 | self.weight, 40 | self.bias, 41 | training=False, 42 | eps=self.epsilon, 43 | ) 44 | -------------------------------------------------------------------------------- /aot_plus/networks/layers/position.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils.math import truncated_normal_ 8 | 9 | 10 | class Downsample2D(nn.Module): 11 | def __init__(self, mode='nearest', scale=4): 12 | super().__init__() 13 | self.mode = mode 14 | self.scale = scale 15 | 16 | def forward(self, x): 17 | n, c, h, w = x.size() 18 | x = F.interpolate(x, 19 | size=(h // self.scale + 1, w // self.scale + 1), 20 | mode=self.mode) 21 | return x 22 | 23 | 24 | def generate_coord(x): 25 | _, _, h, w = x.size() 26 | device = x.device 27 | col = torch.arange(0, h, device=device) 28 | row = torch.arange(0, w, device=device) 29 | grid_h, grid_w = torch.meshgrid(col, row) 30 | return grid_h, grid_w 31 | 32 | 33 | class PositionEmbeddingSine(nn.Module): 34 | def __init__(self, 35 | num_pos_feats=64, 36 | temperature=10000, 37 | normalize=False, 38 | scale=None): 39 | super().__init__() 40 | self.num_pos_feats = num_pos_feats 41 | self.temperature = temperature 42 | self.normalize = normalize 43 | if scale is not None and normalize is False: 44 | raise ValueError("normalize should be True if scale is passed") 45 | if scale is None: 46 | scale = 2 * math.pi 47 | self.scale = scale 48 | 49 | def forward(self, x): 50 | grid_y, grid_x = generate_coord(x) 51 | 52 | y_embed = grid_y.unsqueeze(0).float() 53 | x_embed = grid_x.unsqueeze(0).float() 54 | 55 | if self.normalize: 56 | eps = 1e-6 57 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 58 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 59 | 60 | dim_t = torch.arange(self.num_pos_feats, 61 | dtype=torch.float32, 62 | device=x.device) 63 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) 64 | 65 | pos_x = x_embed[:, :, :, None] / dim_t 66 | pos_y = y_embed[:, :, :, None] / dim_t 67 | pos_x = torch.stack( 68 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), 69 | dim=4).flatten(3) 70 | pos_y = torch.stack( 71 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), 72 | dim=4).flatten(3) 73 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 74 | return pos 75 | 76 | 77 | class PositionEmbeddingLearned(nn.Module): 78 | def __init__(self, num_pos_feats=64, H=30, W=30): 79 | super().__init__() 80 | self.H = H 81 | self.W = W 82 | self.pos_emb = nn.Parameter( 83 | truncated_normal_(torch.zeros(1, num_pos_feats, H, W))) 84 | 85 | def forward(self, x): 86 | bs, _, h, w = x.size() 87 | pos_emb = self.pos_emb 88 | if h != self.H or w != self.W: 89 | pos_emb = F.interpolate(pos_emb, size=(h, w), mode="bilinear") 90 | return pos_emb 91 | -------------------------------------------------------------------------------- /aot_plus/networks/models/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.models.aot import AOT 2 | 3 | 4 | def build_vos_model(name, cfg, **kwargs): 5 | 6 | if name == 'aot': 7 | return AOT(cfg, encoder=cfg.MODEL_ENCODER, **kwargs) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /aot_plus/networks/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /aot_plus/networks/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /aot_plus/networks/models/__pycache__/aot.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/models/__pycache__/aot.cpython-37.pyc -------------------------------------------------------------------------------- /aot_plus/networks/models/__pycache__/aot.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/networks/models/__pycache__/aot.cpython-38.pyc -------------------------------------------------------------------------------- /aot_plus/networks/models/aot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.encoders import build_encoder 4 | from networks.layers.transformer import LongShortTermTransformer 5 | from networks.decoders import build_decoder 6 | from networks.layers.position import PositionEmbeddingSine 7 | 8 | 9 | class AOT(nn.Module): 10 | def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): 11 | super().__init__() 12 | self.cfg = cfg 13 | self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM 14 | self.epsilon = cfg.MODEL_EPSILON 15 | 16 | self.encoder = build_encoder(encoder, 17 | frozen_bn=cfg.MODEL_FREEZE_BN, 18 | freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT) 19 | self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1], 20 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 21 | kernel_size=1) 22 | 23 | self.LSTT = LongShortTermTransformer( 24 | cfg.MODEL_LSTT_NUM, 25 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 26 | cfg.MODEL_SELF_HEADS, 27 | cfg.MODEL_ATT_HEADS, 28 | emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, 29 | droppath=cfg.TRAIN_LSTT_DROPPATH, 30 | lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, 31 | st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, 32 | droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, 33 | droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, 34 | intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 35 | return_intermediate=True, 36 | simplified=cfg.MODEL_SIMPLIFIED_STM, 37 | stopgrad=cfg.MODEL_STM_STOPGRAD, 38 | joint_longatt=cfg.MODEL_JOINT_LONGATT, 39 | linear_q=cfg.MODEL_LINEAR_Q, 40 | norm_inp=cfg.MODEL_NORM_INP, 41 | recurrent_stm=cfg.MODEL_RECURRENT_STM, 42 | recurrent_ltm=cfg.MODEL_RECURRENT_LTM) 43 | 44 | decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ 45 | (cfg.MODEL_LSTT_NUM + 46 | 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM 47 | 48 | self.decoder = build_decoder( 49 | decoder, 50 | in_dim=decoder_indim, 51 | out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, 52 | decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, 53 | hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, 54 | shortcut_dims=cfg.MODEL_ENCODER_DIM, 55 | align_corners=cfg.MODEL_ALIGN_CORNERS) 56 | 57 | id_dim = cfg.MODEL_MAX_OBJ_NUM + 1 58 | if cfg.MODEL_IGNORE_TOKEN: 59 | id_dim = cfg.MODEL_MAX_OBJ_NUM + 2 60 | if cfg.MODEL_ALIGN_CORNERS: 61 | self.patch_wise_id_bank = nn.Conv2d( 62 | id_dim, 63 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 64 | kernel_size=17, 65 | stride=16, 66 | padding=8) 67 | else: 68 | self.patch_wise_id_bank = nn.Conv2d( 69 | id_dim, 70 | cfg.MODEL_ENCODER_EMBEDDING_DIM, 71 | kernel_size=16, 72 | stride=16, 73 | padding=0) 74 | 75 | self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True) 76 | 77 | self.pos_generator = PositionEmbeddingSine( 78 | cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True) 79 | 80 | self._init_weight() 81 | 82 | def get_pos_emb(self, x): 83 | pos_emb = self.pos_generator(x) 84 | return pos_emb 85 | 86 | def get_id_emb(self, x): 87 | id_emb = self.patch_wise_id_bank(x) 88 | id_emb = self.id_dropout(id_emb) 89 | return id_emb 90 | 91 | def encode_image(self, img): 92 | xs = self.encoder(img) 93 | xs[-1] = self.encoder_projector(xs[-1]) 94 | return xs 95 | 96 | def decode_id_logits(self, lstt_emb, shortcuts): 97 | n, c, h, w = shortcuts[-1].size() 98 | decoder_inputs = [shortcuts[-1]] 99 | for emb in lstt_emb: 100 | decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1)) 101 | pred_logit = self.decoder(decoder_inputs, shortcuts) 102 | return pred_logit 103 | 104 | def LSTT_forward(self, 105 | curr_embs, 106 | long_term_memories, 107 | short_term_memories, 108 | curr_id_emb=None, 109 | pos_emb=None, 110 | size_2d=(30, 30)): 111 | n, c, h, w = curr_embs[-1].size() 112 | curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1) 113 | lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories, 114 | short_term_memories, curr_id_emb, 115 | pos_emb, size_2d) 116 | lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip( 117 | *lstt_memories) 118 | return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories 119 | 120 | def _init_weight(self): 121 | nn.init.xavier_uniform_(self.encoder_projector.weight) 122 | nn.init.orthogonal_( 123 | self.patch_wise_id_bank.weight.view( 124 | self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1), 125 | gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2) 126 | -------------------------------------------------------------------------------- /aot_plus/tools/demo.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | import os 4 | from time import time 5 | 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | 9 | import cv2 10 | from PIL import Image 11 | from skimage.morphology.binary import binary_dilation 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | from torchvision import transforms 18 | 19 | from networks.models import build_vos_model 20 | from networks.engines import build_engine 21 | from utils.checkpoint import load_network 22 | 23 | from dataloaders.eval_datasets import VOSTest 24 | import dataloaders.video_transforms as tr 25 | from utils.image import save_mask 26 | 27 | _palette = [ 28 | 255, 0, 0, 0, 0, 139, 255, 255, 84, 0, 255, 0, 139, 0, 139, 0, 128, 128, 29 | 128, 128, 128, 139, 0, 0, 218, 165, 32, 144, 238, 144, 160, 82, 45, 148, 0, 30 | 211, 255, 0, 255, 30, 144, 255, 255, 218, 185, 85, 107, 47, 255, 140, 0, 31 | 50, 205, 50, 123, 104, 238, 240, 230, 140, 72, 61, 139, 128, 128, 0, 0, 0, 32 | 205, 221, 160, 221, 143, 188, 143, 127, 255, 212, 176, 224, 230, 244, 164, 33 | 96, 250, 128, 114, 70, 130, 180, 0, 128, 0, 173, 255, 47, 255, 105, 180, 34 | 238, 130, 238, 154, 205, 50, 220, 20, 60, 176, 48, 96, 0, 206, 209, 0, 191, 35 | 255, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 36 | 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, 37 | 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58, 38 | 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64, 39 | 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70, 40 | 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77, 41 | 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83, 42 | 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89, 43 | 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96, 44 | 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101, 45 | 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106, 46 | 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111, 47 | 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116, 48 | 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, 49 | 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126, 50 | 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, 51 | 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136, 52 | 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141, 53 | 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146, 54 | 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151, 55 | 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156, 56 | 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161, 57 | 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166, 58 | 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171, 59 | 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176, 60 | 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181, 61 | 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186, 62 | 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191, 63 | 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196, 64 | 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201, 65 | 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206, 66 | 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 67 | 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216, 68 | 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221, 69 | 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226, 70 | 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231, 71 | 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236, 72 | 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241, 73 | 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246, 74 | 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251, 75 | 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255, 0, 0, 0 76 | ] 77 | color_palette = np.array(_palette).reshape(-1, 3) 78 | 79 | 80 | def overlay(image, mask, colors=[255, 0, 0], cscale=1, alpha=0.4): 81 | colors = np.atleast_2d(colors) * cscale 82 | 83 | im_overlay = image.copy() 84 | object_ids = np.unique(mask) 85 | 86 | for object_id in object_ids[1:]: 87 | # Overlay color on binary mask 88 | 89 | foreground = image * alpha + np.ones( 90 | image.shape) * (1 - alpha) * np.array(colors[object_id]) 91 | binary_mask = mask == object_id 92 | 93 | # Compose image 94 | im_overlay[binary_mask] = foreground[binary_mask] 95 | 96 | countours = binary_dilation(binary_mask) ^ binary_mask 97 | im_overlay[countours, :] = 0 98 | 99 | return im_overlay.astype(image.dtype) 100 | 101 | 102 | def demo(cfg): 103 | video_fps = 10 104 | gpu_id = cfg.TEST_GPU_ID 105 | 106 | # Load pre-trained model 107 | print('Build AOT model.') 108 | model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(gpu_id) 109 | 110 | print('Load checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) 111 | model, _ = load_network(model, cfg.TEST_CKPT_PATH, gpu_id) 112 | 113 | print('Build AOT engine.') 114 | engine = build_engine(cfg.MODEL_ENGINE, 115 | phase='eval', 116 | aot_model=model, 117 | gpu_id=gpu_id, 118 | long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP) 119 | 120 | # Prepare datasets for each sequence 121 | transform = transforms.Compose([ 122 | tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE, 123 | cfg.TEST_FLIP, cfg.TEST_MULTISCALE, 124 | cfg.MODEL_ALIGN_CORNERS), 125 | tr.MultiToTensor() 126 | ]) 127 | if cfg.TEST_DATA_PATH is not None and cfg.SEQ_LIST is None: 128 | image_root = os.path.join(cfg.TEST_DATA_PATH, 'images') 129 | label_root = os.path.join(cfg.TEST_DATA_PATH, 'masks') 130 | sequences = os.listdir(image_root) 131 | elif cfg.SEQ_LIST is not None: 132 | sequences = [seq.strip() for seq in cfg.SEQ_LIST] 133 | print(sequences) 134 | image_root = os.path.join(cfg.TEST_DATA_PATH, 'JPEGImages_10fps') 135 | label_root = os.path.join(cfg.TEST_DATA_PATH, 'Annotations') 136 | else: 137 | image_root = cfg.TEST_FRAME_PATH 138 | label_root = cfg.TEST_LABEL_PATH 139 | sequences = [image_root.split('/')[-2]] 140 | 141 | seq_datasets = [] 142 | for seq_name in sequences: 143 | print('Build a dataset for sequence {}.'.format(seq_name)) 144 | if cfg.TEST_DATA_PATH is not None: 145 | seq_images = np.sort(os.listdir(os.path.join(image_root, seq_name))) 146 | else: 147 | seq_images = np.sort(os.listdir(image_root)) 148 | image_root = "/".join(image_root.split('/')[:-2]) 149 | label_root = "/".join(label_root.split('/')[:-2]) 150 | 151 | seq_labels = [seq_images[0].replace('jpg', 'png')] 152 | seq_dataset = VOSTest(image_root, 153 | label_root, 154 | seq_name, 155 | seq_images, 156 | seq_labels, 157 | transform=transform) 158 | seq_datasets.append(seq_dataset) 159 | 160 | # Infer 161 | output_root = cfg.TEST_OUTPUT_PATH 162 | output_mask_root = os.path.join(output_root, 'pred_masks') 163 | if not os.path.exists(output_mask_root): 164 | os.makedirs(output_mask_root) 165 | 166 | for seq_dataset in seq_datasets: 167 | seq_name = seq_dataset.seq_name 168 | image_seq_root = os.path.join(image_root, seq_name) 169 | output_mask_seq_root = os.path.join(output_mask_root, seq_name) 170 | if not os.path.exists(output_mask_seq_root): 171 | os.makedirs(output_mask_seq_root) 172 | print('Build a dataloader for sequence {}.'.format(seq_name)) 173 | seq_dataloader = DataLoader(seq_dataset, 174 | batch_size=1, 175 | shuffle=False, 176 | num_workers=cfg.TEST_WORKERS, 177 | pin_memory=True) 178 | num_frames = len(seq_dataset) 179 | max_gap = int(round(num_frames / 30)) 180 | gap = max(max_gap, 5) 181 | print(gap) 182 | engine.long_term_mem_gap = gap 183 | 184 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 185 | output_video_path = os.path.join( 186 | output_root, '{}_{}fps.avi'.format(seq_name, video_fps)) 187 | 188 | print('Start the inference of sequence {}:'.format(seq_name)) 189 | model.eval() 190 | engine.restart_engine() 191 | with torch.no_grad(): 192 | time_start = 0 193 | for frame_idx, samples in enumerate(seq_dataloader): 194 | # if time_start != 0: 195 | # print(time() - time_start) 196 | # time_start = time() 197 | sample = samples[0] 198 | img_name = sample['meta']['current_name'][0] 199 | 200 | obj_nums = sample['meta']['obj_num'] 201 | output_height = sample['meta']['height'] 202 | output_width = sample['meta']['width'] 203 | obj_idx = sample['meta']['obj_idx'] 204 | 205 | obj_nums = [int(obj_num) for obj_num in obj_nums] 206 | obj_idx = [int(_obj_idx) for _obj_idx in obj_idx] 207 | 208 | current_img = sample['current_img'] 209 | current_img = current_img.cuda(gpu_id, non_blocking=True) 210 | 211 | if frame_idx == 0: 212 | videoWriter = cv2.VideoWriter( 213 | output_video_path, fourcc, video_fps, 214 | (int(output_width), int(output_height))) 215 | print( 216 | 'Object number: {}. Inference size: {}x{}. Output size: {}x{}.' 217 | .format(obj_nums[0], 218 | current_img.size()[2], 219 | current_img.size()[3], int(output_height), 220 | int(output_width))) 221 | current_label = sample['current_label'].cuda( 222 | gpu_id, non_blocking=True).float() 223 | current_label = F.interpolate(current_label, 224 | size=current_img.size()[2:], 225 | mode="nearest") 226 | # add reference frame 227 | engine.add_reference_frame(current_img, 228 | current_label, 229 | frame_step=0, 230 | obj_nums=obj_nums) 231 | else: 232 | print('Processing image {}...'.format(img_name)) 233 | # predict segmentation 234 | engine.match_propogate_one_frame(current_img) 235 | pred_logit = engine.decode_current_logits( 236 | (output_height, output_width)) 237 | pred_prob = torch.softmax(pred_logit, dim=1) 238 | pred_label = torch.argmax(pred_prob, dim=1, 239 | keepdim=True).float() 240 | _pred_label = F.interpolate(pred_label, 241 | size=engine.input_size_2d, 242 | mode="nearest") 243 | # update memory 244 | engine.update_memory(_pred_label) 245 | 246 | # save results 247 | input_image_path = os.path.join(image_seq_root, img_name) 248 | output_mask_path = os.path.join( 249 | output_mask_seq_root, 250 | img_name.split('.')[0] + '.png') 251 | 252 | pred_label = Image.fromarray( 253 | pred_label.squeeze(0).squeeze(0).cpu().numpy().astype( 254 | 'uint8')).convert('P') 255 | pred_label.putpalette(_palette) 256 | pred_label.save(output_mask_path) 257 | 258 | input_image = Image.open(input_image_path) 259 | 260 | overlayed_image = overlay( 261 | np.array(input_image, dtype=np.uint8), 262 | np.array(pred_label, dtype=np.uint8), color_palette) 263 | videoWriter.write(overlayed_image[..., [2, 1, 0]]) 264 | 265 | print('Save a visualization video to {}.'.format(output_video_path)) 266 | videoWriter.release() 267 | 268 | 269 | #CUDA_VISIBLE_DEVICES=0 python tools/demo.py --seq_list datasets/VOST/ImageSets/val.txt --ckpt_path pretrain_models/vost_aot_best.pth --output_path demo_output/aot_best_10fps 270 | def main(): 271 | import argparse 272 | parser = argparse.ArgumentParser(description="AOT Demo") 273 | parser.add_argument('--exp_name', type=str, default='default') 274 | 275 | parser.add_argument('--stage', type=str, default='pre_vost') 276 | parser.add_argument('--model', type=str, default='r50_aotl') 277 | 278 | parser.add_argument('--gpu_id', type=int, default=0) 279 | 280 | parser.add_argument('--data_path', type=str, default='./datasets/Demo') 281 | parser.add_argument('--seq_name', type=str, default='') 282 | parser.add_argument('--seq_list', type=str, default='') 283 | parser.add_argument('--output_path', type=str, default='./demo_output') 284 | parser.add_argument('--ckpt_path', 285 | type=str, 286 | default='./pretrain_models/R50_AOTL_PRE_YTB_DAV.pth') 287 | 288 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3) 289 | 290 | parser.add_argument('--amp', action='store_true') 291 | parser.set_defaults(amp=False) 292 | 293 | args = parser.parse_args() 294 | 295 | engine_config = importlib.import_module('configs.' + args.stage) 296 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 297 | 298 | cfg.TEST_GPU_ID = args.gpu_id 299 | 300 | cfg.TEST_CKPT_PATH = args.ckpt_path 301 | if len(args.seq_list) != 0: 302 | lst_file = open(args.seq_list, 'r') 303 | lst = lst_file.readlines() 304 | cfg.SEQ_LIST = lst 305 | cfg.TEST_DATA_PATH = '/mnt/fsx/VOST/' 306 | cfg.TEST_FRAME_PATH = None 307 | elif len(args.seq_name) != 0: 308 | cfg.TEST_DATA_PATH = None 309 | cfg.TEST_SEQ_LIST = None 310 | cfg.SEQ_LIST = None 311 | cfg.TEST_FRAME_PATH = 'datasets/VOST/JPEGImages_10fps/%s/' % args.seq_name 312 | cfg.TEST_LABEL_PATH = 'datasets/VOST/Annotations/%s/' % args.seq_name 313 | else: 314 | cfg.TEST_DATA_PATH = args.data_path 315 | cfg.TEST_FRAME_PATH = None 316 | cfg.TEST_SEQ_LIST = None 317 | cfg.TEST_OUTPUT_PATH = args.output_path 318 | 319 | cfg.TEST_MIN_SIZE = None 320 | cfg.TEST_MAX_SIZE = args.max_resolution * 800. / 480. 321 | 322 | if args.amp: 323 | with torch.cuda.amp.autocast(enabled=True): 324 | demo(cfg) 325 | else: 326 | demo(cfg) 327 | 328 | 329 | if __name__ == '__main__': 330 | main() 331 | -------------------------------------------------------------------------------- /aot_plus/tools/eval.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('..') 6 | 7 | import torch 8 | import torch.multiprocessing as mp 9 | 10 | from networks.managers.evaluator import Evaluator 11 | 12 | 13 | def main_worker(gpu, cfg, seq_queue=None, info_queue=None, enable_amp=False): 14 | # Initiate a evaluating manager 15 | evaluator = Evaluator(rank=gpu, 16 | cfg=cfg, 17 | seq_queue=seq_queue, 18 | info_queue=info_queue) 19 | # Start evaluation 20 | if enable_amp: 21 | with torch.cuda.amp.autocast(enabled=True): 22 | evaluator.evaluating() 23 | else: 24 | evaluator.evaluating() 25 | 26 | 27 | #python tools/eval.py --exp_name aotplus --stage pre_vost --model r50_aotl --dataset vost --split val --gpu_num 8 --ckpt_path pretrain_models/aotplus.pth --ms 1.0 1.1 1.2 0.9 0.8 28 | def main(): 29 | import argparse 30 | parser = argparse.ArgumentParser(description="Eval VOS") 31 | parser.add_argument('--exp_name', type=str, default='default') 32 | 33 | parser.add_argument('--stage', type=str, default='pre') 34 | parser.add_argument('--model', type=str, default='aott') 35 | 36 | parser.add_argument('--gpu_id', type=int, default=0) 37 | parser.add_argument('--gpu_num', type=int, default=1) 38 | 39 | parser.add_argument('--ckpt_path', type=str, default='') 40 | parser.add_argument('--ckpt_step', type=int, default=-1) 41 | 42 | parser.add_argument('--dataset', type=str, default='') 43 | parser.add_argument('--split', type=str, default='') 44 | 45 | parser.add_argument('--no_ema', action='store_true') 46 | parser.set_defaults(no_ema=False) 47 | 48 | parser.add_argument('--flip', action='store_true') 49 | parser.set_defaults(flip=False) 50 | parser.add_argument('--ms', nargs='+', type=float, default=[1.]) 51 | 52 | parser.add_argument('--max_resolution', type=float, default=480 * 1.3) 53 | 54 | parser.add_argument('--amp', action='store_true') 55 | parser.set_defaults(amp=False) 56 | 57 | args = parser.parse_args() 58 | 59 | engine_config = importlib.import_module('configs.' + args.stage) 60 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 61 | 62 | cfg.TEST_EMA = not args.no_ema 63 | 64 | cfg.TEST_GPU_ID = args.gpu_id 65 | cfg.TEST_GPU_NUM = args.gpu_num 66 | 67 | if args.ckpt_path != '': 68 | cfg.TEST_CKPT_PATH = args.ckpt_path 69 | if args.ckpt_step > 0: 70 | cfg.TEST_CKPT_STEP = args.ckpt_step 71 | 72 | if args.dataset != '': 73 | cfg.TEST_DATASET = args.dataset 74 | 75 | if args.split != '': 76 | cfg.TEST_DATASET_SPLIT = args.split 77 | 78 | cfg.TEST_FLIP = args.flip 79 | cfg.TEST_MULTISCALE = args.ms 80 | 81 | cfg.TEST_MIN_SIZE = None 82 | cfg.TEST_MAX_SIZE = args.max_resolution * 800. / 480. 83 | 84 | if args.gpu_num > 1: 85 | mp.set_start_method('spawn') 86 | seq_queue = mp.Queue() 87 | info_queue = mp.Queue() 88 | mp.spawn(main_worker, 89 | nprocs=cfg.TEST_GPU_NUM, 90 | args=(cfg, seq_queue, info_queue, args.amp)) 91 | else: 92 | main_worker(0, cfg, enable_amp=args.amp) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /aot_plus/tools/train.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import random 3 | import sys 4 | 5 | sys.setrecursionlimit(10000) 6 | sys.path.append('.') 7 | sys.path.append('..') 8 | 9 | import torch.multiprocessing as mp 10 | 11 | from networks.managers.trainer import Trainer 12 | 13 | 14 | def main_worker(gpu, cfg, enable_amp=True, exp_name='default'): 15 | # Initiate a training manager 16 | trainer = Trainer(rank=gpu, cfg=cfg, enable_amp=enable_amp) 17 | # Start Training 18 | trainer.sequential_training() 19 | 20 | def main(): 21 | import argparse 22 | parser = argparse.ArgumentParser(description="Train VOS") 23 | parser.add_argument('--exp_name', type=str, default='') 24 | parser.add_argument('--stage', type=str, default='pre') 25 | parser.add_argument('--model', type=str, default='aott') 26 | 27 | parser.add_argument('--start_gpu', type=int, default=0) 28 | parser.add_argument('--gpu_num', type=int, default=-1) 29 | parser.add_argument('--batch_size', type=int, default=-1) 30 | parser.add_argument('--dist_url', type=str, default='') 31 | parser.add_argument('--amp', action='store_true') 32 | parser.set_defaults(amp=False) 33 | 34 | parser.add_argument('--pretrained_path', type=str, default='') 35 | 36 | parser.add_argument('--datasets', nargs='+', type=str, default=[]) 37 | parser.add_argument('--lr', type=float, default=-1.) 38 | parser.add_argument('--total_step', type=int, default=-1.) 39 | parser.add_argument('--start_step', type=int, default=-1.) 40 | 41 | args = parser.parse_args() 42 | 43 | engine_config = importlib.import_module('configs.' + args.stage) 44 | 45 | cfg = engine_config.EngineConfig(args.exp_name, args.model) 46 | 47 | if len(args.datasets) > 0: 48 | cfg.DATASETS = args.datasets 49 | 50 | cfg.DIST_START_GPU = args.start_gpu 51 | if args.gpu_num > 0: 52 | cfg.TRAIN_GPUS = args.gpu_num 53 | if args.batch_size > 0: 54 | cfg.TRAIN_BATCH_SIZE = args.batch_size 55 | 56 | if args.pretrained_path != '': 57 | cfg.PRETRAIN_MODEL = args.pretrained_path 58 | 59 | if args.lr > 0: 60 | cfg.TRAIN_LR = args.lr 61 | 62 | if args.total_step > 0: 63 | cfg.TRAIN_TOTAL_STEPS = args.total_step 64 | 65 | if args.start_step > 0: 66 | cfg.TRAIN_START_STEP = args.start_step 67 | 68 | if args.dist_url == '': 69 | cfg.DIST_URL = 'tcp://127.0.0.1:123' + str(random.randint(0, 9)) + str( 70 | random.randint(0, 9)) 71 | else: 72 | cfg.DIST_URL = args.dist_url 73 | if cfg.TRAIN_GPUS == 1: 74 | main_worker(0, cfg, args.amp, args.exp_name) 75 | else: 76 | # Use torch.multiprocessing.spawn to launch distributed processes 77 | mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg, args.amp, args.exp_name)) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /aot_plus/train_vost.sh: -------------------------------------------------------------------------------- 1 | exp="aotplus" 2 | # exp="debug" 3 | gpu_num="4" 4 | devices="4,5,6,7" 5 | 6 | # model="aott" 7 | # model="aots" 8 | # model="aotb" 9 | # model="aotl" 10 | model="r50_aotl" 11 | # model="swinb_aotl" 12 | 13 | stage="pre_vost" 14 | CUDA_VISIBLE_DEVICES=${devices} python tools/train.py --amp \ 15 | --exp_name ${exp} \ 16 | --stage ${stage} \ 17 | --model ${model} \ 18 | --gpu_num ${gpu_num} 19 | 20 | dataset="vost" 21 | split="val" 22 | CUDA_VISIBLE_DEVICES=${devices} python tools/eval.py --exp_name ${exp} --stage ${stage} --model ${model} \ 23 | --dataset ${dataset} --split ${split} --gpu_num ${gpu_num} --ms 1.0 1.1 1.2 0.9 0.8 -------------------------------------------------------------------------------- /aot_plus/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TRI-ML/VOST/fe274574cb03c8a3ea83e121dd76e20b703781fd/aot_plus/utils/__init__.py -------------------------------------------------------------------------------- /aot_plus/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | 5 | 6 | def load_network_and_optimizer(net, opt, pretrained_dir, gpu, scaler=None): 7 | pretrained = torch.load(pretrained_dir, 8 | map_location=torch.device("cuda:" + str(gpu))) 9 | pretrained_dict = pretrained['state_dict'] 10 | model_dict = net.state_dict() 11 | pretrained_dict_update = {} 12 | pretrained_dict_remove = [] 13 | for k, v in pretrained_dict.items(): 14 | if k in model_dict: 15 | pretrained_dict_update[k] = v 16 | elif k[:7] == 'module.': 17 | if k[7:] in model_dict: 18 | pretrained_dict_update[k[7:]] = v 19 | else: 20 | pretrained_dict_remove.append(k) 21 | model_dict.update(pretrained_dict_update) 22 | net.load_state_dict(model_dict) 23 | opt.load_state_dict(pretrained['optimizer']) 24 | if scaler is not None and 'scaler' in pretrained.keys(): 25 | scaler.load_state_dict(pretrained['scaler']) 26 | del (pretrained) 27 | return net.cuda(gpu), opt, pretrained_dict_remove 28 | 29 | 30 | def load_network_and_optimizer_v2(net, opt, pretrained_dir, gpu, scaler=None): 31 | pretrained = torch.load(pretrained_dir, 32 | map_location=torch.device("cuda:" + str(gpu))) 33 | # load model 34 | pretrained_dict = pretrained['state_dict'] 35 | model_dict = net.state_dict() 36 | pretrained_dict_update = {} 37 | pretrained_dict_remove = [] 38 | for k, v in pretrained_dict.items(): 39 | if k in model_dict: 40 | pretrained_dict_update[k] = v 41 | elif k[:7] == 'module.': 42 | if k[7:] in model_dict: 43 | pretrained_dict_update[k[7:]] = v 44 | else: 45 | pretrained_dict_remove.append(k) 46 | model_dict.update(pretrained_dict_update) 47 | net.load_state_dict(model_dict) 48 | 49 | # load optimizer 50 | opt_dict = opt.state_dict() 51 | all_params = { 52 | param_group['name']: param_group['params'][0] 53 | for param_group in opt_dict['param_groups'] 54 | } 55 | pretrained_opt_dict = {'state': {}, 'param_groups': []} 56 | for idx in range(len(pretrained['optimizer']['param_groups'])): 57 | param_group = pretrained['optimizer']['param_groups'][idx] 58 | if param_group['name'] in all_params.keys(): 59 | pretrained_opt_dict['state'][all_params[ 60 | param_group['name']]] = pretrained['optimizer']['state'][ 61 | param_group['params'][0]] 62 | param_group['params'][0] = all_params[param_group['name']] 63 | pretrained_opt_dict['param_groups'].append(param_group) 64 | 65 | opt_dict.update(pretrained_opt_dict) 66 | opt.load_state_dict(opt_dict) 67 | 68 | # load scaler 69 | if scaler is not None and 'scaler' in pretrained.keys(): 70 | scaler.load_state_dict(pretrained['scaler']) 71 | del (pretrained) 72 | return net.cuda(gpu), opt, pretrained_dict_remove 73 | 74 | 75 | def load_network(net, pretrained_dir, gpu): 76 | pretrained = torch.load(pretrained_dir, 77 | map_location=torch.device("cuda:" + str(gpu))) 78 | if 'state_dict' in pretrained.keys(): 79 | pretrained_dict = pretrained['state_dict'] 80 | elif 'model' in pretrained.keys(): 81 | pretrained_dict = pretrained['model'] 82 | else: 83 | pretrained_dict = pretrained 84 | model_dict = net.state_dict() 85 | pretrained_dict_update = {} 86 | pretrained_dict_remove = [] 87 | for k, v in pretrained_dict.items(): 88 | if k in model_dict and (len(v.shape) > 2 and v.shape[1] != model_dict[k].shape[1]): 89 | model_dict[k][:, :-1, :, :] = v 90 | continue 91 | if k in model_dict: 92 | pretrained_dict_update[k] = v 93 | elif k[:7] == 'module.': 94 | if k[7:] in model_dict: 95 | pretrained_dict_update[k[7:]] = v 96 | else: 97 | pretrained_dict_remove.append(k) 98 | model_dict.update(pretrained_dict_update) 99 | net.load_state_dict(model_dict) 100 | del (pretrained) 101 | return net.cuda(gpu), pretrained_dict_remove 102 | 103 | 104 | def save_network(net, 105 | opt, 106 | step, 107 | save_path, 108 | max_keep=8, 109 | backup_dir='./saved_models', 110 | scaler=None): 111 | ckpt = {'state_dict': net.state_dict(), 'optimizer': opt.state_dict()} 112 | if scaler is not None: 113 | ckpt['scaler'] = scaler.state_dict() 114 | 115 | try: 116 | if not os.path.exists(save_path): 117 | os.makedirs(save_path) 118 | save_file = 'save_step_%s.pth' % (step) 119 | save_dir = os.path.join(save_path, save_file) 120 | torch.save(ckpt, save_dir) 121 | except: 122 | save_path = backup_dir 123 | if not os.path.exists(save_path): 124 | os.makedirs(save_path) 125 | save_file = 'save_step_%s.pth' % (step) 126 | save_dir = os.path.join(save_path, save_file) 127 | torch.save(ckpt, save_dir) 128 | 129 | all_ckpt = os.listdir(save_path) 130 | if len(all_ckpt) > max_keep: 131 | all_step = [] 132 | for ckpt_name in all_ckpt: 133 | step = int(ckpt_name.split('_')[-1].split('.')[0]) 134 | all_step.append(step) 135 | all_step = list(np.sort(all_step))[:-max_keep] 136 | for step in all_step: 137 | ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step)) 138 | os.system('rm {}'.format(ckpt_path)) 139 | -------------------------------------------------------------------------------- /aot_plus/utils/ema.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import unicode_literals 3 | 4 | import torch 5 | 6 | 7 | def get_param_buffer_for_ema(model, 8 | update_buffer=False, 9 | required_buffers=['running_mean', 'running_var']): 10 | params = model.parameters() 11 | all_param_buffer = [p for p in params if p.requires_grad] 12 | if update_buffer: 13 | named_buffers = model.named_buffers() 14 | for key, value in named_buffers: 15 | for buffer_name in required_buffers: 16 | if buffer_name in key: 17 | all_param_buffer.append(value) 18 | break 19 | return all_param_buffer 20 | 21 | 22 | class ExponentialMovingAverage: 23 | """ 24 | Maintains (exponential) moving average of a set of parameters. 25 | """ 26 | def __init__(self, parameters, decay, use_num_updates=True): 27 | """ 28 | Args: 29 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 30 | `model.parameters()`. 31 | decay: The exponential decay. 32 | use_num_updates: Whether to use number of updates when computing 33 | averages. 34 | """ 35 | if decay < 0.0 or decay > 1.0: 36 | raise ValueError('Decay must be between 0 and 1') 37 | self.decay = decay 38 | self.num_updates = 0 if use_num_updates else None 39 | self.shadow_params = [p.clone().detach() for p in parameters] 40 | self.collected_params = [] 41 | 42 | def update(self, parameters): 43 | """ 44 | Update currently maintained parameters. 45 | Call this every time the parameters are updated, such as the result of 46 | the `optimizer.step()` call. 47 | Args: 48 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 49 | parameters used to initialize this object. 50 | """ 51 | decay = self.decay 52 | if self.num_updates is not None: 53 | self.num_updates += 1 54 | decay = min(decay, 55 | (1 + self.num_updates) / (10 + self.num_updates)) 56 | one_minus_decay = 1.0 - decay 57 | with torch.no_grad(): 58 | for s_param, param in zip(self.shadow_params, parameters): 59 | s_param.sub_(one_minus_decay * (s_param - param)) 60 | 61 | def copy_to(self, parameters): 62 | """ 63 | Copy current parameters into given collection of parameters. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | updated with the stored moving averages. 67 | """ 68 | for s_param, param in zip(self.shadow_params, parameters): 69 | param.data.copy_(s_param.data) 70 | 71 | def store(self, parameters): 72 | """ 73 | Save the current parameters for restoring later. 74 | Args: 75 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 76 | temporarily stored. 77 | """ 78 | self.collected_params = [param.clone() for param in parameters] 79 | 80 | def restore(self, parameters): 81 | """ 82 | Restore the parameters stored with the `store` method. 83 | Useful to validate the model with EMA parameters without affecting the 84 | original optimization process. Store the parameters before the 85 | `copy_to` method. After validation (or model saving), use this to 86 | restore the former parameters. 87 | Args: 88 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 89 | updated with the stored parameters. 90 | """ 91 | for c_param, param in zip(self.collected_params, parameters): 92 | param.data.copy_(c_param.data) 93 | del (self.collected_params) 94 | -------------------------------------------------------------------------------- /aot_plus/utils/eval.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import os 3 | 4 | 5 | def zip_folder(source_folder, zip_dir): 6 | f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED) 7 | pre_len = len(os.path.dirname(source_folder)) 8 | for dirpath, dirnames, filenames in os.walk(source_folder): 9 | for filename in filenames: 10 | pathfile = os.path.join(dirpath, filename) 11 | arcname = pathfile[pre_len:].strip(os.path.sep) 12 | f.write(pathfile, arcname) 13 | f.close() -------------------------------------------------------------------------------- /aot_plus/utils/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | import threading 5 | 6 | _palette = [ 7 | 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 8 | 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 9 | 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, 10 | 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, 11 | 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 12 | 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 13 | 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 14 | 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 15 | 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 16 | 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 17 | 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 18 | 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 19 | 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 20 | 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 21 | 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 22 | 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 23 | 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, 24 | 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, 25 | 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, 26 | 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 27 | 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, 28 | 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, 29 | 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, 30 | 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, 31 | 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 32 | 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, 33 | 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, 34 | 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, 35 | 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, 36 | 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, 37 | 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, 38 | 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, 39 | 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, 40 | 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, 41 | 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, 42 | 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, 43 | 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, 44 | 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, 45 | 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, 46 | 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, 47 | 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, 48 | 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, 49 | 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, 50 | 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, 51 | 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, 52 | 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, 53 | 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, 54 | 255, 255, 255 55 | ] 56 | 57 | 58 | def label2colormap(label): 59 | 60 | m = label.astype(np.uint8) 61 | r, c = m.shape 62 | cmap = np.zeros((r, c, 3), dtype=np.uint8) 63 | cmap[:, :, 0] = (m & 1) << 7 | (m & 8) << 3 | (m & 64) >> 1 64 | cmap[:, :, 1] = (m & 2) << 6 | (m & 16) << 2 | (m & 128) >> 2 65 | cmap[:, :, 2] = (m & 4) << 5 | (m & 32) << 1 66 | return cmap 67 | 68 | 69 | def one_hot_mask(mask, cls_num): 70 | if len(mask.size()) == 3: 71 | mask = mask.unsqueeze(1) 72 | indices = torch.arange(0, cls_num + 1, 73 | device=mask.device).view(1, -1, 1, 1) 74 | return (mask == indices).float(), (mask == 255).float() 75 | 76 | 77 | def masked_image(image, colored_mask, mask, alpha=0.7): 78 | mask = np.expand_dims(mask > 0, axis=0) 79 | mask = np.repeat(mask, 3, axis=0) 80 | show_img = (image * alpha + colored_mask * 81 | (1 - alpha)) * mask + image * (1 - mask) 82 | return show_img 83 | 84 | 85 | def save_image(image, path): 86 | im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0))) 87 | im.save(path) 88 | 89 | 90 | def _save_mask(mask, path, squeeze_idx=None): 91 | if squeeze_idx is not None: 92 | unsqueezed_mask = mask * 0 93 | for idx in range(1, len(squeeze_idx)): 94 | obj_id = squeeze_idx[idx] 95 | mask_i = mask == idx 96 | unsqueezed_mask += (mask_i * obj_id).astype(np.uint8) 97 | mask = unsqueezed_mask 98 | mask = Image.fromarray(mask).convert('P') 99 | mask.putpalette(_palette) 100 | mask.save(path) 101 | 102 | 103 | def save_mask(mask_tensor, path, squeeze_idx=None): 104 | mask = mask_tensor.cpu().numpy().astype('uint8') 105 | threading.Thread(target=_save_mask, args=[mask, path, squeeze_idx]).start() 106 | 107 | 108 | def flip_tensor(tensor, dim=0): 109 | inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1, 110 | device=tensor.device).long() 111 | tensor = tensor.index_select(dim, inv_idx) 112 | return tensor 113 | 114 | 115 | def shuffle_obj_mask(mask): 116 | 117 | bs, obj_num, _, _ = mask.size() 118 | new_masks = [] 119 | for idx in range(bs): 120 | now_mask = mask[idx] 121 | random_matrix = torch.eye(obj_num, device=mask.device) 122 | fg = random_matrix[1:][torch.randperm(obj_num - 1)] 123 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) 124 | now_mask = torch.einsum('nm,nhw->mhw', random_matrix, now_mask) 125 | new_masks.append(now_mask) 126 | 127 | return torch.stack(new_masks, dim=0) 128 | -------------------------------------------------------------------------------- /aot_plus/utils/learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, 5 | base_lr, 6 | p, 7 | itr, 8 | max_itr, 9 | restart=1, 10 | warm_up_steps=1000, 11 | is_cosine_decay=False, 12 | min_lr=1e-5, 13 | encoder_lr_ratio=1.0, 14 | freeze_params=[]): 15 | 16 | if restart > 1: 17 | each_max_itr = int(math.ceil(float(max_itr) / restart)) 18 | itr = itr % each_max_itr 19 | warm_up_steps /= restart 20 | max_itr = each_max_itr 21 | 22 | if itr < warm_up_steps: 23 | now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps 24 | else: 25 | itr = itr - warm_up_steps 26 | max_itr = max_itr - warm_up_steps 27 | if is_cosine_decay: 28 | now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr / 29 | (max_itr + 1)) + 30 | 1.) * 0.5 31 | else: 32 | now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p 33 | 34 | for param_group in optimizer.param_groups: 35 | if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]: 36 | param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr 37 | else: 38 | param_group['lr'] = now_lr 39 | 40 | for freeze_param in freeze_params: 41 | if freeze_param in param_group["name"]: 42 | param_group['lr'] = 0 43 | param_group['weight_decay'] = 0 44 | break 45 | 46 | return now_lr 47 | 48 | 49 | def get_trainable_params(model, 50 | base_lr, 51 | weight_decay, 52 | use_frozen_bn=False, 53 | exclusive_wd_dict={}, 54 | no_wd_keys=[]): 55 | params = [] 56 | memo = set() 57 | total_param = 0 58 | for key, value in model.named_parameters(): 59 | if value in memo: 60 | continue 61 | total_param += value.numel() 62 | if not value.requires_grad: 63 | continue 64 | memo.add(value) 65 | wd = weight_decay 66 | for exclusive_key in exclusive_wd_dict.keys(): 67 | if exclusive_key in key: 68 | wd = exclusive_wd_dict[exclusive_key] 69 | break 70 | if len(value.shape) == 1: # normalization layers 71 | if 'bias' in key: # bias requires no weight decay 72 | wd = 0. 73 | elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay 74 | wd = 0. 75 | elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder 76 | wd = 0. 77 | else: 78 | for no_wd_key in no_wd_keys: 79 | if no_wd_key in key: 80 | wd = 0. 81 | break 82 | params += [{ 83 | "params": [value], 84 | "lr": base_lr, 85 | "weight_decay": wd, 86 | "name": key 87 | }] 88 | 89 | print('Total Param: {:.2f}M'.format(total_param / 1e6)) 90 | return params 91 | 92 | 93 | def freeze_params(module): 94 | for p in module.parameters(): 95 | p.requires_grad = False 96 | -------------------------------------------------------------------------------- /aot_plus/utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0): 5 | all_matrix = [] 6 | for idx in range(num): 7 | random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id)) 8 | if keep_first: 9 | fg = random_matrix[1:][torch.randperm(dim - 1)] 10 | random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) 11 | else: 12 | random_matrix = random_matrix[torch.randperm(dim)] 13 | all_matrix.append(random_matrix) 14 | return torch.stack(all_matrix, dim=0) 15 | 16 | 17 | def truncated_normal_(tensor, mean=0, std=.02): 18 | size = tensor.shape 19 | tmp = tensor.new_empty(size + (4, )).normal_() 20 | valid = (tmp < 2) & (tmp > -2) 21 | ind = valid.max(-1, keepdim=True)[1] 22 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 23 | tensor.data.mul_(std).add_(mean) 24 | return tensor 25 | -------------------------------------------------------------------------------- /aot_plus/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self, momentum=0.999): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | self.long_count = 0 12 | self.momentum = momentum 13 | self.moving_avg = 0 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | if self.long_count == 0: 23 | self.moving_avg = val 24 | else: 25 | momentum = min(self.momentum, 1. - 1. / self.long_count) 26 | self.moving_avg = self.moving_avg * momentum + val * (1 - momentum) 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.long_count += n 31 | self.avg = self.sum / self.count 32 | -------------------------------------------------------------------------------- /aot_plus/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pytorch_iou(pred, target, obj_num, epsilon=1e-6): 5 | ''' 6 | pred: [bs, h, w] 7 | target: [bs, h, w] 8 | obj_num: [bs] 9 | ''' 10 | bs = pred.size(0) 11 | all_iou = [] 12 | for idx in range(bs): 13 | now_pred = pred[idx].unsqueeze(0) 14 | now_target = target[idx].unsqueeze(0) 15 | now_obj_num = obj_num[idx] 16 | 17 | obj_ids = torch.arange(0, now_obj_num + 1, 18 | device=now_pred.device).int().view(-1, 1, 1) 19 | if obj_ids.size(0) == 1: # only contain background 20 | continue 21 | else: 22 | obj_ids = obj_ids[1:] 23 | now_pred = (now_pred == obj_ids).float() 24 | now_target = (now_target == obj_ids).float() 25 | 26 | intersection = (now_pred * now_target).sum((1, 2)) 27 | union = ((now_pred + now_target) > 0).float().sum((1, 2)) 28 | 29 | now_iou = (intersection + epsilon) / (union + epsilon) 30 | 31 | all_iou.append(now_iou.mean()) 32 | if len(all_iou) > 0: 33 | all_iou = torch.stack(all_iou).mean() 34 | else: 35 | all_iou = torch.ones((1), device=pred.device) 36 | return all_iou 37 | -------------------------------------------------------------------------------- /data.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 13 | 14 | 15 | 16 | 17 | 140 | 141 | 142 | 143 |
144 | 145 |
160 | |
164 |
165 | |
169 |
170 | |
174 |
175 | |
179 |
180 | |
184 |
161 | |
165 |
166 | |
170 |
171 | |
175 |
176 | |
180 |
181 | |
185 |
193 | ![]() 195 | |
197 |
198 | ![]() 200 | |
202 |
203 | ![]() 205 | |
207 |
208 | ![]() 210 | |
212 |
254 | 255 | | 256 |
![]() |
267 | Pavel Tokmakov, Jie Li, Adrien Gaidon. 268 | Breaking the “Object” in Video Object Segmentation. 269 | CVPR 2023. 270 | 271 | 272 | 273 | |
274 |
275 |